Skip to content

Instantly share code, notes, and snippets.

@apis
Created February 6, 2024 15:11
Show Gist options
  • Save apis/9684ff2fd22a7fd06bb2306ae916cbc0 to your computer and use it in GitHub Desktop.
Save apis/9684ff2fd22a7fd06bb2306ae916cbc0 to your computer and use it in GitHub Desktop.
Encode/decode with PBKDF2 key derivation function and AES-256 cipher in Go
package secret
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"fmt"
"golang.org/x/crypto/pbkdf2"
)
const keyLength = 32
const saltLength = 8
const keyIterationCount = 4096
// CipherText is encoded with PBKDF2 key derivation function and AES-256 cipher.
// CipherText byte array converted into string using BASE64 encoding.
// CipherText byte array has the following format:
// +-------------------------------------------------------------------------------+
// | PBKDF2 Salt (8 bytes) | AES-256 IV (16 bytes) | AES-256 Ciphertext (variable) |
// +-------------------------------------------------------------------------------+
// DecryptPbkdf2Aes256 decrypts cipherText using user supplied password string and returns decrypted string
func DecryptPbkdf2Aes256(password string, cipherText string) (string, error) {
cipherBuffer, err := base64.StdEncoding.DecodeString(cipherText)
if err != nil {
msg := "base64.StdEncoding.DecodeString() failed"
return "", fmt.Errorf("%s: %w", msg, err)
}
if len(cipherBuffer) < saltLength+aes.BlockSize {
msg := "invalid input cipherText size"
return "", fmt.Errorf("%s", msg)
}
salt := cipherBuffer[:saltLength]
key := createKey(password, salt)
messageBuffer, err := decrypt(key, cipherBuffer[saltLength:])
if err != nil {
msg := "secret.decrypt() failed"
return "", fmt.Errorf("%s: %w", msg, err)
}
return string(messageBuffer), nil
}
// EncryptPbkdf2Aes256 encrypts message using user supplied password string and returns encrypted string
func EncryptPbkdf2Aes256(password string, message string) (string, error) {
salt, err := createSalt()
if err != nil {
msg := "secret.createSalt() failed"
return "", fmt.Errorf("%s: %w", msg, err)
}
key := createKey(password, salt)
messageBuffer := []byte(message)
cipherBuffer, err := encrypt(salt, key, messageBuffer)
if err != nil {
msg := "secret.encrypt() failed"
return "", fmt.Errorf("%s: %w", msg, err)
}
return base64.StdEncoding.EncodeToString(cipherBuffer), nil
}
func createSalt() ([]byte, error) {
salt := make([]byte, saltLength)
_, err := rand.Read(salt)
if err != nil {
msg := "rand.Read() failed"
return nil, fmt.Errorf("%s: %w", msg, err)
}
return salt, nil
}
func createKey(password string, salt []byte) []byte {
key := pbkdf2.Key([]byte(password), salt, keyIterationCount, keyLength, sha1.New)
return key
}
func encrypt(salt []byte, key []byte, messageBuffer []byte) ([]byte, error) {
cipherBlock, err := aes.NewCipher(key)
if err != nil {
msg := "aes.NewCipher() failed"
return nil, fmt.Errorf("%s: %w", msg, err)
}
cipherBuffer := make([]byte, saltLength+aes.BlockSize+len(messageBuffer))
copy(cipherBuffer, salt)
iv := cipherBuffer[saltLength : saltLength+aes.BlockSize]
_, err = rand.Read(iv)
if err != nil {
msg := "rand.Read() failed"
return nil, fmt.Errorf("%s: %w", msg, err)
}
stream := cipher.NewCFBEncrypter(cipherBlock, iv)
stream.XORKeyStream(cipherBuffer[saltLength+aes.BlockSize:], messageBuffer)
return cipherBuffer, nil
}
func decrypt(key []byte, cipherBuffer []byte) ([]byte, error) {
cipherBlock, err := aes.NewCipher(key)
if err != nil {
msg := "aes.NewCipher() failed"
return nil, fmt.Errorf("%s: %w", msg, err)
}
iv := cipherBuffer[:aes.BlockSize]
cipherBuffer = cipherBuffer[aes.BlockSize:]
stream := cipher.NewCFBDecrypter(cipherBlock, iv)
stream.XORKeyStream(cipherBuffer, cipherBuffer)
return cipherBuffer, nil
}
package secret
import (
"encoding/base64"
"strings"
"testing"
)
func TestEncryptAndDecrypt(t *testing.T) {
password := "test_password"
message := "test_message"
cipher, err := EncryptPbkdf2Aes256(password, message)
if err != nil {
t.Fatalf("EncryptPbkdf2Aes256 failed: %v", err)
}
decryptedMessage, err := DecryptPbkdf2Aes256(password, cipher)
if err != nil {
t.Fatalf("DecryptPbkdf2Aes256 failed: %v", err)
}
if decryptedMessage != message {
t.Fatalf("decrypted message does not match original message")
}
}
func TestEncryptAndDecryptWithZeroPayload(t *testing.T) {
password := "test_password"
message := ""
cipher, err := EncryptPbkdf2Aes256(password, message)
if err != nil {
t.Fatalf("EncryptPbkdf2Aes256 failed: %v", err)
}
decryptedMessage, err := DecryptPbkdf2Aes256(password, cipher)
if err != nil {
t.Fatalf("DecryptPbkdf2Aes256 failed: %v", err)
}
if decryptedMessage != message {
t.Fatalf("decrypted message does not match original message")
}
}
func TestEncryptAndDecryptWithEmptyPassword(t *testing.T) {
password := ""
message := "test_message"
cipher, err := EncryptPbkdf2Aes256(password, message)
if err != nil {
t.Fatalf("EncryptPbkdf2Aes256 failed: %v", err)
}
decryptedMessage, err := DecryptPbkdf2Aes256(password, cipher)
if err != nil {
t.Fatalf("DecryptPbkdf2Aes256 failed: %v", err)
}
if decryptedMessage != message {
t.Fatalf("decrypted message does not match original message")
}
}
func TestDecryptWithInvalidBase64InputFailure(t *testing.T) {
password := "your_password"
invalidCipher := "invalid_base64_encoded_cipher"
_, err := DecryptPbkdf2Aes256(password, invalidCipher)
if err == nil {
t.Fatal("Expected DecryptPbkdf2Aes256 to fail, but it succeeded")
}
if !strings.HasPrefix(err.Error(), "base64.StdEncoding.DecodeString() failed") {
t.Fatal("DecryptPbkdf2Aes256 failed with unexpected error")
}
}
func TestDecryptWithInputShorterThanExpectedFailure(t *testing.T) {
password := "your_password"
invalidCipher := base64.StdEncoding.EncodeToString([]byte("ABCDEFGH"))
_, err := DecryptPbkdf2Aes256(password, invalidCipher)
if err == nil {
t.Fatal("Expected DecryptPbkdf2Aes256 to fail, but it succeeded")
}
if !strings.HasPrefix(err.Error(), "invalid input cipherText size") {
t.Fatal("DecryptPbkdf2Aes256 failed with unexpected error")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment