master v1.5.8
Stanly 4 weeks ago
parent 826dee2e35
commit 26602e7552

@ -4,70 +4,103 @@ import (
"bytes" "bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"fmt"
"github.com/Luzifer/go-openssl/v3" "github.com/Luzifer/go-openssl/v3"
) )
func paddingPKCS7(ciphertext []byte, blockSize int) []byte { func addPKCS7Padding(ciphertext []byte, blockSize int) []byte {
padding := blockSize - len(ciphertext)%blockSize padding := blockSize - len(ciphertext)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding) padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padtext...) return append(ciphertext, padText...)
} }
func unpaddingPKCS7(origData []byte) []byte { func removePKCS7Padding(data []byte) ([]byte, error) {
length := len(origData) if len(data) == 0 {
unpadding := int(origData[length-1]) return nil, fmt.Errorf("data cannot be empty")
return origData[:(length - unpadding)] }
padding := int(data[len(data)-1])
if padding > len(data) {
return nil, fmt.Errorf("invalid padding size")
}
return data[:len(data)-padding], nil
} }
//EncryptAES 加密函式 // EncryptAES 加密函式
func EncryptAES(plaintext, key, iv []byte) ([]byte, error) { func EncryptAES(plaintext, key, iv []byte) ([]byte, error) {
// 驗證密鑰長度是否符合 AES 的要求
if len(key) != aes.BlockSize && len(key) != 16 && len(key) != 24 && len(key) != 32 {
return nil, fmt.Errorf("invalid key size: %d, expected 16, 24, or 32", len(key))
}
// 創建 AES 密碼器
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to create AES cipher: %w", err)
} }
// 確保輸入的資料長度符合 block size並進行填充
blockSize := block.BlockSize() blockSize := block.BlockSize()
plaintext = paddingPKCS7(plaintext, blockSize) plaintext = addPKCS7Padding(plaintext, blockSize)
// 創建 CBC 模式加密器
blockMode := cipher.NewCBCEncrypter(block, iv) blockMode := cipher.NewCBCEncrypter(block, iv)
crypted := make([]byte, len(plaintext))
blockMode.CryptBlocks(crypted, plaintext) // 用來保存加密結果的緩衝區
return crypted, nil enc := make([]byte, len(plaintext))
blockMode.CryptBlocks(enc, plaintext)
return enc, nil
} }
// DecryptAES 解密函式 // DecryptAES 使用 AES CBC 模式進行解密
func DecryptAES(ciphertext, key, iv []byte) ([]byte, error) { func DecryptAES(ciphertext, key, iv []byte) ([]byte, error) {
// 驗證密鑰長度是否符合 AES 的要求
if len(key) != aes.BlockSize && len(key) != 16 && len(key) != 24 && len(key) != 32 {
return nil, fmt.Errorf("invalid key size: %d, expected 16, 24, or 32", len(key))
}
// 創建 AES 密碼器
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to create AES cipher: %w", err)
} }
// 確保輸入的資料長度符合 block size並進行解密
blockSize := block.BlockSize() blockSize := block.BlockSize()
blockMode := cipher.NewCBCDecrypter(block, iv[:blockSize]) blockMode := cipher.NewCBCDecrypter(block, iv[:blockSize])
// 用來保存解密結果的緩衝區
origData := make([]byte, len(ciphertext)) origData := make([]byte, len(ciphertext))
blockMode.CryptBlocks(origData, ciphertext) blockMode.CryptBlocks(origData, ciphertext)
origData = unpaddingPKCS7(origData)
return origData, nil // 去除 PKCS7 填充
return removePKCS7Padding(origData)
} }
// DecryptAESWithOpenSSL 使用Openssl 解密AES // EncryptAESWithOpenSSL 使用Openssl 加密AES
func DecryptAESWithOpenSSL(value, key string) ([]byte, error) { func EncryptAESWithOpenSSL(value, key string) ([]byte, error) {
o := openssl.New() o := openssl.New()
dec, err := o.DecryptBytes(key, []byte(value), openssl.DigestMD5Sum) enc, err := o.EncryptBytes(key, []byte(value), openssl.DigestMD5Sum)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to encrypt with OpenSSL: %w", err)
} }
return dec, nil return enc, nil
} }
// EncryptAESWithOpenSSL 使用Openssl 加密AES // DecryptAESWithOpenSSL 使用Openssl 解密AES
func EncryptAESWithOpenSSL(value, key string) ([]byte, error) { func DecryptAESWithOpenSSL(value, key string) ([]byte, error) {
o := openssl.New() o := openssl.New()
enc, err := o.EncryptBytes(key, []byte(value), openssl.DigestMD5Sum) dec, err := o.DecryptBytes(key, []byte(value), openssl.DigestMD5Sum)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to decrypt with OpenSSL: %w", err)
} }
return enc, nil return dec, nil
} }

@ -9,24 +9,15 @@ import (
// MD5 回傳md5加密 // MD5 回傳md5加密
func MD5(v string) string { func MD5(v string) string {
h := md5.New() return fmt.Sprintf("%x", md5.Sum([]byte(v)))
h.Write([]byte(v))
bs := h.Sum(nil)
return fmt.Sprintf("%x", bs)
} }
// SHA1 回傳sha1加密 // SHA1 回傳sha1加密
func SHA1(v string) string { func SHA1(v string) string {
h := sha1.New() return fmt.Sprintf("%x", sha1.Sum([]byte(v)))
h.Write([]byte(v))
bs := h.Sum(nil)
return fmt.Sprintf("%x", bs)
} }
// SHA256 回傳sha256加密 // SHA256 回傳sha256加密
func SHA256(v string) string { func SHA256(v string) string {
h := sha256.New() return fmt.Sprintf("%x", sha256.Sum256([]byte(v)))
h.Write([]byte(v))
bs := h.Sum(nil)
return fmt.Sprintf("%x", bs)
} }

@ -1,6 +1,8 @@
package crypto package crypto
import ( import (
"fmt"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -8,7 +10,7 @@ import (
func EncryptPassword(pwd string) (string, error) { func EncryptPassword(pwd string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost) hash, err := bcrypt.GenerateFromPassword([]byte(pwd), bcrypt.DefaultCost)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to encrypt password: %w", err)
} }
return string(hash), nil return string(hash), nil
@ -16,5 +18,12 @@ func EncryptPassword(pwd string) (string, error) {
// CheckPassword 檢查密碼 // CheckPassword 檢查密碼
func CheckPassword(pwd, hash string) error { func CheckPassword(pwd, hash string) error {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(pwd)) err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pwd))
if err != nil {
// 返回帶有更多上下文信息的錯誤
return fmt.Errorf("password does not match: %w", err)
}
// 如果密碼正確,返回 nil
return nil
} }

@ -6,19 +6,20 @@ import (
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt"
) )
func NewKeyRSA(bitSize int) (pubPEM []byte, keyPEM []byte, err error) { func NewKeyRSA(bitSize int) (pubPEM []byte, keyPEM []byte, err error) {
// Generate RSA key. // 生成 RSA 密鑰對
key, err := rsa.GenerateKey(rand.Reader, bitSize) key, err := rsa.GenerateKey(rand.Reader, bitSize)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, fmt.Errorf("failed to generate RSA key: %w", err)
} }
// Extract public component. // 提取公鑰部分並轉換為 PEM 格式
pubBytes, err := x509.MarshalPKIXPublicKey(key.Public()) pubBytes, err := x509.MarshalPKIXPublicKey(key.Public())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
} }
pubPEM = pem.EncodeToMemory( pubPEM = pem.EncodeToMemory(
&pem.Block{ &pem.Block{
@ -27,7 +28,7 @@ func NewKeyRSA(bitSize int) (pubPEM []byte, keyPEM []byte, err error) {
}, },
) )
// Encode private key to PKCS#1 ASN.1 PEM. // 編碼私鑰為 PKCS#1 ASN.1 PEM 格式
keyPEM = pem.EncodeToMemory( keyPEM = pem.EncodeToMemory(
&pem.Block{ &pem.Block{
Type: "RSA PRIVATE KEY", Type: "RSA PRIVATE KEY",
@ -42,29 +43,44 @@ func NewKeyRSA(bitSize int) (pubPEM []byte, keyPEM []byte, err error) {
func EncryptRSA(value, publicKey []byte) ([]byte, error) { func EncryptRSA(value, publicKey []byte) ([]byte, error) {
block, _ := pem.Decode(publicKey) block, _ := pem.Decode(publicKey)
if block == nil { if block == nil {
return nil, errors.New("public key error") return nil, errors.New("failed to decode public key PEM")
} }
pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes) pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to parse public key: %w", err)
} }
pub := pubInterface.(*rsa.PublicKey) pub, ok := pubInterface.(*rsa.PublicKey)
return rsa.EncryptPKCS1v15(rand.Reader, pub, value) if !ok {
return nil, errors.New("invalid public key type")
}
// 使用公鑰進行加密
enc, err := rsa.EncryptPKCS1v15(rand.Reader, pub, value)
if err != nil {
return nil, fmt.Errorf("failed to encrypt with RSA: %w", err)
}
return enc, nil
} }
// DecryptRSA rsa解密 // DecryptRSA rsa解密
func DecryptRSA(ciphertext, privateKey []byte) ([]byte, error) { func DecryptRSA(ciphertext, privateKey []byte) ([]byte, error) {
block, _ := pem.Decode(privateKey) block, _ := pem.Decode(privateKey)
if block == nil { if block == nil {
return nil, errors.New("private key error") return nil, errors.New("failed to decode private key PEM")
} }
priv, err := x509.ParsePKCS1PrivateKey(block.Bytes) priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to parse private key: %w", err)
}
dec, err := rsa.DecryptPKCS1v15(rand.Reader, priv, ciphertext)
if err != nil {
return nil, fmt.Errorf("failed to decrypt with RSA: %w", err)
} }
return rsa.DecryptPKCS1v15(rand.Reader, priv, ciphertext) return dec, nil
} }

@ -2,52 +2,37 @@ package generate
import ( import (
"crypto/rand" "crypto/rand"
"math/big" "fmt"
) )
// GetRandomString 取得隨機字串 // GetRandomString 取得隨機字串
func GetRandomString(n int) (string, error) { func GetRandomString(n int) (string, error) {
const alphaNum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" const alphaNum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
return getRandomFromCharset(n, alphaNum)
buffer := make([]byte, n)
max := big.NewInt(int64(len(alphaNum)))
for i := 0; i < n; i++ {
index, err := randomInt(max)
if err != nil {
return "", err
}
buffer[i] = alphaNum[index]
}
return string(buffer), nil
} }
// GetRandomKey 取得隨機金鑰 // GetRandomKey 取得隨機金鑰
func GetRandomKey(n int) (string, error) { func GetRandomKey(n int) (string, error) {
const alphaNum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz~!@#$%^&*()-_=+" const alphaNum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz~!@#$%^&*()-_=+"
return getRandomFromCharset(n, alphaNum)
}
// getRandomFromCharset 通用隨機字串生成器
func getRandomFromCharset(n int, charset string) (string, error) {
// 預先計算字符集的長度
charsLen := len(charset)
buffer := make([]byte, n) buffer := make([]byte, n)
max := big.NewInt(int64(len(alphaNum)))
for i := 0; i < n; i++ { // 使用 crypto/rand 生成隨機數據
index, err := randomInt(max) _, err := rand.Read(buffer)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to generate random data: %w", err)
}
buffer[i] = alphaNum[index]
} }
return string(buffer), nil // 從字元集選擇字符
} for i := 0; i < n; i++ {
buffer[i] = charset[int(buffer[i])%charsLen]
func randomInt(max *big.Int) (int, error) {
random, err := rand.Int(rand.Reader, max)
if err != nil {
return 0, err
} }
return int(random.Int64()), nil return string(buffer), nil
} }

@ -6,7 +6,6 @@ import (
"image/jpeg" "image/jpeg"
"image/png" "image/png"
"io" "io"
"io/ioutil"
"mime/multipart" "mime/multipart"
"os" "os"
) )
@ -14,31 +13,42 @@ import (
// CompressImage 壓縮圖片至指定位置 // CompressImage 壓縮圖片至指定位置
func CompressImage(input interface{}, output string) (err error) { func CompressImage(input interface{}, output string) (err error) {
var file io.Reader var file io.Reader
switch input.(type) { var fileToClose io.Closer
switch v := input.(type) {
case string: case string:
if file, err = os.Open(input.(string)); err != nil { f, err := os.Open(v)
if err != nil {
return err return err
} }
file = f
fileToClose = f
defer func() { defer func() {
_ = file.(*os.File).Close() if fileToClose != nil {
_ = fileToClose.Close()
}
_ = os.Remove(v)
}() }()
defer os.Remove(input.(string))
case *multipart.FileHeader: case *multipart.FileHeader:
fileHeader := input.(*multipart.FileHeader) f, err := v.Open()
file, err = fileHeader.Open()
if err != nil { if err != nil {
return err return err
} }
file = f
fileToClose = f
defer func() { defer func() {
_ = file.(multipart.File).Close() if fileToClose != nil {
_ = fileToClose.Close()
}
}() }()
} }
bs, err := ioutil.ReadAll(file) var buf bytes.Buffer
if err != nil { if _, err = io.Copy(&buf, file); err != nil {
return err return err
} }
bs := buf.Bytes()
var ext string var ext string
if _, err = jpeg.Decode(bytes.NewReader(bs)); err == nil { if _, err = jpeg.Decode(bytes.NewReader(bs)); err == nil {
@ -53,7 +63,11 @@ func CompressImage(input interface{}, output string) (err error) {
if err != nil { if err != nil {
return err return err
} }
defer outputFile.Close() defer func() {
if outputFile != nil {
_ = outputFile.Close()
}
}()
switch ext { switch ext {
case "jpg": case "jpg":
@ -61,34 +75,22 @@ func CompressImage(input interface{}, output string) (err error) {
if err != nil { if err != nil {
return err return err
} }
if err = jpeg.Encode(outputFile, img, &jpeg.Options{Quality: 70}); err != nil { return jpeg.Encode(outputFile, img, &jpeg.Options{Quality: 70})
return err
}
case "png": case "png":
img, err := png.Decode(bytes.NewReader(bs)) img, err := png.Decode(bytes.NewReader(bs))
if err != nil { if err != nil {
return err return err
} }
encoder := png.Encoder{CompressionLevel: png.BestCompression} encoder := png.Encoder{CompressionLevel: png.BestCompression}
if err = encoder.Encode(outputFile, img); err != nil { return encoder.Encode(outputFile, img)
return err
}
case "gif": case "gif":
img, err := gif.DecodeAll(bytes.NewReader(bs)) img, err := gif.DecodeAll(bytes.NewReader(bs))
if err != nil { if err != nil {
return err return err
} }
if err = gif.EncodeAll(outputFile, img); err != nil { return gif.EncodeAll(outputFile, img)
return err
}
default: default:
if _, err = outputFile.Write(bs); err != nil { _, err = outputFile.Write(bs)
return err return err
}
} }
return nil
} }

@ -12,15 +12,14 @@ import (
// Encode jwt編碼 // Encode jwt編碼
func Encode(values types.Data, key string) (string, error) { func Encode(values types.Data, key string) (string, error) {
claims := jwt.MapClaims{} claims := jwt.MapClaims{}
for key, value := range values { for k, v := range values {
claims[key] = value claims[k] = v
} }
tokenString, err := jwt. tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).
NewWithClaims(jwt.SigningMethodHS256, claims).
SignedString([]byte(key)) SignedString([]byte(key))
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to sign JWT with HS256: %w", err)
} }
return tokenString, nil return tokenString, nil
@ -29,15 +28,14 @@ func Encode(values types.Data, key string) (string, error) {
// EncodeRS256 jwt編碼 // EncodeRS256 jwt編碼
func EncodeRS256(values types.Data, key *rsa.PrivateKey) (string, error) { func EncodeRS256(values types.Data, key *rsa.PrivateKey) (string, error) {
claims := jwt.MapClaims{} claims := jwt.MapClaims{}
for key, value := range values { for k, v := range values {
claims[key] = value claims[k] = v
} }
tokenString, err := jwt. tokenString, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).
NewWithClaims(jwt.SigningMethodRS256, claims).
SignedString(key) SignedString(key)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to sign JWT with RS256: %w", err)
} }
return tokenString, nil return tokenString, nil
@ -47,22 +45,27 @@ func EncodeRS256(values types.Data, key *rsa.PrivateKey) (string, error) {
func Decode(tokenString string, key string) (types.Data, error) { func Decode(tokenString string, key string) (types.Data, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected error: %v. ", token.Header["alg"]) return nil, fmt.Errorf("unexpected signing method: %w", token.Header["alg"])
} }
return []byte(key), nil return []byte(key), nil
}) })
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to parse JWT: %w", err)
}
if token == nil || !token.Valid {
return nil, fmt.Errorf("invalid token or claims")
} }
claims, ok := token.Claims.(jwt.MapClaims) claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid { if !ok {
return nil, fmt.Errorf("Token error. ") return nil, fmt.Errorf("failed to parse token claims")
} }
return types.Data(claims), nil return types.Data(claims), nil
} }
// IsExpired 檢查 token 是否過期
func IsExpired(err error) bool { func IsExpired(err error) bool {
return errors.Is(err, jwt.ErrTokenExpired) return errors.Is(err, jwt.ErrTokenExpired)
} }

Loading…
Cancel
Save