feat: HKDF per-wallet key derivation for wallet encryption

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-23 10:39:46 +09:00
parent 81214d42e5
commit d79156a1d7

View File

@@ -4,17 +4,20 @@ import (
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
"crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log" "log"
"strconv"
"sync" "sync"
"time" "time"
"github.com/tolelom/tolchain/core" "github.com/tolelom/tolchain/core"
tocrypto "github.com/tolelom/tolchain/crypto" tocrypto "github.com/tolelom/tolchain/crypto"
"github.com/tolelom/tolchain/wallet" "github.com/tolelom/tolchain/wallet"
"golang.org/x/crypto/hkdf"
) )
type Service struct { type Service struct {
@@ -93,6 +96,16 @@ func NewService(
// ---- Wallet Encryption (AES-256-GCM) ---- // ---- Wallet Encryption (AES-256-GCM) ----
func (s *Service) derivePerWalletKey(salt []byte, userID uint) ([]byte, error) {
info := []byte("wallet:" + strconv.FormatUint(uint64(userID), 10))
r := hkdf.New(sha256.New, s.encKeyBytes, salt, info)
key := make([]byte, 32)
if _, err := io.ReadFull(r, key); err != nil {
return nil, fmt.Errorf("HKDF key derivation failed: %w", err)
}
return key, nil
}
func (s *Service) encryptPrivKey(privKey tocrypto.PrivateKey) (cipherHex, nonceHex string, err error) { func (s *Service) encryptPrivKey(privKey tocrypto.PrivateKey) (cipherHex, nonceHex string, err error) {
block, err := aes.NewCipher(s.encKeyBytes) block, err := aes.NewCipher(s.encKeyBytes)
if err != nil { if err != nil {
@@ -134,6 +147,63 @@ func (s *Service) decryptPrivKey(cipherHex, nonceHex string) (tocrypto.PrivateKe
return tocrypto.PrivateKey(plaintext), nil return tocrypto.PrivateKey(plaintext), nil
} }
func (s *Service) encryptPrivKeyV2(privKey tocrypto.PrivateKey, userID uint) (cipherHex, nonceHex, saltHex string, err error) {
salt := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
return "", "", "", err
}
key, err := s.derivePerWalletKey(salt, userID)
if err != nil {
return "", "", "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", "", "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", "", "", err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", "", "", err
}
cipherText := gcm.Seal(nil, nonce, []byte(privKey), nil)
return hex.EncodeToString(cipherText), hex.EncodeToString(nonce), hex.EncodeToString(salt), nil
}
func (s *Service) decryptPrivKeyV2(cipherHex, nonceHex, saltHex string, userID uint) (tocrypto.PrivateKey, error) {
cipherText, err := hex.DecodeString(cipherHex)
if err != nil {
return nil, err
}
nonce, err := hex.DecodeString(nonceHex)
if err != nil {
return nil, err
}
salt, err := hex.DecodeString(saltHex)
if err != nil {
return nil, err
}
key, err := s.derivePerWalletKey(salt, userID)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
plaintext, err := gcm.Open(nil, nonce, cipherText, nil)
if err != nil {
return nil, fmt.Errorf("wallet decryption failed: %w", err)
}
return tocrypto.PrivateKey(plaintext), nil
}
// ---- Wallet Management ---- // ---- Wallet Management ----
// CreateWallet generates a new keypair, encrypts it, and stores in DB. // CreateWallet generates a new keypair, encrypts it, and stores in DB.
@@ -142,18 +212,18 @@ func (s *Service) CreateWallet(userID uint) (*UserWallet, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("key generation failed: %w", err) return nil, fmt.Errorf("key generation failed: %w", err)
} }
cipherHex, nonceHex, saltHex, err := s.encryptPrivKeyV2(w.PrivKey(), userID)
cipherHex, nonceHex, err := s.encryptPrivKey(w.PrivKey())
if err != nil { if err != nil {
return nil, fmt.Errorf("key encryption failed: %w", err) return nil, fmt.Errorf("key encryption failed: %w", err)
} }
uw := &UserWallet{ uw := &UserWallet{
UserID: userID, UserID: userID,
PubKeyHex: w.PubKey(), PubKeyHex: w.PubKey(),
Address: w.Address(), Address: w.Address(),
EncryptedPrivKey: cipherHex, EncryptedPrivKey: cipherHex,
EncNonce: nonceHex, EncNonce: nonceHex,
KeyVersion: 2,
HKDFSalt: saltHex,
} }
if err := s.repo.Create(uw); err != nil { if err := s.repo.Create(uw); err != nil {
return nil, fmt.Errorf("wallet save failed: %w", err) return nil, fmt.Errorf("wallet save failed: %w", err)
@@ -171,7 +241,12 @@ func (s *Service) loadUserWallet(userID uint) (*wallet.Wallet, string, error) {
if err != nil { if err != nil {
return nil, "", fmt.Errorf("wallet not found: %w", err) return nil, "", fmt.Errorf("wallet not found: %w", err)
} }
privKey, err := s.decryptPrivKey(uw.EncryptedPrivKey, uw.EncNonce) var privKey tocrypto.PrivateKey
if uw.KeyVersion >= 2 {
privKey, err = s.decryptPrivKeyV2(uw.EncryptedPrivKey, uw.EncNonce, uw.HKDFSalt, uw.UserID)
} else {
privKey, err = s.decryptPrivKey(uw.EncryptedPrivKey, uw.EncNonce)
}
if err != nil { if err != nil {
log.Printf("WARNING: wallet decryption failed for userID=%d: %v", userID, err) log.Printf("WARNING: wallet decryption failed for userID=%d: %v", userID, err)
return nil, "", fmt.Errorf("wallet decryption failed") return nil, "", fmt.Errorf("wallet decryption failed")