diff --git a/crypto/aes.go b/crypto/aes.go index 1bd1d89..17f60e4 100644 --- a/crypto/aes.go +++ b/crypto/aes.go @@ -4,70 +4,103 @@ import ( "bytes" "crypto/aes" "crypto/cipher" + "fmt" "github.com/Luzifer/go-openssl/v3" ) -func paddingPKCS7(ciphertext []byte, blockSize int) []byte { +func addPKCS7Padding(ciphertext []byte, blockSize int) []byte { padding := blockSize - len(ciphertext)%blockSize - padtext := bytes.Repeat([]byte{byte(padding)}, padding) - return append(ciphertext, padtext...) + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(ciphertext, padText...) } -func unpaddingPKCS7(origData []byte) []byte { - length := len(origData) - unpadding := int(origData[length-1]) - return origData[:(length - unpadding)] +func removePKCS7Padding(data []byte) ([]byte, error) { + if len(data) == 0 { + return nil, fmt.Errorf("data cannot be empty") + } + + padding := int(data[len(data)-1]) + + if padding > len(data) { + return nil, fmt.Errorf("invalid padding size") + } + + return data[:len(data)-padding], nil } -//EncryptAES 加密函式 +// EncryptAES 加密函式 func EncryptAES(plaintext, key, iv []byte) ([]byte, error) { + // 驗證密鑰長度是否符合 AES 的要求 + if len(key) != aes.BlockSize && len(key) != 16 && len(key) != 24 && len(key) != 32 { + return nil, fmt.Errorf("invalid key size: %d, expected 16, 24, or 32", len(key)) + } + + // 創建 AES 密碼器 block, err := aes.NewCipher(key) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create AES cipher: %w", err) } + + // 確保輸入的資料長度符合 block size,並進行填充 blockSize := block.BlockSize() - plaintext = paddingPKCS7(plaintext, blockSize) + plaintext = addPKCS7Padding(plaintext, blockSize) + + // 創建 CBC 模式加密器 blockMode := cipher.NewCBCEncrypter(block, iv) - crypted := make([]byte, len(plaintext)) - blockMode.CryptBlocks(crypted, plaintext) - return crypted, nil + + // 用來保存加密結果的緩衝區 + enc := make([]byte, len(plaintext)) + blockMode.CryptBlocks(enc, plaintext) + + return enc, nil } -// DecryptAES 解密函式 +// DecryptAES 使用 AES CBC 模式進行解密 func DecryptAES(ciphertext, key, iv []byte) ([]byte, error) { + // 驗證密鑰長度是否符合 AES 的要求 + if len(key) != aes.BlockSize && len(key) != 16 && len(key) != 24 && len(key) != 32 { + return nil, fmt.Errorf("invalid key size: %d, expected 16, 24, or 32", len(key)) + } + + // 創建 AES 密碼器 block, err := aes.NewCipher(key) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create AES cipher: %w", err) } + + // 確保輸入的資料長度符合 block size,並進行解密 blockSize := block.BlockSize() blockMode := cipher.NewCBCDecrypter(block, iv[:blockSize]) + + // 用來保存解密結果的緩衝區 origData := make([]byte, len(ciphertext)) blockMode.CryptBlocks(origData, ciphertext) - origData = unpaddingPKCS7(origData) - return origData, nil + + // 去除 PKCS7 填充 + return removePKCS7Padding(origData) } -// DecryptAESWithOpenSSL 使用Openssl 解密AES -func DecryptAESWithOpenSSL(value, key string) ([]byte, error) { +// EncryptAESWithOpenSSL 使用Openssl 加密AES +func EncryptAESWithOpenSSL(value, key string) ([]byte, error) { o := openssl.New() - dec, err := o.DecryptBytes(key, []byte(value), openssl.DigestMD5Sum) + enc, err := o.EncryptBytes(key, []byte(value), openssl.DigestMD5Sum) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to encrypt with OpenSSL: %w", err) } - return dec, nil + return enc, nil } -// EncryptAESWithOpenSSL 使用Openssl 加密AES -func EncryptAESWithOpenSSL(value, key string) ([]byte, error) { +// DecryptAESWithOpenSSL 使用Openssl 解密AES +func DecryptAESWithOpenSSL(value, key string) ([]byte, error) { o := openssl.New() - enc, err := o.EncryptBytes(key, []byte(value), openssl.DigestMD5Sum) + dec, err := o.DecryptBytes(key, []byte(value), openssl.DigestMD5Sum) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to decrypt with OpenSSL: %w", err) } - return enc, nil + return dec, nil } diff --git a/crypto/crypto.go b/crypto/crypto.go index 1d7a569..7bdf019 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -9,24 +9,15 @@ import ( // MD5 回傳md5加密 func MD5(v string) string { - h := md5.New() - h.Write([]byte(v)) - bs := h.Sum(nil) - return fmt.Sprintf("%x", bs) + return fmt.Sprintf("%x", md5.Sum([]byte(v))) } // SHA1 回傳sha1加密 func SHA1(v string) string { - h := sha1.New() - h.Write([]byte(v)) - bs := h.Sum(nil) - return fmt.Sprintf("%x", bs) + return fmt.Sprintf("%x", sha1.Sum([]byte(v))) } // SHA256 回傳sha256加密 func SHA256(v string) string { - h := sha256.New() - h.Write([]byte(v)) - bs := h.Sum(nil) - return fmt.Sprintf("%x", bs) + return fmt.Sprintf("%x", sha256.Sum256([]byte(v))) } diff --git a/crypto/password.go b/crypto/password.go index 1ab8d4b..71c82b0 100644 --- a/crypto/password.go +++ b/crypto/password.go @@ -1,6 +1,8 @@ package crypto import ( + "fmt" + "golang.org/x/crypto/bcrypt" ) @@ -8,7 +10,7 @@ import ( func EncryptPassword(pwd string) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost) if err != nil { - return "", err + return "", fmt.Errorf("failed to encrypt password: %w", err) } return string(hash), nil @@ -16,5 +18,12 @@ func EncryptPassword(pwd string) (string, error) { // CheckPassword 檢查密碼 func CheckPassword(pwd, hash string) error { - return bcrypt.CompareHashAndPassword([]byte(hash), []byte(pwd)) + err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pwd)) + if err != nil { + // 返回帶有更多上下文信息的錯誤 + return fmt.Errorf("password does not match: %w", err) + } + + // 如果密碼正確,返回 nil + return nil } diff --git a/crypto/rsa.go b/crypto/rsa.go index 1266c30..53ef35f 100644 --- a/crypto/rsa.go +++ b/crypto/rsa.go @@ -6,19 +6,20 @@ import ( "crypto/x509" "encoding/pem" "errors" + "fmt" ) func NewKeyRSA(bitSize int) (pubPEM []byte, keyPEM []byte, err error) { - // Generate RSA key. + // 生成 RSA 密鑰對 key, err := rsa.GenerateKey(rand.Reader, bitSize) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to generate RSA key: %w", err) } - // Extract public component. + // 提取公鑰部分並轉換為 PEM 格式 pubBytes, err := x509.MarshalPKIXPublicKey(key.Public()) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to marshal public key: %w", err) } pubPEM = pem.EncodeToMemory( &pem.Block{ @@ -27,7 +28,7 @@ func NewKeyRSA(bitSize int) (pubPEM []byte, keyPEM []byte, err error) { }, ) - // Encode private key to PKCS#1 ASN.1 PEM. + // 編碼私鑰為 PKCS#1 ASN.1 PEM 格式 keyPEM = pem.EncodeToMemory( &pem.Block{ Type: "RSA PRIVATE KEY", @@ -42,29 +43,44 @@ func NewKeyRSA(bitSize int) (pubPEM []byte, keyPEM []byte, err error) { func EncryptRSA(value, publicKey []byte) ([]byte, error) { block, _ := pem.Decode(publicKey) if block == nil { - return nil, errors.New("public key error") + return nil, errors.New("failed to decode public key PEM") } pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to parse public key: %w", err) } - pub := pubInterface.(*rsa.PublicKey) - return rsa.EncryptPKCS1v15(rand.Reader, pub, value) + pub, ok := pubInterface.(*rsa.PublicKey) + if !ok { + return nil, errors.New("invalid public key type") + } + + // 使用公鑰進行加密 + enc, err := rsa.EncryptPKCS1v15(rand.Reader, pub, value) + if err != nil { + return nil, fmt.Errorf("failed to encrypt with RSA: %w", err) + } + + return enc, nil } // DecryptRSA rsa解密 func DecryptRSA(ciphertext, privateKey []byte) ([]byte, error) { block, _ := pem.Decode(privateKey) if block == nil { - return nil, errors.New("private key error") + return nil, errors.New("failed to decode private key PEM") } priv, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + + dec, err := rsa.DecryptPKCS1v15(rand.Reader, priv, ciphertext) + if err != nil { + return nil, fmt.Errorf("failed to decrypt with RSA: %w", err) } - return rsa.DecryptPKCS1v15(rand.Reader, priv, ciphertext) + return dec, nil } diff --git a/generate/generate.go b/generate/generate.go index 4e7bff7..1a8703b 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -2,52 +2,37 @@ package generate import ( "crypto/rand" - "math/big" + "fmt" ) // GetRandomString 取得隨機字串 func GetRandomString(n int) (string, error) { const alphaNum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" - - buffer := make([]byte, n) - max := big.NewInt(int64(len(alphaNum))) - - for i := 0; i < n; i++ { - index, err := randomInt(max) - if err != nil { - return "", err - } - - buffer[i] = alphaNum[index] - } - - return string(buffer), nil + return getRandomFromCharset(n, alphaNum) } // GetRandomKey 取得隨機金鑰 func GetRandomKey(n int) (string, error) { const alphaNum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz~!@#$%^&*()-_=+" + return getRandomFromCharset(n, alphaNum) +} +// getRandomFromCharset 通用隨機字串生成器 +func getRandomFromCharset(n int, charset string) (string, error) { + // 預先計算字符集的長度 + charsLen := len(charset) buffer := make([]byte, n) - max := big.NewInt(int64(len(alphaNum))) - for i := 0; i < n; i++ { - index, err := randomInt(max) - if err != nil { - return "", err - } - - buffer[i] = alphaNum[index] + // 使用 crypto/rand 生成隨機數據 + _, err := rand.Read(buffer) + if err != nil { + return "", fmt.Errorf("failed to generate random data: %w", err) } - return string(buffer), nil -} - -func randomInt(max *big.Int) (int, error) { - random, err := rand.Int(rand.Reader, max) - if err != nil { - return 0, err + // 從字元集選擇字符 + for i := 0; i < n; i++ { + buffer[i] = charset[int(buffer[i])%charsLen] } - return int(random.Int64()), nil + return string(buffer), nil } diff --git a/image/image.go b/image/image.go index 2d792be..6aa33f1 100644 --- a/image/image.go +++ b/image/image.go @@ -6,7 +6,6 @@ import ( "image/jpeg" "image/png" "io" - "io/ioutil" "mime/multipart" "os" ) @@ -14,31 +13,42 @@ import ( // CompressImage 壓縮圖片至指定位置 func CompressImage(input interface{}, output string) (err error) { var file io.Reader - switch input.(type) { + var fileToClose io.Closer + + switch v := input.(type) { case string: - if file, err = os.Open(input.(string)); err != nil { + f, err := os.Open(v) + if err != nil { return err } + + file = f + fileToClose = f defer func() { - _ = file.(*os.File).Close() + if fileToClose != nil { + _ = fileToClose.Close() + } + _ = os.Remove(v) }() - defer os.Remove(input.(string)) - case *multipart.FileHeader: - fileHeader := input.(*multipart.FileHeader) - file, err = fileHeader.Open() + f, err := v.Open() if err != nil { return err } + file = f + fileToClose = f defer func() { - _ = file.(multipart.File).Close() + if fileToClose != nil { + _ = fileToClose.Close() + } }() } - bs, err := ioutil.ReadAll(file) - if err != nil { + var buf bytes.Buffer + if _, err = io.Copy(&buf, file); err != nil { return err } + bs := buf.Bytes() var ext string if _, err = jpeg.Decode(bytes.NewReader(bs)); err == nil { @@ -53,7 +63,11 @@ func CompressImage(input interface{}, output string) (err error) { if err != nil { return err } - defer outputFile.Close() + defer func() { + if outputFile != nil { + _ = outputFile.Close() + } + }() switch ext { case "jpg": @@ -61,34 +75,22 @@ func CompressImage(input interface{}, output string) (err error) { if err != nil { return err } - if err = jpeg.Encode(outputFile, img, &jpeg.Options{Quality: 70}); err != nil { - return err - } - + return jpeg.Encode(outputFile, img, &jpeg.Options{Quality: 70}) case "png": img, err := png.Decode(bytes.NewReader(bs)) if err != nil { return err } encoder := png.Encoder{CompressionLevel: png.BestCompression} - if err = encoder.Encode(outputFile, img); err != nil { - return err - } - + return encoder.Encode(outputFile, img) case "gif": img, err := gif.DecodeAll(bytes.NewReader(bs)) if err != nil { return err } - if err = gif.EncodeAll(outputFile, img); err != nil { - return err - } - + return gif.EncodeAll(outputFile, img) default: - if _, err = outputFile.Write(bs); err != nil { - return err - } + _, err = outputFile.Write(bs) + return err } - - return nil } diff --git a/jwt/jwt.go b/jwt/jwt.go index 00a45a9..22d3a8d 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -12,15 +12,14 @@ import ( // Encode jwt編碼 func Encode(values types.Data, key string) (string, error) { claims := jwt.MapClaims{} - for key, value := range values { - claims[key] = value + for k, v := range values { + claims[k] = v } - tokenString, err := jwt. - NewWithClaims(jwt.SigningMethodHS256, claims). + tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims). SignedString([]byte(key)) if err != nil { - return "", err + return "", fmt.Errorf("failed to sign JWT with HS256: %w", err) } return tokenString, nil @@ -29,15 +28,14 @@ func Encode(values types.Data, key string) (string, error) { // EncodeRS256 jwt編碼 func EncodeRS256(values types.Data, key *rsa.PrivateKey) (string, error) { claims := jwt.MapClaims{} - for key, value := range values { - claims[key] = value + for k, v := range values { + claims[k] = v } - tokenString, err := jwt. - NewWithClaims(jwt.SigningMethodRS256, claims). + tokenString, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims). SignedString(key) if err != nil { - return "", err + return "", fmt.Errorf("failed to sign JWT with RS256: %w", err) } return tokenString, nil @@ -47,22 +45,27 @@ func EncodeRS256(values types.Data, key *rsa.PrivateKey) (string, error) { func Decode(tokenString string, key string) (types.Data, error) { token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("Unexpected error: %v. ", token.Header["alg"]) + return nil, fmt.Errorf("unexpected signing method: %w", token.Header["alg"]) } return []byte(key), nil }) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to parse JWT: %w", err) + } + + if token == nil || !token.Valid { + return nil, fmt.Errorf("invalid token or claims") } claims, ok := token.Claims.(jwt.MapClaims) - if !ok || !token.Valid { - return nil, fmt.Errorf("Token error. ") + if !ok { + return nil, fmt.Errorf("failed to parse token claims") } return types.Data(claims), nil } +// IsExpired 檢查 token 是否過期 func IsExpired(err error) bool { return errors.Is(err, jwt.ErrTokenExpired) }