master v1.5.8
Stanly 4 weeks ago
parent 826dee2e35
commit 26602e7552

@ -4,70 +4,103 @@ import (
"bytes"
"crypto/aes"
"crypto/cipher"
"fmt"
"github.com/Luzifer/go-openssl/v3"
)
func paddingPKCS7(ciphertext []byte, blockSize int) []byte {
func addPKCS7Padding(ciphertext []byte, blockSize int) []byte {
padding := blockSize - len(ciphertext)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padtext...)
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padText...)
}
func unpaddingPKCS7(origData []byte) []byte {
length := len(origData)
unpadding := int(origData[length-1])
return origData[:(length - unpadding)]
func removePKCS7Padding(data []byte) ([]byte, error) {
if len(data) == 0 {
return nil, fmt.Errorf("data cannot be empty")
}
padding := int(data[len(data)-1])
if padding > len(data) {
return nil, fmt.Errorf("invalid padding size")
}
return data[:len(data)-padding], nil
}
//EncryptAES 加密函式
// EncryptAES 加密函式
func EncryptAES(plaintext, key, iv []byte) ([]byte, error) {
// 驗證密鑰長度是否符合 AES 的要求
if len(key) != aes.BlockSize && len(key) != 16 && len(key) != 24 && len(key) != 32 {
return nil, fmt.Errorf("invalid key size: %d, expected 16, 24, or 32", len(key))
}
// 創建 AES 密碼器
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
}
// 確保輸入的資料長度符合 block size並進行填充
blockSize := block.BlockSize()
plaintext = paddingPKCS7(plaintext, blockSize)
plaintext = addPKCS7Padding(plaintext, blockSize)
// 創建 CBC 模式加密器
blockMode := cipher.NewCBCEncrypter(block, iv)
crypted := make([]byte, len(plaintext))
blockMode.CryptBlocks(crypted, plaintext)
return crypted, nil
// 用來保存加密結果的緩衝區
enc := make([]byte, len(plaintext))
blockMode.CryptBlocks(enc, plaintext)
return enc, nil
}
// DecryptAES 解密函式
// DecryptAES 使用 AES CBC 模式進行解密
func DecryptAES(ciphertext, key, iv []byte) ([]byte, error) {
// 驗證密鑰長度是否符合 AES 的要求
if len(key) != aes.BlockSize && len(key) != 16 && len(key) != 24 && len(key) != 32 {
return nil, fmt.Errorf("invalid key size: %d, expected 16, 24, or 32", len(key))
}
// 創建 AES 密碼器
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
}
// 確保輸入的資料長度符合 block size並進行解密
blockSize := block.BlockSize()
blockMode := cipher.NewCBCDecrypter(block, iv[:blockSize])
// 用來保存解密結果的緩衝區
origData := make([]byte, len(ciphertext))
blockMode.CryptBlocks(origData, ciphertext)
origData = unpaddingPKCS7(origData)
return origData, nil
// 去除 PKCS7 填充
return removePKCS7Padding(origData)
}
// DecryptAESWithOpenSSL 使用Openssl 解密AES
func DecryptAESWithOpenSSL(value, key string) ([]byte, error) {
// EncryptAESWithOpenSSL 使用Openssl 加密AES
func EncryptAESWithOpenSSL(value, key string) ([]byte, error) {
o := openssl.New()
dec, err := o.DecryptBytes(key, []byte(value), openssl.DigestMD5Sum)
enc, err := o.EncryptBytes(key, []byte(value), openssl.DigestMD5Sum)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to encrypt with OpenSSL: %w", err)
}
return dec, nil
return enc, nil
}
// EncryptAESWithOpenSSL 使用Openssl 加密AES
func EncryptAESWithOpenSSL(value, key string) ([]byte, error) {
// DecryptAESWithOpenSSL 使用Openssl 解密AES
func DecryptAESWithOpenSSL(value, key string) ([]byte, error) {
o := openssl.New()
enc, err := o.EncryptBytes(key, []byte(value), openssl.DigestMD5Sum)
dec, err := o.DecryptBytes(key, []byte(value), openssl.DigestMD5Sum)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to decrypt with OpenSSL: %w", err)
}
return enc, nil
return dec, nil
}

@ -9,24 +9,15 @@ import (
// MD5 回傳md5加密
func MD5(v string) string {
h := md5.New()
h.Write([]byte(v))
bs := h.Sum(nil)
return fmt.Sprintf("%x", bs)
return fmt.Sprintf("%x", md5.Sum([]byte(v)))
}
// SHA1 回傳sha1加密
func SHA1(v string) string {
h := sha1.New()
h.Write([]byte(v))
bs := h.Sum(nil)
return fmt.Sprintf("%x", bs)
return fmt.Sprintf("%x", sha1.Sum([]byte(v)))
}
// SHA256 回傳sha256加密
func SHA256(v string) string {
h := sha256.New()
h.Write([]byte(v))
bs := h.Sum(nil)
return fmt.Sprintf("%x", bs)
return fmt.Sprintf("%x", sha256.Sum256([]byte(v)))
}

@ -1,6 +1,8 @@
package crypto
import (
"fmt"
"golang.org/x/crypto/bcrypt"
)
@ -8,7 +10,7 @@ import (
func EncryptPassword(pwd string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost)
if err != nil {
return "", err
return "", fmt.Errorf("failed to encrypt password: %w", err)
}
return string(hash), nil
@ -16,5 +18,12 @@ func EncryptPassword(pwd string) (string, error) {
// CheckPassword 檢查密碼
func CheckPassword(pwd, hash string) error {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(pwd))
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pwd))
if err != nil {
// 返回帶有更多上下文信息的錯誤
return fmt.Errorf("password does not match: %w", err)
}
// 如果密碼正確,返回 nil
return nil
}

@ -6,19 +6,20 @@ import (
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
)
func NewKeyRSA(bitSize int) (pubPEM []byte, keyPEM []byte, err error) {
// Generate RSA key.
// 生成 RSA 密鑰對
key, err := rsa.GenerateKey(rand.Reader, bitSize)
if err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("failed to generate RSA key: %w", err)
}
// Extract public component.
// 提取公鑰部分並轉換為 PEM 格式
pubBytes, err := x509.MarshalPKIXPublicKey(key.Public())
if err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
}
pubPEM = pem.EncodeToMemory(
&pem.Block{
@ -27,7 +28,7 @@ func NewKeyRSA(bitSize int) (pubPEM []byte, keyPEM []byte, err error) {
},
)
// Encode private key to PKCS#1 ASN.1 PEM.
// 編碼私鑰為 PKCS#1 ASN.1 PEM 格式
keyPEM = pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
@ -42,29 +43,44 @@ func NewKeyRSA(bitSize int) (pubPEM []byte, keyPEM []byte, err error) {
func EncryptRSA(value, publicKey []byte) ([]byte, error) {
block, _ := pem.Decode(publicKey)
if block == nil {
return nil, errors.New("public key error")
return nil, errors.New("failed to decode public key PEM")
}
pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to parse public key: %w", err)
}
pub := pubInterface.(*rsa.PublicKey)
return rsa.EncryptPKCS1v15(rand.Reader, pub, value)
pub, ok := pubInterface.(*rsa.PublicKey)
if !ok {
return nil, errors.New("invalid public key type")
}
// 使用公鑰進行加密
enc, err := rsa.EncryptPKCS1v15(rand.Reader, pub, value)
if err != nil {
return nil, fmt.Errorf("failed to encrypt with RSA: %w", err)
}
return enc, nil
}
// DecryptRSA rsa解密
func DecryptRSA(ciphertext, privateKey []byte) ([]byte, error) {
block, _ := pem.Decode(privateKey)
if block == nil {
return nil, errors.New("private key error")
return nil, errors.New("failed to decode private key PEM")
}
priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
dec, err := rsa.DecryptPKCS1v15(rand.Reader, priv, ciphertext)
if err != nil {
return nil, fmt.Errorf("failed to decrypt with RSA: %w", err)
}
return rsa.DecryptPKCS1v15(rand.Reader, priv, ciphertext)
return dec, nil
}

@ -2,52 +2,37 @@ package generate
import (
"crypto/rand"
"math/big"
"fmt"
)
// GetRandomString 取得隨機字串
func GetRandomString(n int) (string, error) {
const alphaNum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
buffer := make([]byte, n)
max := big.NewInt(int64(len(alphaNum)))
for i := 0; i < n; i++ {
index, err := randomInt(max)
if err != nil {
return "", err
}
buffer[i] = alphaNum[index]
}
return string(buffer), nil
return getRandomFromCharset(n, alphaNum)
}
// GetRandomKey 取得隨機金鑰
func GetRandomKey(n int) (string, error) {
const alphaNum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz~!@#$%^&*()-_=+"
return getRandomFromCharset(n, alphaNum)
}
// getRandomFromCharset 通用隨機字串生成器
func getRandomFromCharset(n int, charset string) (string, error) {
// 預先計算字符集的長度
charsLen := len(charset)
buffer := make([]byte, n)
max := big.NewInt(int64(len(alphaNum)))
for i := 0; i < n; i++ {
index, err := randomInt(max)
if err != nil {
return "", err
}
buffer[i] = alphaNum[index]
// 使用 crypto/rand 生成隨機數據
_, err := rand.Read(buffer)
if err != nil {
return "", fmt.Errorf("failed to generate random data: %w", err)
}
return string(buffer), nil
}
func randomInt(max *big.Int) (int, error) {
random, err := rand.Int(rand.Reader, max)
if err != nil {
return 0, err
// 從字元集選擇字符
for i := 0; i < n; i++ {
buffer[i] = charset[int(buffer[i])%charsLen]
}
return int(random.Int64()), nil
return string(buffer), nil
}

@ -6,7 +6,6 @@ import (
"image/jpeg"
"image/png"
"io"
"io/ioutil"
"mime/multipart"
"os"
)
@ -14,31 +13,42 @@ import (
// CompressImage 壓縮圖片至指定位置
func CompressImage(input interface{}, output string) (err error) {
var file io.Reader
switch input.(type) {
var fileToClose io.Closer
switch v := input.(type) {
case string:
if file, err = os.Open(input.(string)); err != nil {
f, err := os.Open(v)
if err != nil {
return err
}
file = f
fileToClose = f
defer func() {
_ = file.(*os.File).Close()
if fileToClose != nil {
_ = fileToClose.Close()
}
_ = os.Remove(v)
}()
defer os.Remove(input.(string))
case *multipart.FileHeader:
fileHeader := input.(*multipart.FileHeader)
file, err = fileHeader.Open()
f, err := v.Open()
if err != nil {
return err
}
file = f
fileToClose = f
defer func() {
_ = file.(multipart.File).Close()
if fileToClose != nil {
_ = fileToClose.Close()
}
}()
}
bs, err := ioutil.ReadAll(file)
if err != nil {
var buf bytes.Buffer
if _, err = io.Copy(&buf, file); err != nil {
return err
}
bs := buf.Bytes()
var ext string
if _, err = jpeg.Decode(bytes.NewReader(bs)); err == nil {
@ -53,7 +63,11 @@ func CompressImage(input interface{}, output string) (err error) {
if err != nil {
return err
}
defer outputFile.Close()
defer func() {
if outputFile != nil {
_ = outputFile.Close()
}
}()
switch ext {
case "jpg":
@ -61,34 +75,22 @@ func CompressImage(input interface{}, output string) (err error) {
if err != nil {
return err
}
if err = jpeg.Encode(outputFile, img, &jpeg.Options{Quality: 70}); err != nil {
return err
}
return jpeg.Encode(outputFile, img, &jpeg.Options{Quality: 70})
case "png":
img, err := png.Decode(bytes.NewReader(bs))
if err != nil {
return err
}
encoder := png.Encoder{CompressionLevel: png.BestCompression}
if err = encoder.Encode(outputFile, img); err != nil {
return err
}
return encoder.Encode(outputFile, img)
case "gif":
img, err := gif.DecodeAll(bytes.NewReader(bs))
if err != nil {
return err
}
if err = gif.EncodeAll(outputFile, img); err != nil {
return err
}
return gif.EncodeAll(outputFile, img)
default:
if _, err = outputFile.Write(bs); err != nil {
return err
}
_, err = outputFile.Write(bs)
return err
}
return nil
}

@ -12,15 +12,14 @@ import (
// Encode jwt編碼
func Encode(values types.Data, key string) (string, error) {
claims := jwt.MapClaims{}
for key, value := range values {
claims[key] = value
for k, v := range values {
claims[k] = v
}
tokenString, err := jwt.
NewWithClaims(jwt.SigningMethodHS256, claims).
tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).
SignedString([]byte(key))
if err != nil {
return "", err
return "", fmt.Errorf("failed to sign JWT with HS256: %w", err)
}
return tokenString, nil
@ -29,15 +28,14 @@ func Encode(values types.Data, key string) (string, error) {
// EncodeRS256 jwt編碼
func EncodeRS256(values types.Data, key *rsa.PrivateKey) (string, error) {
claims := jwt.MapClaims{}
for key, value := range values {
claims[key] = value
for k, v := range values {
claims[k] = v
}
tokenString, err := jwt.
NewWithClaims(jwt.SigningMethodRS256, claims).
tokenString, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).
SignedString(key)
if err != nil {
return "", err
return "", fmt.Errorf("failed to sign JWT with RS256: %w", err)
}
return tokenString, nil
@ -47,22 +45,27 @@ func EncodeRS256(values types.Data, key *rsa.PrivateKey) (string, error) {
func Decode(tokenString string, key string) (types.Data, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected error: %v. ", token.Header["alg"])
return nil, fmt.Errorf("unexpected signing method: %w", token.Header["alg"])
}
return []byte(key), nil
})
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to parse JWT: %w", err)
}
if token == nil || !token.Valid {
return nil, fmt.Errorf("invalid token or claims")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return nil, fmt.Errorf("Token error. ")
if !ok {
return nil, fmt.Errorf("failed to parse token claims")
}
return types.Data(claims), nil
}
// IsExpired 檢查 token 是否過期
func IsExpired(err error) bool {
return errors.Is(err, jwt.ErrTokenExpired)
}

Loading…
Cancel
Save