From 26602e75529b4ab242f2bc0a3ebafb615c40cbbe Mon Sep 17 00:00:00 2001
From: Stanly <eashingliaw@gmail.com>
Date: Mon, 10 Mar 2025 01:51:18 +0800
Subject: [PATCH] Optimize

---
 crypto/aes.go        | 87 ++++++++++++++++++++++++++++++++++++----------------
 crypto/crypto.go     | 15 ++-------
 crypto/password.go   | 13 ++++++--
 crypto/rsa.go        | 40 ++++++++++++++++--------
 generate/generate.go | 47 ++++++++++------------------
 image/image.go       | 60 ++++++++++++++++++------------------
 jwt/jwt.go           | 31 ++++++++++---------
 7 files changed, 166 insertions(+), 127 deletions(-)

diff --git a/crypto/aes.go b/crypto/aes.go
index 1bd1d89..17f60e4 100644
--- a/crypto/aes.go
+++ b/crypto/aes.go
@@ -4,70 +4,103 @@ import (
 	"bytes"
 	"crypto/aes"
 	"crypto/cipher"
+	"fmt"
 
 	"github.com/Luzifer/go-openssl/v3"
 )
 
-func paddingPKCS7(ciphertext []byte, blockSize int) []byte {
+func addPKCS7Padding(ciphertext []byte, blockSize int) []byte {
 	padding := blockSize - len(ciphertext)%blockSize
-	padtext := bytes.Repeat([]byte{byte(padding)}, padding)
-	return append(ciphertext, padtext...)
+	padText := bytes.Repeat([]byte{byte(padding)}, padding)
+	return append(ciphertext, padText...)
 }
 
-func unpaddingPKCS7(origData []byte) []byte {
-	length := len(origData)
-	unpadding := int(origData[length-1])
-	return origData[:(length - unpadding)]
+func removePKCS7Padding(data []byte) ([]byte, error) {
+	if len(data) == 0 {
+		return nil, fmt.Errorf("data cannot be empty")
+	}
+
+	padding := int(data[len(data)-1])
+
+	if padding > len(data) {
+		return nil, fmt.Errorf("invalid padding size")
+	}
+
+	return data[:len(data)-padding], nil
 }
 
-//EncryptAES 加密函式
+// EncryptAES 加密函式
 func EncryptAES(plaintext, key, iv []byte) ([]byte, error) {
+	// 驗證密鑰長度是否符合 AES 的要求
+	if len(key) != aes.BlockSize && len(key) != 16 && len(key) != 24 && len(key) != 32 {
+		return nil, fmt.Errorf("invalid key size: %d, expected 16, 24, or 32", len(key))
+	}
+
+	// 創建 AES 密碼器
 	block, err := aes.NewCipher(key)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("failed to create AES cipher: %w", err)
 	}
+
+	// 確保輸入的資料長度符合 block size,並進行填充
 	blockSize := block.BlockSize()
-	plaintext = paddingPKCS7(plaintext, blockSize)
+	plaintext = addPKCS7Padding(plaintext, blockSize)
+
+	// 創建 CBC 模式加密器
 	blockMode := cipher.NewCBCEncrypter(block, iv)
-	crypted := make([]byte, len(plaintext))
-	blockMode.CryptBlocks(crypted, plaintext)
-	return crypted, nil
+
+	// 用來保存加密結果的緩衝區
+	enc := make([]byte, len(plaintext))
+	blockMode.CryptBlocks(enc, plaintext)
+
+	return enc, nil
 }
 
-// DecryptAES 解密函式
+// DecryptAES 使用 AES CBC 模式進行解密
 func DecryptAES(ciphertext, key, iv []byte) ([]byte, error) {
+	// 驗證密鑰長度是否符合 AES 的要求
+	if len(key) != aes.BlockSize && len(key) != 16 && len(key) != 24 && len(key) != 32 {
+		return nil, fmt.Errorf("invalid key size: %d, expected 16, 24, or 32", len(key))
+	}
+
+	// 創建 AES 密碼器
 	block, err := aes.NewCipher(key)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("failed to create AES cipher: %w", err)
 	}
+
+	// 確保輸入的資料長度符合 block size,並進行解密
 	blockSize := block.BlockSize()
 	blockMode := cipher.NewCBCDecrypter(block, iv[:blockSize])
+
+	// 用來保存解密結果的緩衝區
 	origData := make([]byte, len(ciphertext))
 	blockMode.CryptBlocks(origData, ciphertext)
-	origData = unpaddingPKCS7(origData)
-	return origData, nil
+
+	// 去除 PKCS7 填充
+	return removePKCS7Padding(origData)
 }
 
-// DecryptAESWithOpenSSL 使用Openssl 解密AES
-func DecryptAESWithOpenSSL(value, key string) ([]byte, error) {
+// EncryptAESWithOpenSSL 使用Openssl 加密AES
+func EncryptAESWithOpenSSL(value, key string) ([]byte, error) {
 	o := openssl.New()
 
-	dec, err := o.DecryptBytes(key, []byte(value), openssl.DigestMD5Sum)
+	enc, err := o.EncryptBytes(key, []byte(value), openssl.DigestMD5Sum)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("failed to encrypt with OpenSSL: %w", err)
 	}
 
-	return dec, nil
+	return enc, nil
 }
 
-// EncryptAESWithOpenSSL 使用Openssl 加密AES
-func EncryptAESWithOpenSSL(value, key string) ([]byte, error) {
+// DecryptAESWithOpenSSL 使用Openssl 解密AES
+func DecryptAESWithOpenSSL(value, key string) ([]byte, error) {
 	o := openssl.New()
 
-	enc, err := o.EncryptBytes(key, []byte(value), openssl.DigestMD5Sum)
+	dec, err := o.DecryptBytes(key, []byte(value), openssl.DigestMD5Sum)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("failed to decrypt with OpenSSL: %w", err)
 	}
 
-	return enc, nil
+	return dec, nil
 }
diff --git a/crypto/crypto.go b/crypto/crypto.go
index 1d7a569..7bdf019 100644
--- a/crypto/crypto.go
+++ b/crypto/crypto.go
@@ -9,24 +9,15 @@ import (
 
 // MD5 回傳md5加密
 func MD5(v string) string {
-	h := md5.New()
-	h.Write([]byte(v))
-	bs := h.Sum(nil)
-	return fmt.Sprintf("%x", bs)
+	return fmt.Sprintf("%x", md5.Sum([]byte(v)))
 }
 
 // SHA1 回傳sha1加密
 func SHA1(v string) string {
-	h := sha1.New()
-	h.Write([]byte(v))
-	bs := h.Sum(nil)
-	return fmt.Sprintf("%x", bs)
+	return fmt.Sprintf("%x", sha1.Sum([]byte(v)))
 }
 
 // SHA256 回傳sha256加密
 func SHA256(v string) string {
-	h := sha256.New()
-	h.Write([]byte(v))
-	bs := h.Sum(nil)
-	return fmt.Sprintf("%x", bs)
+	return fmt.Sprintf("%x", sha256.Sum256([]byte(v)))
 }
diff --git a/crypto/password.go b/crypto/password.go
index 1ab8d4b..71c82b0 100644
--- a/crypto/password.go
+++ b/crypto/password.go
@@ -1,6 +1,8 @@
 package crypto
 
 import (
+	"fmt"
+
 	"golang.org/x/crypto/bcrypt"
 )
 
@@ -8,7 +10,7 @@ import (
 func EncryptPassword(pwd string) (string, error) {
 	hash, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost)
 	if err != nil {
-		return "", err
+		return "", fmt.Errorf("failed to encrypt password: %w", err)
 	}
 
 	return string(hash), nil
@@ -16,5 +18,12 @@ func EncryptPassword(pwd string) (string, error) {
 
 // CheckPassword 檢查密碼
 func CheckPassword(pwd, hash string) error {
-	return bcrypt.CompareHashAndPassword([]byte(hash), []byte(pwd))
+	err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pwd))
+	if err != nil {
+		// 返回帶有更多上下文信息的錯誤
+		return fmt.Errorf("password does not match: %w", err)
+	}
+
+	// 如果密碼正確,返回 nil
+	return nil
 }
diff --git a/crypto/rsa.go b/crypto/rsa.go
index 1266c30..53ef35f 100644
--- a/crypto/rsa.go
+++ b/crypto/rsa.go
@@ -6,19 +6,20 @@ import (
 	"crypto/x509"
 	"encoding/pem"
 	"errors"
+	"fmt"
 )
 
 func NewKeyRSA(bitSize int) (pubPEM []byte, keyPEM []byte, err error) {
-	// Generate RSA key.
+	// 生成 RSA 密鑰對
 	key, err := rsa.GenerateKey(rand.Reader, bitSize)
 	if err != nil {
-		return nil, nil, err
+		return nil, nil, fmt.Errorf("failed to generate RSA key: %w", err)
 	}
 
-	// Extract public component.
+	// 提取公鑰部分並轉換為 PEM 格式
 	pubBytes, err := x509.MarshalPKIXPublicKey(key.Public())
 	if err != nil {
-		return nil, nil, err
+		return nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
 	}
 	pubPEM = pem.EncodeToMemory(
 		&pem.Block{
@@ -27,7 +28,7 @@ func NewKeyRSA(bitSize int) (pubPEM []byte, keyPEM []byte, err error) {
 		},
 	)
 
-	// Encode private key to PKCS#1 ASN.1 PEM.
+	// 編碼私鑰為 PKCS#1 ASN.1 PEM 格式
 	keyPEM = pem.EncodeToMemory(
 		&pem.Block{
 			Type:  "RSA PRIVATE KEY",
@@ -42,29 +43,44 @@ func NewKeyRSA(bitSize int) (pubPEM []byte, keyPEM []byte, err error) {
 func EncryptRSA(value, publicKey []byte) ([]byte, error) {
 	block, _ := pem.Decode(publicKey)
 	if block == nil {
-		return nil, errors.New("public key error")
+		return nil, errors.New("failed to decode public key PEM")
 	}
 
 	pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("failed to parse public key: %w", err)
 	}
 
-	pub := pubInterface.(*rsa.PublicKey)
-	return rsa.EncryptPKCS1v15(rand.Reader, pub, value)
+	pub, ok := pubInterface.(*rsa.PublicKey)
+	if !ok {
+		return nil, errors.New("invalid public key type")
+	}
+
+	// 使用公鑰進行加密
+	enc, err := rsa.EncryptPKCS1v15(rand.Reader, pub, value)
+	if err != nil {
+		return nil, fmt.Errorf("failed to encrypt with RSA: %w", err)
+	}
+
+	return enc, nil
 }
 
 // DecryptRSA rsa解密
 func DecryptRSA(ciphertext, privateKey []byte) ([]byte, error) {
 	block, _ := pem.Decode(privateKey)
 	if block == nil {
-		return nil, errors.New("private key error")
+		return nil, errors.New("failed to decode private key PEM")
 	}
 
 	priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("failed to parse private key: %w", err)
+	}
+
+	dec, err := rsa.DecryptPKCS1v15(rand.Reader, priv, ciphertext)
+	if err != nil {
+		return nil, fmt.Errorf("failed to decrypt with RSA: %w", err)
 	}
 
-	return rsa.DecryptPKCS1v15(rand.Reader, priv, ciphertext)
+	return dec, nil
 }
diff --git a/generate/generate.go b/generate/generate.go
index 4e7bff7..1a8703b 100644
--- a/generate/generate.go
+++ b/generate/generate.go
@@ -2,52 +2,37 @@ package generate
 
 import (
 	"crypto/rand"
-	"math/big"
+	"fmt"
 )
 
 // GetRandomString 取得隨機字串
 func GetRandomString(n int) (string, error) {
 	const alphaNum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
-
-	buffer := make([]byte, n)
-	max := big.NewInt(int64(len(alphaNum)))
-
-	for i := 0; i < n; i++ {
-		index, err := randomInt(max)
-		if err != nil {
-			return "", err
-		}
-
-		buffer[i] = alphaNum[index]
-	}
-
-	return string(buffer), nil
+	return getRandomFromCharset(n, alphaNum)
 }
 
 // GetRandomKey 取得隨機金鑰
 func GetRandomKey(n int) (string, error) {
 	const alphaNum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz~!@#$%^&*()-_=+"
+	return getRandomFromCharset(n, alphaNum)
+}
 
+// getRandomFromCharset 通用隨機字串生成器
+func getRandomFromCharset(n int, charset string) (string, error) {
+	// 預先計算字符集的長度
+	charsLen := len(charset)
 	buffer := make([]byte, n)
-	max := big.NewInt(int64(len(alphaNum)))
 
-	for i := 0; i < n; i++ {
-		index, err := randomInt(max)
-		if err != nil {
-			return "", err
-		}
-
-		buffer[i] = alphaNum[index]
+	// 使用 crypto/rand 生成隨機數據
+	_, err := rand.Read(buffer)
+	if err != nil {
+		return "", fmt.Errorf("failed to generate random data: %w", err)
 	}
 
-	return string(buffer), nil
-}
-
-func randomInt(max *big.Int) (int, error) {
-	random, err := rand.Int(rand.Reader, max)
-	if err != nil {
-		return 0, err
+	// 從字元集選擇字符
+	for i := 0; i < n; i++ {
+		buffer[i] = charset[int(buffer[i])%charsLen]
 	}
 
-	return int(random.Int64()), nil
+	return string(buffer), nil
 }
diff --git a/image/image.go b/image/image.go
index 2d792be..6aa33f1 100644
--- a/image/image.go
+++ b/image/image.go
@@ -6,7 +6,6 @@ import (
 	"image/jpeg"
 	"image/png"
 	"io"
-	"io/ioutil"
 	"mime/multipart"
 	"os"
 )
@@ -14,31 +13,42 @@ import (
 // CompressImage 壓縮圖片至指定位置
 func CompressImage(input interface{}, output string) (err error) {
 	var file io.Reader
-	switch input.(type) {
+	var fileToClose io.Closer
+
+	switch v := input.(type) {
 	case string:
-		if file, err = os.Open(input.(string)); err != nil {
+		f, err := os.Open(v)
+		if err != nil {
 			return err
 		}
+
+		file = f
+		fileToClose = f
 		defer func() {
-			_ = file.(*os.File).Close()
+			if fileToClose != nil {
+				_ = fileToClose.Close()
+			}
+			_ = os.Remove(v)
 		}()
-		defer os.Remove(input.(string))
-
 	case *multipart.FileHeader:
-		fileHeader := input.(*multipart.FileHeader)
-		file, err = fileHeader.Open()
+		f, err := v.Open()
 		if err != nil {
 			return err
 		}
+		file = f
+		fileToClose = f
 		defer func() {
-			_ = file.(multipart.File).Close()
+			if fileToClose != nil {
+				_ = fileToClose.Close()
+			}
 		}()
 	}
 
-	bs, err := ioutil.ReadAll(file)
-	if err != nil {
+	var buf bytes.Buffer
+	if _, err = io.Copy(&buf, file); err != nil {
 		return err
 	}
+	bs := buf.Bytes()
 
 	var ext string
 	if _, err = jpeg.Decode(bytes.NewReader(bs)); err == nil {
@@ -53,7 +63,11 @@ func CompressImage(input interface{}, output string) (err error) {
 	if err != nil {
 		return err
 	}
-	defer outputFile.Close()
+	defer func() {
+		if outputFile != nil {
+			_ = outputFile.Close()
+		}
+	}()
 
 	switch ext {
 	case "jpg":
@@ -61,34 +75,22 @@ func CompressImage(input interface{}, output string) (err error) {
 		if err != nil {
 			return err
 		}
-		if err = jpeg.Encode(outputFile, img, &jpeg.Options{Quality: 70}); err != nil {
-			return err
-		}
-
+		return jpeg.Encode(outputFile, img, &jpeg.Options{Quality: 70})
 	case "png":
 		img, err := png.Decode(bytes.NewReader(bs))
 		if err != nil {
 			return err
 		}
 		encoder := png.Encoder{CompressionLevel: png.BestCompression}
-		if err = encoder.Encode(outputFile, img); err != nil {
-			return err
-		}
-
+		return encoder.Encode(outputFile, img)
 	case "gif":
 		img, err := gif.DecodeAll(bytes.NewReader(bs))
 		if err != nil {
 			return err
 		}
-		if err = gif.EncodeAll(outputFile, img); err != nil {
-			return err
-		}
-
+		return gif.EncodeAll(outputFile, img)
 	default:
-		if _, err = outputFile.Write(bs); err != nil {
-			return err
-		}
+		_, err = outputFile.Write(bs)
+		return err
 	}
-
-	return nil
 }
diff --git a/jwt/jwt.go b/jwt/jwt.go
index 00a45a9..22d3a8d 100644
--- a/jwt/jwt.go
+++ b/jwt/jwt.go
@@ -12,15 +12,14 @@ import (
 // Encode jwt編碼
 func Encode(values types.Data, key string) (string, error) {
 	claims := jwt.MapClaims{}
-	for key, value := range values {
-		claims[key] = value
+	for k, v := range values {
+		claims[k] = v
 	}
 
-	tokenString, err := jwt.
-		NewWithClaims(jwt.SigningMethodHS256, claims).
+	tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).
 		SignedString([]byte(key))
 	if err != nil {
-		return "", err
+		return "", fmt.Errorf("failed to sign JWT with HS256: %w", err)
 	}
 
 	return tokenString, nil
@@ -29,15 +28,14 @@ func Encode(values types.Data, key string) (string, error) {
 // EncodeRS256 jwt編碼
 func EncodeRS256(values types.Data, key *rsa.PrivateKey) (string, error) {
 	claims := jwt.MapClaims{}
-	for key, value := range values {
-		claims[key] = value
+	for k, v := range values {
+		claims[k] = v
 	}
 
-	tokenString, err := jwt.
-		NewWithClaims(jwt.SigningMethodRS256, claims).
+	tokenString, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).
 		SignedString(key)
 	if err != nil {
-		return "", err
+		return "", fmt.Errorf("failed to sign JWT with RS256: %w", err)
 	}
 
 	return tokenString, nil
@@ -47,22 +45,27 @@ func EncodeRS256(values types.Data, key *rsa.PrivateKey) (string, error) {
 func Decode(tokenString string, key string) (types.Data, error) {
 	token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
 		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
-			return nil, fmt.Errorf("Unexpected error: %v. ", token.Header["alg"])
+			return nil, fmt.Errorf("unexpected signing method: %w", token.Header["alg"])
 		}
 		return []byte(key), nil
 	})
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("failed to parse JWT: %w", err)
+	}
+
+	if token == nil || !token.Valid {
+		return nil, fmt.Errorf("invalid token or claims")
 	}
 
 	claims, ok := token.Claims.(jwt.MapClaims)
-	if !ok || !token.Valid {
-		return nil, fmt.Errorf("Token error. ")
+	if !ok {
+		return nil, fmt.Errorf("failed to parse token claims")
 	}
 
 	return types.Data(claims), nil
 }
 
+// IsExpired 檢查 token 是否過期
 func IsExpired(err error) bool {
 	return errors.Is(err, jwt.ErrTokenExpired)
 }