Compare commits

...

5 Commits

Author SHA1 Message Date
0cd0d2a402 feat: wallet private key export API with password verification
All checks were successful
Server CI/CD / lint-and-build (push) Successful in 39s
Server CI/CD / deploy (push) Successful in 52s
2026-03-23 10:52:27 +09:00
10a3f0156b feat: v1→v2 wallet key migration on server startup
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-23 10:45:06 +09:00
3a75f64d44 test: HKDF per-wallet encryption unit tests 2026-03-23 10:42:19 +09:00
d79156a1d7 feat: HKDF per-wallet key derivation for wallet encryption
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-23 10:39:46 +09:00
81214d42e5 feat: add key_version and hkdf_salt columns to UserWallet 2026-03-23 10:36:50 +09:00
8 changed files with 264 additions and 12 deletions

View File

@@ -535,6 +535,18 @@ func sanitizeForUsername(s string) string {
// If these fail, the admin user exists without a wallet/profile. // If these fail, the admin user exists without a wallet/profile.
// This is acceptable because EnsureAdmin runs once at startup and failures // This is acceptable because EnsureAdmin runs once at startup and failures
// are logged as warnings. A restart will skip user creation (already exists). // are logged as warnings. A restart will skip user creation (already exists).
// VerifyPassword checks if the password matches the user's stored hash.
func (s *Service) VerifyPassword(userID uint, password string) error {
user, err := s.repo.FindByID(userID)
if err != nil {
return fmt.Errorf("user not found")
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
return fmt.Errorf("invalid password")
}
return nil
}
func (s *Service) EnsureAdmin(username, password string) error { func (s *Service) EnsureAdmin(username, password string) error {
if _, err := s.repo.FindByUsername(username); err == nil { if _, err := s.repo.FindByUsername(username); err == nil {
return nil return nil

View File

@@ -3,6 +3,7 @@ package chain
import ( import (
"errors" "errors"
"log" "log"
"log/slog"
"strconv" "strconv"
"strings" "strings"
@@ -620,6 +621,39 @@ func (h *Handler) RegisterTemplate(c *fiber.Ctx) error {
return c.Status(fiber.StatusCreated).JSON(result) return c.Status(fiber.StatusCreated).JSON(result)
} }
// ExportWallet godoc
// @Summary 개인키 내보내기
// @Description 비밀번호 확인 후 현재 유저의 지갑 개인키를 반환합니다
// @Tags Chain
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param body body exportRequest true "비밀번호"
// @Success 200 {object} map[string]string
// @Failure 400 {object} docs.ErrorResponse
// @Failure 401 {object} docs.ErrorResponse
// @Router /api/chain/wallet/export [post]
type exportRequest struct {
Password string `json:"password"`
}
func (h *Handler) ExportWallet(c *fiber.Ctx) error {
userID, err := getUserID(c)
if err != nil {
return err
}
var req exportRequest
if err := c.BodyParser(&req); err != nil || req.Password == "" {
return c.Status(400).JSON(fiber.Map{"error": "password is required"})
}
slog.Warn("wallet export requested", "userID", userID, "ip", c.IP())
privKeyHex, err := h.svc.ExportPrivKey(userID, req.Password)
if err != nil {
return c.Status(401).JSON(fiber.Map{"error": "invalid password"})
}
return c.JSON(fiber.Map{"privateKey": privKeyHex})
}
// ---- Internal Handlers (game server, username-based) ---- // ---- Internal Handlers (game server, username-based) ----
// InternalGrantReward godoc // InternalGrantReward godoc

View File

@@ -17,4 +17,6 @@ type UserWallet struct {
Address string `json:"address" gorm:"type:varchar(40);uniqueIndex;not null"` Address string `json:"address" gorm:"type:varchar(40);uniqueIndex;not null"`
EncryptedPrivKey string `json:"-" gorm:"type:varchar(512);not null"` EncryptedPrivKey string `json:"-" gorm:"type:varchar(512);not null"`
EncNonce string `json:"-" gorm:"type:varchar(48);not null"` EncNonce string `json:"-" gorm:"type:varchar(48);not null"`
KeyVersion int `json:"-" gorm:"type:tinyint;default:1;not null"`
HKDFSalt string `json:"-" gorm:"type:varchar(32)"` // 16 bytes hex, nullable for v1
} }

View File

@@ -29,3 +29,22 @@ func (r *Repository) FindByPubKeyHex(pubKeyHex string) (*UserWallet, error) {
} }
return &w, nil return &w, nil
} }
// FindAllByKeyVersion returns all wallets with the given key version.
func (r *Repository) FindAllByKeyVersion(version int) ([]UserWallet, error) {
var wallets []UserWallet
if err := r.db.Where("key_version = ?", version).Find(&wallets).Error; err != nil {
return nil, err
}
return wallets, nil
}
// UpdateEncryption updates the encryption fields of a wallet.
func (r *Repository) UpdateEncryption(id uint, encPrivKey, encNonce, hkdfSalt string, keyVersion int) error {
return r.db.Model(&UserWallet{}).Where("id = ?", id).Updates(map[string]any{
"encrypted_priv_key": encPrivKey,
"enc_nonce": encNonce,
"hkdf_salt": hkdfSalt,
"key_version": keyVersion,
}).Error
}

View File

@@ -4,28 +4,32 @@ import (
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
"crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log" "log"
"strconv"
"sync" "sync"
"time" "time"
"github.com/tolelom/tolchain/core" "github.com/tolelom/tolchain/core"
tocrypto "github.com/tolelom/tolchain/crypto" tocrypto "github.com/tolelom/tolchain/crypto"
"github.com/tolelom/tolchain/wallet" "github.com/tolelom/tolchain/wallet"
"golang.org/x/crypto/hkdf"
) )
type Service struct { type Service struct {
repo *Repository repo *Repository
client *Client client *Client
chainID string chainID string
operatorWallet *wallet.Wallet operatorWallet *wallet.Wallet
encKeyBytes []byte // 32-byte AES-256 key encKeyBytes []byte // 32-byte AES-256 key
userResolver func(username string) (uint, error) userResolver func(username string) (uint, error)
operatorMu sync.Mutex // serialises operator-nonce transactions passwordVerifier func(userID uint, password string) error
userMu sync.Map // per-user mutex (keyed by userID uint) operatorMu sync.Mutex // serialises operator-nonce transactions
userMu sync.Map // per-user mutex (keyed by userID uint)
} }
// SetUserResolver sets the callback that resolves username → userID. // SetUserResolver sets the callback that resolves username → userID.
@@ -33,6 +37,24 @@ func (s *Service) SetUserResolver(fn func(username string) (uint, error)) {
s.userResolver = fn s.userResolver = fn
} }
func (s *Service) SetPasswordVerifier(fn func(userID uint, password string) error) {
s.passwordVerifier = fn
}
func (s *Service) ExportPrivKey(userID uint, password string) (string, error) {
if s.passwordVerifier == nil {
return "", fmt.Errorf("password verifier not configured")
}
if err := s.passwordVerifier(userID, password); err != nil {
return "", err
}
w, _, err := s.loadUserWallet(userID)
if err != nil {
return "", err
}
return w.PrivKey().Hex(), nil
}
// resolveUsername converts a username to the user's on-chain pubKeyHex. // resolveUsername converts a username to the user's on-chain pubKeyHex.
// If the user exists but has no wallet (e.g. legacy user or failed creation), // If the user exists but has no wallet (e.g. legacy user or failed creation),
// a wallet is auto-created on the fly. // a wallet is auto-created on the fly.
@@ -93,6 +115,16 @@ func NewService(
// ---- Wallet Encryption (AES-256-GCM) ---- // ---- Wallet Encryption (AES-256-GCM) ----
func (s *Service) derivePerWalletKey(salt []byte, userID uint) ([]byte, error) {
info := []byte("wallet:" + strconv.FormatUint(uint64(userID), 10))
r := hkdf.New(sha256.New, s.encKeyBytes, salt, info)
key := make([]byte, 32)
if _, err := io.ReadFull(r, key); err != nil {
return nil, fmt.Errorf("HKDF key derivation failed: %w", err)
}
return key, nil
}
func (s *Service) encryptPrivKey(privKey tocrypto.PrivateKey) (cipherHex, nonceHex string, err error) { func (s *Service) encryptPrivKey(privKey tocrypto.PrivateKey) (cipherHex, nonceHex string, err error) {
block, err := aes.NewCipher(s.encKeyBytes) block, err := aes.NewCipher(s.encKeyBytes)
if err != nil { if err != nil {
@@ -134,6 +166,101 @@ func (s *Service) decryptPrivKey(cipherHex, nonceHex string) (tocrypto.PrivateKe
return tocrypto.PrivateKey(plaintext), nil return tocrypto.PrivateKey(plaintext), nil
} }
func (s *Service) encryptPrivKeyV2(privKey tocrypto.PrivateKey, userID uint) (cipherHex, nonceHex, saltHex string, err error) {
salt := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
return "", "", "", err
}
key, err := s.derivePerWalletKey(salt, userID)
if err != nil {
return "", "", "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", "", "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", "", "", err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", "", "", err
}
cipherText := gcm.Seal(nil, nonce, []byte(privKey), nil)
return hex.EncodeToString(cipherText), hex.EncodeToString(nonce), hex.EncodeToString(salt), nil
}
func (s *Service) decryptPrivKeyV2(cipherHex, nonceHex, saltHex string, userID uint) (tocrypto.PrivateKey, error) {
cipherText, err := hex.DecodeString(cipherHex)
if err != nil {
return nil, err
}
nonce, err := hex.DecodeString(nonceHex)
if err != nil {
return nil, err
}
salt, err := hex.DecodeString(saltHex)
if err != nil {
return nil, err
}
key, err := s.derivePerWalletKey(salt, userID)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
plaintext, err := gcm.Open(nil, nonce, cipherText, nil)
if err != nil {
return nil, fmt.Errorf("wallet decryption failed: %w", err)
}
return tocrypto.PrivateKey(plaintext), nil
}
// ---- Wallet Migration ----
// MigrateWalletKeys re-encrypts all v1 wallets using HKDF per-wallet keys.
// Each wallet is migrated individually; failures are logged and skipped.
func (s *Service) MigrateWalletKeys() error {
wallets, err := s.repo.FindAllByKeyVersion(1)
if err != nil {
return fmt.Errorf("query v1 wallets: %w", err)
}
if len(wallets) == 0 {
return nil
}
log.Printf("INFO: migrating %d v1 wallets to v2 (HKDF)", len(wallets))
var migrated, failed int
for _, uw := range wallets {
privKey, err := s.decryptPrivKey(uw.EncryptedPrivKey, uw.EncNonce)
if err != nil {
log.Printf("ERROR: v1 decrypt failed for walletID=%d userID=%d: %v", uw.ID, uw.UserID, err)
failed++
continue
}
cipherHex, nonceHex, saltHex, err := s.encryptPrivKeyV2(privKey, uw.UserID)
if err != nil {
log.Printf("ERROR: v2 encrypt failed for walletID=%d userID=%d: %v", uw.ID, uw.UserID, err)
failed++
continue
}
if err := s.repo.UpdateEncryption(uw.ID, cipherHex, nonceHex, saltHex, 2); err != nil {
log.Printf("ERROR: DB update failed for walletID=%d userID=%d: %v", uw.ID, uw.UserID, err)
failed++
continue
}
migrated++
}
log.Printf("INFO: wallet migration complete: %d migrated, %d failed", migrated, failed)
return nil
}
// ---- Wallet Management ---- // ---- Wallet Management ----
// CreateWallet generates a new keypair, encrypts it, and stores in DB. // CreateWallet generates a new keypair, encrypts it, and stores in DB.
@@ -142,18 +269,18 @@ func (s *Service) CreateWallet(userID uint) (*UserWallet, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("key generation failed: %w", err) return nil, fmt.Errorf("key generation failed: %w", err)
} }
cipherHex, nonceHex, saltHex, err := s.encryptPrivKeyV2(w.PrivKey(), userID)
cipherHex, nonceHex, err := s.encryptPrivKey(w.PrivKey())
if err != nil { if err != nil {
return nil, fmt.Errorf("key encryption failed: %w", err) return nil, fmt.Errorf("key encryption failed: %w", err)
} }
uw := &UserWallet{ uw := &UserWallet{
UserID: userID, UserID: userID,
PubKeyHex: w.PubKey(), PubKeyHex: w.PubKey(),
Address: w.Address(), Address: w.Address(),
EncryptedPrivKey: cipherHex, EncryptedPrivKey: cipherHex,
EncNonce: nonceHex, EncNonce: nonceHex,
KeyVersion: 2,
HKDFSalt: saltHex,
} }
if err := s.repo.Create(uw); err != nil { if err := s.repo.Create(uw); err != nil {
return nil, fmt.Errorf("wallet save failed: %w", err) return nil, fmt.Errorf("wallet save failed: %w", err)
@@ -171,7 +298,12 @@ func (s *Service) loadUserWallet(userID uint) (*wallet.Wallet, string, error) {
if err != nil { if err != nil {
return nil, "", fmt.Errorf("wallet not found: %w", err) return nil, "", fmt.Errorf("wallet not found: %w", err)
} }
privKey, err := s.decryptPrivKey(uw.EncryptedPrivKey, uw.EncNonce) var privKey tocrypto.PrivateKey
if uw.KeyVersion >= 2 {
privKey, err = s.decryptPrivKeyV2(uw.EncryptedPrivKey, uw.EncNonce, uw.HKDFSalt, uw.UserID)
} else {
privKey, err = s.decryptPrivKey(uw.EncryptedPrivKey, uw.EncNonce)
}
if err != nil { if err != nil {
log.Printf("WARNING: wallet decryption failed for userID=%d: %v", userID, err) log.Printf("WARNING: wallet decryption failed for userID=%d: %v", userID, err)
return nil, "", fmt.Errorf("wallet decryption failed") return nil, "", fmt.Errorf("wallet decryption failed")

View File

@@ -0,0 +1,46 @@
package chain
import (
"testing"
tocrypto "github.com/tolelom/tolchain/crypto"
)
func TestEncryptDecryptV2_Roundtrip(t *testing.T) {
s := newTestService()
priv, _, err := tocrypto.GenerateKeyPair()
if err != nil {
t.Fatal(err)
}
cipherHex, nonceHex, saltHex, err := s.encryptPrivKeyV2(priv, 42)
if err != nil {
t.Fatal(err)
}
got, err := s.decryptPrivKeyV2(cipherHex, nonceHex, saltHex, 42)
if err != nil {
t.Fatal(err)
}
if got.Hex() != priv.Hex() {
t.Errorf("roundtrip mismatch: got %s, want %s", got.Hex(), priv.Hex())
}
}
func TestDecryptV2_WrongUserID_Fails(t *testing.T) {
s := newTestService()
priv, _, _ := tocrypto.GenerateKeyPair()
cipherHex, nonceHex, saltHex, _ := s.encryptPrivKeyV2(priv, 42)
_, err := s.decryptPrivKeyV2(cipherHex, nonceHex, saltHex, 99)
if err == nil {
t.Error("expected error for wrong userID")
}
}
func TestV1V2_DifferentCiphertext(t *testing.T) {
s := newTestService()
priv, _, _ := tocrypto.GenerateKeyPair()
v1cipher, _, _ := s.encryptPrivKey(priv)
v2cipher, _, _, _ := s.encryptPrivKeyV2(priv, 1)
if v1cipher == v2cipher {
t.Error("v1 and v2 should produce different ciphertext")
}
}

View File

@@ -75,6 +75,11 @@ func main() {
} }
chainHandler := chain.NewHandler(chainSvc) chainHandler := chain.NewHandler(chainSvc)
// Migrate v1 wallets to v2 (HKDF per-wallet keys)
if err := chainSvc.MigrateWalletKeys(); err != nil {
log.Fatalf("wallet key migration failed: %v", err)
}
userResolver := func(username string) (uint, error) { userResolver := func(username string) (uint, error) {
user, err := authRepo.FindByUsername(username) user, err := authRepo.FindByUsername(username)
if err != nil { if err != nil {
@@ -88,6 +93,7 @@ func main() {
_, err := chainSvc.CreateWallet(userID) _, err := chainSvc.CreateWallet(userID)
return err return err
}) })
chainSvc.SetPasswordVerifier(authSvc.VerifyPassword)
playerRepo := player.NewRepository(db) playerRepo := player.NewRepository(db)
playerSvc := player.NewService(playerRepo) playerSvc := player.NewService(playerRepo)

View File

@@ -113,6 +113,7 @@ func Register(
// Chain - Queries (authenticated) // Chain - Queries (authenticated)
ch := api.Group("/chain", authMw) ch := api.Group("/chain", authMw)
ch.Get("/wallet", chainH.GetWalletInfo) ch.Get("/wallet", chainH.GetWalletInfo)
ch.Post("/wallet/export", chainH.ExportWallet)
ch.Get("/balance", chainH.GetBalance) ch.Get("/balance", chainH.GetBalance)
ch.Get("/assets", chainH.GetAssets) ch.Get("/assets", chainH.GetAssets)
ch.Get("/asset/:id", chainH.GetAsset) ch.Get("/asset/:id", chainH.GetAsset)