Compare commits
5 Commits
d46ba47c63
...
0cd0d2a402
| Author | SHA1 | Date | |
|---|---|---|---|
| 0cd0d2a402 | |||
| 10a3f0156b | |||
| 3a75f64d44 | |||
| d79156a1d7 | |||
| 81214d42e5 |
@@ -535,6 +535,18 @@ func sanitizeForUsername(s string) string {
|
||||
// If these fail, the admin user exists without a wallet/profile.
|
||||
// This is acceptable because EnsureAdmin runs once at startup and failures
|
||||
// are logged as warnings. A restart will skip user creation (already exists).
|
||||
// VerifyPassword checks if the password matches the user's stored hash.
|
||||
func (s *Service) VerifyPassword(userID uint, password string) error {
|
||||
user, err := s.repo.FindByID(userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("user not found")
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
|
||||
return fmt.Errorf("invalid password")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) EnsureAdmin(username, password string) error {
|
||||
if _, err := s.repo.FindByUsername(username); err == nil {
|
||||
return nil
|
||||
|
||||
@@ -3,6 +3,7 @@ package chain
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -620,6 +621,39 @@ func (h *Handler) RegisterTemplate(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusCreated).JSON(result)
|
||||
}
|
||||
|
||||
// ExportWallet godoc
|
||||
// @Summary 개인키 내보내기
|
||||
// @Description 비밀번호 확인 후 현재 유저의 지갑 개인키를 반환합니다
|
||||
// @Tags Chain
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param body body exportRequest true "비밀번호"
|
||||
// @Success 200 {object} map[string]string
|
||||
// @Failure 400 {object} docs.ErrorResponse
|
||||
// @Failure 401 {object} docs.ErrorResponse
|
||||
// @Router /api/chain/wallet/export [post]
|
||||
type exportRequest struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func (h *Handler) ExportWallet(c *fiber.Ctx) error {
|
||||
userID, err := getUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var req exportRequest
|
||||
if err := c.BodyParser(&req); err != nil || req.Password == "" {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "password is required"})
|
||||
}
|
||||
slog.Warn("wallet export requested", "userID", userID, "ip", c.IP())
|
||||
privKeyHex, err := h.svc.ExportPrivKey(userID, req.Password)
|
||||
if err != nil {
|
||||
return c.Status(401).JSON(fiber.Map{"error": "invalid password"})
|
||||
}
|
||||
return c.JSON(fiber.Map{"privateKey": privKeyHex})
|
||||
}
|
||||
|
||||
// ---- Internal Handlers (game server, username-based) ----
|
||||
|
||||
// InternalGrantReward godoc
|
||||
|
||||
@@ -17,4 +17,6 @@ type UserWallet struct {
|
||||
Address string `json:"address" gorm:"type:varchar(40);uniqueIndex;not null"`
|
||||
EncryptedPrivKey string `json:"-" gorm:"type:varchar(512);not null"`
|
||||
EncNonce string `json:"-" gorm:"type:varchar(48);not null"`
|
||||
KeyVersion int `json:"-" gorm:"type:tinyint;default:1;not null"`
|
||||
HKDFSalt string `json:"-" gorm:"type:varchar(32)"` // 16 bytes hex, nullable for v1
|
||||
}
|
||||
|
||||
@@ -29,3 +29,22 @@ func (r *Repository) FindByPubKeyHex(pubKeyHex string) (*UserWallet, error) {
|
||||
}
|
||||
return &w, nil
|
||||
}
|
||||
|
||||
// FindAllByKeyVersion returns all wallets with the given key version.
|
||||
func (r *Repository) FindAllByKeyVersion(version int) ([]UserWallet, error) {
|
||||
var wallets []UserWallet
|
||||
if err := r.db.Where("key_version = ?", version).Find(&wallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return wallets, nil
|
||||
}
|
||||
|
||||
// UpdateEncryption updates the encryption fields of a wallet.
|
||||
func (r *Repository) UpdateEncryption(id uint, encPrivKey, encNonce, hkdfSalt string, keyVersion int) error {
|
||||
return r.db.Model(&UserWallet{}).Where("id = ?", id).Updates(map[string]any{
|
||||
"encrypted_priv_key": encPrivKey,
|
||||
"enc_nonce": encNonce,
|
||||
"hkdf_salt": hkdfSalt,
|
||||
"key_version": keyVersion,
|
||||
}).Error
|
||||
}
|
||||
|
||||
@@ -4,17 +4,20 @@ import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tolelom/tolchain/core"
|
||||
tocrypto "github.com/tolelom/tolchain/crypto"
|
||||
"github.com/tolelom/tolchain/wallet"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
@@ -24,6 +27,7 @@ type Service struct {
|
||||
operatorWallet *wallet.Wallet
|
||||
encKeyBytes []byte // 32-byte AES-256 key
|
||||
userResolver func(username string) (uint, error)
|
||||
passwordVerifier func(userID uint, password string) error
|
||||
operatorMu sync.Mutex // serialises operator-nonce transactions
|
||||
userMu sync.Map // per-user mutex (keyed by userID uint)
|
||||
}
|
||||
@@ -33,6 +37,24 @@ func (s *Service) SetUserResolver(fn func(username string) (uint, error)) {
|
||||
s.userResolver = fn
|
||||
}
|
||||
|
||||
func (s *Service) SetPasswordVerifier(fn func(userID uint, password string) error) {
|
||||
s.passwordVerifier = fn
|
||||
}
|
||||
|
||||
func (s *Service) ExportPrivKey(userID uint, password string) (string, error) {
|
||||
if s.passwordVerifier == nil {
|
||||
return "", fmt.Errorf("password verifier not configured")
|
||||
}
|
||||
if err := s.passwordVerifier(userID, password); err != nil {
|
||||
return "", err
|
||||
}
|
||||
w, _, err := s.loadUserWallet(userID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return w.PrivKey().Hex(), nil
|
||||
}
|
||||
|
||||
// resolveUsername converts a username to the user's on-chain pubKeyHex.
|
||||
// If the user exists but has no wallet (e.g. legacy user or failed creation),
|
||||
// a wallet is auto-created on the fly.
|
||||
@@ -93,6 +115,16 @@ func NewService(
|
||||
|
||||
// ---- Wallet Encryption (AES-256-GCM) ----
|
||||
|
||||
func (s *Service) derivePerWalletKey(salt []byte, userID uint) ([]byte, error) {
|
||||
info := []byte("wallet:" + strconv.FormatUint(uint64(userID), 10))
|
||||
r := hkdf.New(sha256.New, s.encKeyBytes, salt, info)
|
||||
key := make([]byte, 32)
|
||||
if _, err := io.ReadFull(r, key); err != nil {
|
||||
return nil, fmt.Errorf("HKDF key derivation failed: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (s *Service) encryptPrivKey(privKey tocrypto.PrivateKey) (cipherHex, nonceHex string, err error) {
|
||||
block, err := aes.NewCipher(s.encKeyBytes)
|
||||
if err != nil {
|
||||
@@ -134,6 +166,101 @@ func (s *Service) decryptPrivKey(cipherHex, nonceHex string) (tocrypto.PrivateKe
|
||||
return tocrypto.PrivateKey(plaintext), nil
|
||||
}
|
||||
|
||||
func (s *Service) encryptPrivKeyV2(privKey tocrypto.PrivateKey, userID uint) (cipherHex, nonceHex, saltHex string, err error) {
|
||||
salt := make([]byte, 16)
|
||||
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
key, err := s.derivePerWalletKey(salt, userID)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
cipherText := gcm.Seal(nil, nonce, []byte(privKey), nil)
|
||||
return hex.EncodeToString(cipherText), hex.EncodeToString(nonce), hex.EncodeToString(salt), nil
|
||||
}
|
||||
|
||||
func (s *Service) decryptPrivKeyV2(cipherHex, nonceHex, saltHex string, userID uint) (tocrypto.PrivateKey, error) {
|
||||
cipherText, err := hex.DecodeString(cipherHex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nonce, err := hex.DecodeString(nonceHex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
salt, err := hex.DecodeString(saltHex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key, err := s.derivePerWalletKey(salt, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plaintext, err := gcm.Open(nil, nonce, cipherText, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wallet decryption failed: %w", err)
|
||||
}
|
||||
return tocrypto.PrivateKey(plaintext), nil
|
||||
}
|
||||
|
||||
// ---- Wallet Migration ----
|
||||
|
||||
// MigrateWalletKeys re-encrypts all v1 wallets using HKDF per-wallet keys.
|
||||
// Each wallet is migrated individually; failures are logged and skipped.
|
||||
func (s *Service) MigrateWalletKeys() error {
|
||||
wallets, err := s.repo.FindAllByKeyVersion(1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query v1 wallets: %w", err)
|
||||
}
|
||||
if len(wallets) == 0 {
|
||||
return nil
|
||||
}
|
||||
log.Printf("INFO: migrating %d v1 wallets to v2 (HKDF)", len(wallets))
|
||||
var migrated, failed int
|
||||
for _, uw := range wallets {
|
||||
privKey, err := s.decryptPrivKey(uw.EncryptedPrivKey, uw.EncNonce)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: v1 decrypt failed for walletID=%d userID=%d: %v", uw.ID, uw.UserID, err)
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
cipherHex, nonceHex, saltHex, err := s.encryptPrivKeyV2(privKey, uw.UserID)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: v2 encrypt failed for walletID=%d userID=%d: %v", uw.ID, uw.UserID, err)
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
if err := s.repo.UpdateEncryption(uw.ID, cipherHex, nonceHex, saltHex, 2); err != nil {
|
||||
log.Printf("ERROR: DB update failed for walletID=%d userID=%d: %v", uw.ID, uw.UserID, err)
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
migrated++
|
||||
}
|
||||
log.Printf("INFO: wallet migration complete: %d migrated, %d failed", migrated, failed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---- Wallet Management ----
|
||||
|
||||
// CreateWallet generates a new keypair, encrypts it, and stores in DB.
|
||||
@@ -142,18 +269,18 @@ func (s *Service) CreateWallet(userID uint) (*UserWallet, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("key generation failed: %w", err)
|
||||
}
|
||||
|
||||
cipherHex, nonceHex, err := s.encryptPrivKey(w.PrivKey())
|
||||
cipherHex, nonceHex, saltHex, err := s.encryptPrivKeyV2(w.PrivKey(), userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("key encryption failed: %w", err)
|
||||
}
|
||||
|
||||
uw := &UserWallet{
|
||||
UserID: userID,
|
||||
PubKeyHex: w.PubKey(),
|
||||
Address: w.Address(),
|
||||
EncryptedPrivKey: cipherHex,
|
||||
EncNonce: nonceHex,
|
||||
KeyVersion: 2,
|
||||
HKDFSalt: saltHex,
|
||||
}
|
||||
if err := s.repo.Create(uw); err != nil {
|
||||
return nil, fmt.Errorf("wallet save failed: %w", err)
|
||||
@@ -171,7 +298,12 @@ func (s *Service) loadUserWallet(userID uint) (*wallet.Wallet, string, error) {
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("wallet not found: %w", err)
|
||||
}
|
||||
privKey, err := s.decryptPrivKey(uw.EncryptedPrivKey, uw.EncNonce)
|
||||
var privKey tocrypto.PrivateKey
|
||||
if uw.KeyVersion >= 2 {
|
||||
privKey, err = s.decryptPrivKeyV2(uw.EncryptedPrivKey, uw.EncNonce, uw.HKDFSalt, uw.UserID)
|
||||
} else {
|
||||
privKey, err = s.decryptPrivKey(uw.EncryptedPrivKey, uw.EncNonce)
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("WARNING: wallet decryption failed for userID=%d: %v", userID, err)
|
||||
return nil, "", fmt.Errorf("wallet decryption failed")
|
||||
|
||||
46
internal/chain/service_encryption_test.go
Normal file
46
internal/chain/service_encryption_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package chain
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
tocrypto "github.com/tolelom/tolchain/crypto"
|
||||
)
|
||||
|
||||
func TestEncryptDecryptV2_Roundtrip(t *testing.T) {
|
||||
s := newTestService()
|
||||
priv, _, err := tocrypto.GenerateKeyPair()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cipherHex, nonceHex, saltHex, err := s.encryptPrivKeyV2(priv, 42)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got, err := s.decryptPrivKeyV2(cipherHex, nonceHex, saltHex, 42)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.Hex() != priv.Hex() {
|
||||
t.Errorf("roundtrip mismatch: got %s, want %s", got.Hex(), priv.Hex())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptV2_WrongUserID_Fails(t *testing.T) {
|
||||
s := newTestService()
|
||||
priv, _, _ := tocrypto.GenerateKeyPair()
|
||||
cipherHex, nonceHex, saltHex, _ := s.encryptPrivKeyV2(priv, 42)
|
||||
_, err := s.decryptPrivKeyV2(cipherHex, nonceHex, saltHex, 99)
|
||||
if err == nil {
|
||||
t.Error("expected error for wrong userID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestV1V2_DifferentCiphertext(t *testing.T) {
|
||||
s := newTestService()
|
||||
priv, _, _ := tocrypto.GenerateKeyPair()
|
||||
v1cipher, _, _ := s.encryptPrivKey(priv)
|
||||
v2cipher, _, _, _ := s.encryptPrivKeyV2(priv, 1)
|
||||
if v1cipher == v2cipher {
|
||||
t.Error("v1 and v2 should produce different ciphertext")
|
||||
}
|
||||
}
|
||||
6
main.go
6
main.go
@@ -75,6 +75,11 @@ func main() {
|
||||
}
|
||||
chainHandler := chain.NewHandler(chainSvc)
|
||||
|
||||
// Migrate v1 wallets to v2 (HKDF per-wallet keys)
|
||||
if err := chainSvc.MigrateWalletKeys(); err != nil {
|
||||
log.Fatalf("wallet key migration failed: %v", err)
|
||||
}
|
||||
|
||||
userResolver := func(username string) (uint, error) {
|
||||
user, err := authRepo.FindByUsername(username)
|
||||
if err != nil {
|
||||
@@ -88,6 +93,7 @@ func main() {
|
||||
_, err := chainSvc.CreateWallet(userID)
|
||||
return err
|
||||
})
|
||||
chainSvc.SetPasswordVerifier(authSvc.VerifyPassword)
|
||||
|
||||
playerRepo := player.NewRepository(db)
|
||||
playerSvc := player.NewService(playerRepo)
|
||||
|
||||
@@ -113,6 +113,7 @@ func Register(
|
||||
// Chain - Queries (authenticated)
|
||||
ch := api.Group("/chain", authMw)
|
||||
ch.Get("/wallet", chainH.GetWalletInfo)
|
||||
ch.Post("/wallet/export", chainH.ExportWallet)
|
||||
ch.Get("/balance", chainH.GetBalance)
|
||||
ch.Get("/assets", chainH.GetAssets)
|
||||
ch.Get("/asset/:id", chainH.GetAsset)
|
||||
|
||||
Reference in New Issue
Block a user