From d79156a1d721adbf0b1db652725cdcfa039ac404 Mon Sep 17 00:00:00 2001 From: tolelom <98kimsungmin@naver.com> Date: Mon, 23 Mar 2026 10:39:46 +0900 Subject: [PATCH] feat: HKDF per-wallet key derivation for wallet encryption Co-Authored-By: Claude Sonnet 4.6 --- internal/chain/service.go | 83 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 79 insertions(+), 4 deletions(-) diff --git a/internal/chain/service.go b/internal/chain/service.go index 111128d..cefb060 100644 --- a/internal/chain/service.go +++ b/internal/chain/service.go @@ -4,17 +4,20 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" + "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "io" "log" + "strconv" "sync" "time" "github.com/tolelom/tolchain/core" tocrypto "github.com/tolelom/tolchain/crypto" "github.com/tolelom/tolchain/wallet" + "golang.org/x/crypto/hkdf" ) type Service struct { @@ -93,6 +96,16 @@ func NewService( // ---- Wallet Encryption (AES-256-GCM) ---- +func (s *Service) derivePerWalletKey(salt []byte, userID uint) ([]byte, error) { + info := []byte("wallet:" + strconv.FormatUint(uint64(userID), 10)) + r := hkdf.New(sha256.New, s.encKeyBytes, salt, info) + key := make([]byte, 32) + if _, err := io.ReadFull(r, key); err != nil { + return nil, fmt.Errorf("HKDF key derivation failed: %w", err) + } + return key, nil +} + func (s *Service) encryptPrivKey(privKey tocrypto.PrivateKey) (cipherHex, nonceHex string, err error) { block, err := aes.NewCipher(s.encKeyBytes) if err != nil { @@ -134,6 +147,63 @@ func (s *Service) decryptPrivKey(cipherHex, nonceHex string) (tocrypto.PrivateKe return tocrypto.PrivateKey(plaintext), nil } +func (s *Service) encryptPrivKeyV2(privKey tocrypto.PrivateKey, userID uint) (cipherHex, nonceHex, saltHex string, err error) { + salt := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, salt); err != nil { + return "", "", "", err + } + key, err := s.derivePerWalletKey(salt, userID) + if err != nil { + return "", "", "", err + } + block, err := aes.NewCipher(key) + if err != nil { + return "", "", "", err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", "", "", err + } + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", "", "", err + } + cipherText := gcm.Seal(nil, nonce, []byte(privKey), nil) + return hex.EncodeToString(cipherText), hex.EncodeToString(nonce), hex.EncodeToString(salt), nil +} + +func (s *Service) decryptPrivKeyV2(cipherHex, nonceHex, saltHex string, userID uint) (tocrypto.PrivateKey, error) { + cipherText, err := hex.DecodeString(cipherHex) + if err != nil { + return nil, err + } + nonce, err := hex.DecodeString(nonceHex) + if err != nil { + return nil, err + } + salt, err := hex.DecodeString(saltHex) + if err != nil { + return nil, err + } + key, err := s.derivePerWalletKey(salt, userID) + if err != nil { + return nil, err + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + plaintext, err := gcm.Open(nil, nonce, cipherText, nil) + if err != nil { + return nil, fmt.Errorf("wallet decryption failed: %w", err) + } + return tocrypto.PrivateKey(plaintext), nil +} + // ---- Wallet Management ---- // CreateWallet generates a new keypair, encrypts it, and stores in DB. @@ -142,18 +212,18 @@ func (s *Service) CreateWallet(userID uint) (*UserWallet, error) { if err != nil { return nil, fmt.Errorf("key generation failed: %w", err) } - - cipherHex, nonceHex, err := s.encryptPrivKey(w.PrivKey()) + cipherHex, nonceHex, saltHex, err := s.encryptPrivKeyV2(w.PrivKey(), userID) if err != nil { return nil, fmt.Errorf("key encryption failed: %w", err) } - uw := &UserWallet{ UserID: userID, PubKeyHex: w.PubKey(), Address: w.Address(), EncryptedPrivKey: cipherHex, EncNonce: nonceHex, + KeyVersion: 2, + HKDFSalt: saltHex, } if err := s.repo.Create(uw); err != nil { return nil, fmt.Errorf("wallet save failed: %w", err) @@ -171,7 +241,12 @@ func (s *Service) loadUserWallet(userID uint) (*wallet.Wallet, string, error) { if err != nil { return nil, "", fmt.Errorf("wallet not found: %w", err) } - privKey, err := s.decryptPrivKey(uw.EncryptedPrivKey, uw.EncNonce) + var privKey tocrypto.PrivateKey + if uw.KeyVersion >= 2 { + privKey, err = s.decryptPrivKeyV2(uw.EncryptedPrivKey, uw.EncNonce, uw.HKDFSalt, uw.UserID) + } else { + privKey, err = s.decryptPrivKey(uw.EncryptedPrivKey, uw.EncNonce) + } if err != nil { log.Printf("WARNING: wallet decryption failed for userID=%d: %v", userID, err) return nil, "", fmt.Errorf("wallet decryption failed")