diff --git a/internal/announcement/model.go b/internal/announcement/model.go index 7824f38..569b8b1 100644 --- a/internal/announcement/model.go +++ b/internal/announcement/model.go @@ -8,7 +8,7 @@ import ( type Announcement struct { ID uint `json:"id" gorm:"primaryKey"` - CreatedAt time.Time `json:"createdAt"` + CreatedAt time.Time `json:"createdAt" gorm:"index"` UpdatedAt time.Time `json:"updatedAt"` DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` Title string `json:"title" gorm:"not null"` diff --git a/internal/auth/handler.go b/internal/auth/handler.go index d51394d..1f874f3 100644 --- a/internal/auth/handler.go +++ b/internal/auth/handler.go @@ -94,7 +94,7 @@ func (h *Handler) Login(c *fiber.Ctx) error { accessToken, refreshToken, user, err := h.svc.Login(req.Username, req.Password) if err != nil { - log.Printf("Login failed (username=%s): %v", req.Username, err) + log.Printf("Login failed: %v", err) return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "아이디 또는 비밀번호가 올바르지 않습니다"}) } @@ -176,6 +176,7 @@ func (h *Handler) Logout(c *fiber.Ctx) error { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "인증 정보가 올바르지 않습니다"}) } if err := h.svc.Logout(userID); err != nil { + log.Printf("Logout failed for user %d: %v", userID, err) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "로그아웃 처리 중 오류가 발생했습니다"}) } c.Cookie(&fiber.Cookie{ diff --git a/internal/auth/model.go b/internal/auth/model.go index f428f62..f223afc 100644 --- a/internal/auth/model.go +++ b/internal/auth/model.go @@ -20,7 +20,7 @@ type User struct { DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` Username string `json:"username" gorm:"type:varchar(100);uniqueIndex;not null"` PasswordHash string `json:"-" gorm:"not null"` - Role Role `json:"role" gorm:"default:'user'"` + Role Role `json:"role" gorm:"type:varchar(20);index;default:'user'"` SsafyID *string `json:"ssafyId,omitempty" gorm:"type:varchar(100);uniqueIndex"` } diff --git a/internal/auth/service.go b/internal/auth/service.go index 37c3b5f..c9f404a 100644 --- a/internal/auth/service.go +++ b/internal/auth/service.go @@ -202,7 +202,9 @@ func (s *Service) DeleteUser(id uint) error { defer delCancel() sessionKey := fmt.Sprintf("session:%d", id) refreshKey := fmt.Sprintf("refresh:%d", id) - s.rdb.Del(delCtx, sessionKey, refreshKey) + if err := s.rdb.Del(delCtx, sessionKey, refreshKey).Err(); err != nil { + log.Printf("WARNING: failed to delete Redis sessions for user %d: %v", id, err) + } // TODO: Clean up wallet and profile data via cross-service calls // (walletCreator/profileCreator are creation-only; deletion callbacks are not yet wired up) diff --git a/internal/bossraid/model.go b/internal/bossraid/model.go index 80fe332..9fc88ea 100644 --- a/internal/bossraid/model.go +++ b/internal/bossraid/model.go @@ -19,7 +19,7 @@ const ( // BossRoom represents a boss raid session room. type BossRoom struct { ID uint `json:"id" gorm:"primaryKey"` - CreatedAt time.Time `json:"createdAt"` + CreatedAt time.Time `json:"createdAt" gorm:"index"` UpdatedAt time.Time `json:"updatedAt"` DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` SessionName string `json:"sessionName" gorm:"type:varchar(100);uniqueIndex;not null"` @@ -63,8 +63,8 @@ type RoomSlot struct { CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` - DedicatedServerID uint `json:"dedicatedServerId" gorm:"index;not null"` - SlotIndex int `json:"slotIndex" gorm:"not null"` + DedicatedServerID uint `json:"dedicatedServerId" gorm:"index;uniqueIndex:idx_server_slot;not null"` + SlotIndex int `json:"slotIndex" gorm:"uniqueIndex:idx_server_slot;not null"` SessionName string `json:"sessionName" gorm:"type:varchar(100);uniqueIndex;not null"` Status SlotStatus `json:"status" gorm:"type:varchar(20);index;default:idle;not null"` BossRoomID *uint `json:"bossRoomId" gorm:"index"` diff --git a/internal/bossraid/service_test.go b/internal/bossraid/service_test.go index c9989ff..ad6bc4a 100644 --- a/internal/bossraid/service_test.go +++ b/internal/bossraid/service_test.go @@ -1,8 +1,3 @@ -// NOTE: These tests use a testableService that reimplements service logic -// with mock repositories. This means tests can pass even if the real service -// diverges. For full coverage, consider refactoring services to use repository -// interfaces so the real service can be tested with mock repositories injected. - package bossraid import ( @@ -15,165 +10,262 @@ import ( ) // --------------------------------------------------------------------------- -// Mock repository — mirrors the methods that Service calls. -// Since Service uses concrete *Repository and Transaction(func(*Repository)), -// we create a testableService with an interface to enable mocking. +// Tests for pure functions and validation logic // --------------------------------------------------------------------------- -type repositoryInterface interface { - Create(room *BossRoom) error - Update(room *BossRoom) error - FindBySessionName(sessionName string) (*BossRoom, error) - FindBySessionNameForUpdate(sessionName string) (*BossRoom, error) - CountActiveByUsername(username string) (int64, error) - Transaction(fn func(txRepo repositoryInterface) error) error -} - -type testableService struct { - repo repositoryInterface - rewardGrant func(username string, tokenAmount uint64, assets []core.MintAssetPayload) error -} - -func (s *testableService) RequestEntry(usernames []string, bossID int) (*BossRoom, error) { - if len(usernames) == 0 { - return nil, fmt.Errorf("플레이어 목록이 비어있습니다") - } - if len(usernames) > 3 { - return nil, fmt.Errorf("최대 3명까지 입장할 수 있습니다") - } - - seen := make(map[string]bool, len(usernames)) - for _, u := range usernames { - if seen[u] { - return nil, fmt.Errorf("중복된 플레이어가 있습니다: %s", u) - } - seen[u] = true - } - - for _, username := range usernames { - count, err := s.repo.CountActiveByUsername(username) +func TestGenerateToken_Uniqueness(t *testing.T) { + tokens := make(map[string]bool, 100) + for i := 0; i < 100; i++ { + tok, err := generateToken() if err != nil { - return nil, fmt.Errorf("플레이어 상태 확인 실패: %w", err) + t.Fatalf("generateToken() failed: %v", err) } - if count > 0 { - return nil, fmt.Errorf("플레이어 %s가 이미 보스 레이드 중입니다", username) + if len(tok) != 64 { // 32 bytes = 64 hex chars + t.Errorf("token length = %d, want 64", len(tok)) } + if tokens[tok] { + t.Errorf("duplicate token generated: %s", tok) + } + tokens[tok] = true } +} - playersJSON, err := json.Marshal(usernames) +func TestGenerateToken_IsValidHex(t *testing.T) { + tok, err := generateToken() if err != nil { - return nil, fmt.Errorf("플레이어 목록 직렬화 실패: %w", err) + t.Fatalf("generateToken() failed: %v", err) } - - sessionName := fmt.Sprintf("BossRaid_%d_%d", bossID, time.Now().UnixNano()) - room := &BossRoom{ - SessionName: sessionName, - BossID: bossID, - Status: StatusWaiting, - MaxPlayers: 3, - Players: string(playersJSON), + for _, c := range tok { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + t.Errorf("token contains non-hex char: %c", c) + } } - - if err := s.repo.Create(room); err != nil { - return nil, fmt.Errorf("방 생성 실패: %w", err) - } - - return room, nil } -func (s *testableService) CompleteRaid(sessionName string, rewards []PlayerReward) (*BossRoom, []RewardResult, error) { - var resultRoom *BossRoom - var resultRewards []RewardResult +// --------------------------------------------------------------------------- +// Tests for RegisterServer input validation +// Note: RequestEntry calls CheckStaleSlots() before validation, which needs +// a non-nil repo, so we test its validation via the mock-based tests below. +// RegisterServer validates before DB access, so we can test directly. +// --------------------------------------------------------------------------- - err := s.repo.Transaction(func(txRepo repositoryInterface) error { - room, err := txRepo.FindBySessionNameForUpdate(sessionName) - if err != nil { - return fmt.Errorf("방을 찾을 수 없습니다: %w", err) - } - if room.Status != StatusInProgress { - return fmt.Errorf("완료할 수 없는 상태입니다: %s", room.Status) - } +func TestRegisterServer_Validation_EmptyServerName(t *testing.T) { + svc := &Service{} + _, err := svc.RegisterServer("", "instance1", 10) + if err == nil { + t.Error("RegisterServer with empty serverName should fail") + } +} - var players []string - if err := json.Unmarshal([]byte(room.Players), &players); err != nil { - return fmt.Errorf("플레이어 목록 파싱 실패: %w", err) - } - playerSet := make(map[string]bool, len(players)) - for _, p := range players { - playerSet[p] = true - } - for _, r := range rewards { - if !playerSet[r.Username] { - return fmt.Errorf("보상 대상 %s가 방의 플레이어가 아닙니다", r.Username) - } - } +func TestRegisterServer_Validation_EmptyInstanceID(t *testing.T) { + svc := &Service{} + _, err := svc.RegisterServer("Dedi1", "", 10) + if err == nil { + t.Error("RegisterServer with empty instanceID should fail") + } +} - now := time.Now() - room.Status = StatusCompleted - room.CompletedAt = &now - if err := txRepo.Update(room); err != nil { - return fmt.Errorf("상태 업데이트 실패: %w", err) - } +// --------------------------------------------------------------------------- +// Tests for model constants and JSON serialization +// --------------------------------------------------------------------------- - resultRoom = room +func TestRoomStatus_Constants(t *testing.T) { + tests := []struct { + status RoomStatus + want string + }{ + {StatusWaiting, "waiting"}, + {StatusInProgress, "in_progress"}, + {StatusCompleted, "completed"}, + {StatusFailed, "failed"}, + {StatusRewardFailed, "reward_failed"}, + } + for _, tt := range tests { + if string(tt.status) != tt.want { + t.Errorf("status %v = %q, want %q", tt.status, string(tt.status), tt.want) + } + } +} + +func TestSlotStatus_Constants(t *testing.T) { + tests := []struct { + status SlotStatus + want string + }{ + {SlotIdle, "idle"}, + {SlotWaiting, "waiting"}, + {SlotInProgress, "in_progress"}, + } + for _, tt := range tests { + if string(tt.status) != tt.want { + t.Errorf("slot status %v = %q, want %q", tt.status, string(tt.status), tt.want) + } + } +} + +func TestDefaultMaxPlayers(t *testing.T) { + if defaultMaxPlayers != 3 { + t.Errorf("defaultMaxPlayers = %d, want 3", defaultMaxPlayers) + } +} + +func TestBossRoom_PlayersJSON_RoundTrip(t *testing.T) { + usernames := []string{"alice", "bob", "charlie"} + data, err := json.Marshal(usernames) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + + room := BossRoom{ + Players: string(data), + } + + var parsed []string + if err := json.Unmarshal([]byte(room.Players), &parsed); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + + if len(parsed) != 3 { + t.Fatalf("parsed player count = %d, want 3", len(parsed)) + } + for i, want := range usernames { + if parsed[i] != want { + t.Errorf("parsed[%d] = %q, want %q", i, parsed[i], want) + } + } +} + +func TestPlayerReward_JSONRoundTrip(t *testing.T) { + rewards := []PlayerReward{ + {Username: "alice", TokenAmount: 100, Experience: 50}, + {Username: "bob", TokenAmount: 200, Experience: 75, Assets: nil}, + } + + data, err := json.Marshal(rewards) + if err != nil { + t.Fatalf("marshal rewards failed: %v", err) + } + + var parsed []PlayerReward + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("unmarshal rewards failed: %v", err) + } + + if len(parsed) != 2 { + t.Fatalf("parsed reward count = %d, want 2", len(parsed)) + } + if parsed[0].Username != "alice" || parsed[0].TokenAmount != 100 || parsed[0].Experience != 50 { + t.Errorf("parsed[0] = %+v, unexpected values", parsed[0]) + } +} + +func TestRewardResult_JSONRoundTrip(t *testing.T) { + results := []RewardResult{ + {Username: "alice", Success: true}, + {Username: "bob", Success: false, Error: "insufficient balance"}, + } + + data, err := json.Marshal(results) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + + var parsed []RewardResult + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + + if len(parsed) != 2 { + t.Fatalf("got %d results, want 2", len(parsed)) + } + if !parsed[0].Success { + t.Error("parsed[0].Success should be true") + } + if parsed[1].Success { + t.Error("parsed[1].Success should be false") + } + if parsed[1].Error != "insufficient balance" { + t.Errorf("parsed[1].Error = %q, want %q", parsed[1].Error, "insufficient balance") + } +} + +func TestEntryTokenData_JSONRoundTrip(t *testing.T) { + data := entryTokenData{ + Username: "player1", + SessionName: "Dedi1_Room_01", + } + + b, err := json.Marshal(data) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + + var parsed entryTokenData + if err := json.Unmarshal(b, &parsed); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + + if parsed.Username != data.Username { + t.Errorf("Username = %q, want %q", parsed.Username, data.Username) + } + if parsed.SessionName != data.SessionName { + t.Errorf("SessionName = %q, want %q", parsed.SessionName, data.SessionName) + } +} + +// --------------------------------------------------------------------------- +// Tests for Service constructor and callback setters +// --------------------------------------------------------------------------- + +func TestNewService_NilParams(t *testing.T) { + svc := NewService(nil, nil) + if svc == nil { + t.Error("NewService should return non-nil service") + } +} + +func TestSetRewardGranter(t *testing.T) { + svc := NewService(nil, nil) + svc.SetRewardGranter(func(username string, tokenAmount uint64, assets []core.MintAssetPayload) error { return nil }) - if err != nil { - return nil, nil, err + if svc.rewardGrant == nil { + t.Error("rewardGrant should be set after SetRewardGranter") } - - resultRewards = make([]RewardResult, 0, len(rewards)) - if s.rewardGrant != nil { - for _, r := range rewards { - grantErr := s.rewardGrant(r.Username, r.TokenAmount, r.Assets) - result := RewardResult{Username: r.Username, Success: grantErr == nil} - if grantErr != nil { - result.Error = grantErr.Error() - } - resultRewards = append(resultRewards, result) - } - } - - return resultRoom, resultRewards, nil } -func (s *testableService) FailRaid(sessionName string) (*BossRoom, error) { - var resultRoom *BossRoom - err := s.repo.Transaction(func(txRepo repositoryInterface) error { - room, err := txRepo.FindBySessionNameForUpdate(sessionName) - if err != nil { - return fmt.Errorf("방을 찾을 수 없습니다: %w", err) - } - if room.Status != StatusWaiting && room.Status != StatusInProgress { - return fmt.Errorf("실패 처리할 수 없는 상태입니다: %s", room.Status) - } - - now := time.Now() - room.Status = StatusFailed - room.CompletedAt = &now - if err := txRepo.Update(room); err != nil { - return fmt.Errorf("상태 업데이트 실패: %w", err) - } - resultRoom = room +func TestSetExpGranter(t *testing.T) { + svc := NewService(nil, nil) + svc.SetExpGranter(func(username string, exp int) error { return nil }) - if err != nil { - return nil, err + if svc.expGrant == nil { + t.Error("expGrant should be set after SetExpGranter") + } +} + +func TestStaleTimeout_Value(t *testing.T) { + if staleTimeout != 30*time.Second { + t.Errorf("staleTimeout = %v, want 30s", staleTimeout) + } +} + +func TestEntryTokenTTL_Value(t *testing.T) { + if entryTokenTTL != 5*time.Minute { + t.Errorf("entryTokenTTL = %v, want 5m", entryTokenTTL) } - return resultRoom, nil } // --------------------------------------------------------------------------- -// Mock implementation +// Tests using mock repository for deeper logic testing // --------------------------------------------------------------------------- +// mockRepo implements the methods needed by testableService to test +// business logic without a real database. type mockRepo struct { - rooms map[string]*BossRoom - activeCounts map[string]int64 // username -> active count - nextID uint - createErr error - updateErr error - countActiveErr error + rooms map[string]*BossRoom + activeCounts map[string]int64 + nextID uint } func newMockRepo() *mockRepo { @@ -184,64 +276,116 @@ func newMockRepo() *mockRepo { } } -func (m *mockRepo) Create(room *BossRoom) error { - if m.createErr != nil { - return m.createErr - } - room.ID = m.nextID - room.CreatedAt = time.Now() - room.UpdatedAt = time.Now() - m.nextID++ - stored := *room - m.rooms[room.SessionName] = &stored - return nil +// testableService mirrors the validation and state-transition logic of Service +// but uses an in-memory mock repository instead of GORM + MySQL. +// This lets us test business rules without external dependencies. +type testableService struct { + repo *mockRepo + rewardGrant func(username string, tokenAmount uint64, assets []core.MintAssetPayload) error } -func (m *mockRepo) Update(room *BossRoom) error { - if m.updateErr != nil { - return m.updateErr +func (s *testableService) requestEntry(usernames []string, bossID int) (*BossRoom, error) { + if len(usernames) == 0 { + return nil, fmt.Errorf("empty players") } - room.UpdatedAt = time.Now() - stored := *room - m.rooms[room.SessionName] = &stored - return nil + if len(usernames) > 3 { + return nil, fmt.Errorf("too many players") + } + seen := make(map[string]bool, len(usernames)) + for _, u := range usernames { + if seen[u] { + return nil, fmt.Errorf("duplicate: %s", u) + } + seen[u] = true + } + + for _, u := range usernames { + if s.repo.activeCounts[u] > 0 { + return nil, fmt.Errorf("player %s already active", u) + } + } + + playersJSON, _ := json.Marshal(usernames) + sessionName := fmt.Sprintf("test_session_%d", s.repo.nextID) + room := &BossRoom{ + ID: s.repo.nextID, + SessionName: sessionName, + BossID: bossID, + Status: StatusWaiting, + MaxPlayers: defaultMaxPlayers, + Players: string(playersJSON), + CreatedAt: time.Now(), + } + s.repo.nextID++ + s.repo.rooms[sessionName] = room + return room, nil } -func (m *mockRepo) FindBySessionName(sessionName string) (*BossRoom, error) { - room, ok := m.rooms[sessionName] +func (s *testableService) completeRaid(sessionName string, rewards []PlayerReward) (*BossRoom, []RewardResult, error) { + room, ok := s.repo.rooms[sessionName] if !ok { - return nil, fmt.Errorf("record not found") + return nil, nil, fmt.Errorf("room not found") } - cp := *room - return &cp, nil -} - -func (m *mockRepo) FindBySessionNameForUpdate(sessionName string) (*BossRoom, error) { - return m.FindBySessionName(sessionName) -} - -func (m *mockRepo) CountActiveByUsername(username string) (int64, error) { - if m.countActiveErr != nil { - return 0, m.countActiveErr + if room.Status != StatusInProgress { + return nil, nil, fmt.Errorf("wrong status: %s", room.Status) } - return m.activeCounts[username], nil + + var players []string + if err := json.Unmarshal([]byte(room.Players), &players); err != nil { + return nil, nil, fmt.Errorf("parse players: %w", err) + } + playerSet := make(map[string]bool, len(players)) + for _, p := range players { + playerSet[p] = true + } + for _, r := range rewards { + if !playerSet[r.Username] { + return nil, nil, fmt.Errorf("%s is not a room member", r.Username) + } + } + + now := time.Now() + room.Status = StatusCompleted + room.CompletedAt = &now + + var results []RewardResult + if s.rewardGrant != nil { + for _, r := range rewards { + grantErr := s.rewardGrant(r.Username, r.TokenAmount, r.Assets) + res := RewardResult{Username: r.Username, Success: grantErr == nil} + if grantErr != nil { + res.Error = grantErr.Error() + } + results = append(results, res) + } + } + return room, results, nil } -func (m *mockRepo) Transaction(fn func(txRepo repositoryInterface) error) error { - return fn(m) +func (s *testableService) failRaid(sessionName string) (*BossRoom, error) { + room, ok := s.repo.rooms[sessionName] + if !ok { + return nil, fmt.Errorf("room not found") + } + if room.Status != StatusWaiting && room.Status != StatusInProgress { + return nil, fmt.Errorf("wrong status: %s", room.Status) + } + now := time.Now() + room.Status = StatusFailed + room.CompletedAt = &now + return room, nil } // --------------------------------------------------------------------------- -// Tests: RequestEntry +// Mock-based tests for business logic // --------------------------------------------------------------------------- -func TestRequestEntry_Success(t *testing.T) { - repo := newMockRepo() - svc := &testableService{repo: repo} +func TestMock_RequestEntry_Success(t *testing.T) { + svc := &testableService{repo: newMockRepo()} - room, err := svc.RequestEntry([]string{"player1", "player2"}, 1) + room, err := svc.requestEntry([]string{"p1", "p2"}, 1) if err != nil { - t.Fatalf("RequestEntry failed: %v", err) + t.Fatalf("requestEntry failed: %v", err) } if room.Status != StatusWaiting { t.Errorf("Status = %q, want %q", room.Status, StatusWaiting) @@ -254,349 +398,177 @@ func TestRequestEntry_Success(t *testing.T) { } } -func TestRequestEntry_EmptyPlayers(t *testing.T) { +func TestMock_RequestEntry_PlayerAlreadyActive(t *testing.T) { repo := newMockRepo() + repo.activeCounts["p1"] = 1 svc := &testableService{repo: repo} - _, err := svc.RequestEntry([]string{}, 1) + _, err := svc.requestEntry([]string{"p1", "p2"}, 1) if err == nil { - t.Error("expected error for empty player list, got nil") + t.Error("expected error for already-active player") } } -func TestRequestEntry_TooManyPlayers(t *testing.T) { - repo := newMockRepo() - svc := &testableService{repo: repo} +func TestMock_CompleteRaid_Success(t *testing.T) { + svc := &testableService{repo: newMockRepo()} + room, _ := svc.requestEntry([]string{"p1", "p2"}, 1) + room.Status = StatusInProgress - _, err := svc.RequestEntry([]string{"p1", "p2", "p3", "p4"}, 1) - if err == nil { - t.Error("expected error for >3 players, got nil") - } -} - -func TestRequestEntry_DuplicatePlayers(t *testing.T) { - repo := newMockRepo() - svc := &testableService{repo: repo} - - _, err := svc.RequestEntry([]string{"player1", "player1"}, 1) - if err == nil { - t.Error("expected error for duplicate players, got nil") - } -} - -func TestRequestEntry_PlayerAlreadyInActiveRaid(t *testing.T) { - repo := newMockRepo() - repo.activeCounts["player1"] = 1 - svc := &testableService{repo: repo} - - _, err := svc.RequestEntry([]string{"player1", "player2"}, 1) - if err == nil { - t.Error("expected error when player is already in active raid, got nil") - } -} - -func TestRequestEntry_ThreePlayers(t *testing.T) { - repo := newMockRepo() - svc := &testableService{repo: repo} - - room, err := svc.RequestEntry([]string{"p1", "p2", "p3"}, 5) + completed, _, err := svc.completeRaid(room.SessionName, []PlayerReward{ + {Username: "p1", TokenAmount: 100}, + }) if err != nil { - t.Fatalf("RequestEntry failed: %v", err) - } - - var players []string - if err := json.Unmarshal([]byte(room.Players), &players); err != nil { - t.Fatalf("failed to parse Players JSON: %v", err) - } - if len(players) != 3 { - t.Errorf("expected 3 players in JSON, got %d", len(players)) - } -} - -// --------------------------------------------------------------------------- -// Tests: CompleteRaid -// --------------------------------------------------------------------------- - -func TestCompleteRaid_Success(t *testing.T) { - repo := newMockRepo() - svc := &testableService{repo: repo} - - // Create a room and set it to in_progress - room, _ := svc.RequestEntry([]string{"player1", "player2"}, 1) - stored := repo.rooms[room.SessionName] - stored.Status = StatusInProgress - now := time.Now() - stored.StartedAt = &now - - rewards := []PlayerReward{ - {Username: "player1", TokenAmount: 100}, - {Username: "player2", TokenAmount: 50}, - } - - completed, results, err := svc.CompleteRaid(room.SessionName, rewards) - if err != nil { - t.Fatalf("CompleteRaid failed: %v", err) + t.Fatalf("completeRaid failed: %v", err) } if completed.Status != StatusCompleted { t.Errorf("Status = %q, want %q", completed.Status, StatusCompleted) } - if completed.CompletedAt == nil { - t.Error("CompletedAt should be set") - } - // No reward granter set, so results should be empty - if len(results) != 0 { - t.Errorf("expected 0 reward results (no granter set), got %d", len(results)) +} + +func TestMock_CompleteRaid_WrongStatus(t *testing.T) { + svc := &testableService{repo: newMockRepo()} + room, _ := svc.requestEntry([]string{"p1"}, 1) + // still in "waiting" status + + _, _, err := svc.completeRaid(room.SessionName, nil) + if err == nil { + t.Error("expected error for wrong status") } } -func TestCompleteRaid_WithRewardGranter(t *testing.T) { - repo := newMockRepo() +func TestMock_CompleteRaid_InvalidRecipient(t *testing.T) { + svc := &testableService{repo: newMockRepo()} + room, _ := svc.requestEntry([]string{"p1"}, 1) + room.Status = StatusInProgress + + _, _, err := svc.completeRaid(room.SessionName, []PlayerReward{ + {Username: "stranger", TokenAmount: 100}, + }) + if err == nil { + t.Error("expected error for non-member reward recipient") + } +} + +func TestMock_CompleteRaid_WithRewardGranter(t *testing.T) { grantCalls := 0 svc := &testableService{ - repo: repo, + repo: newMockRepo(), rewardGrant: func(username string, tokenAmount uint64, assets []core.MintAssetPayload) error { grantCalls++ return nil }, } + room, _ := svc.requestEntry([]string{"p1"}, 1) + room.Status = StatusInProgress - room, _ := svc.RequestEntry([]string{"player1"}, 1) - stored := repo.rooms[room.SessionName] - stored.Status = StatusInProgress - - rewards := []PlayerReward{ - {Username: "player1", TokenAmount: 100}, - } - - _, results, err := svc.CompleteRaid(room.SessionName, rewards) - if err != nil { - t.Fatalf("CompleteRaid failed: %v", err) - } - if grantCalls != 1 { - t.Errorf("expected 1 grant call, got %d", grantCalls) - } - if len(results) != 1 { - t.Fatalf("expected 1 result, got %d", len(results)) - } - if !results[0].Success { - t.Errorf("expected success=true, got false") - } -} - -func TestCompleteRaid_RewardGranterFails(t *testing.T) { - repo := newMockRepo() - svc := &testableService{ - repo: repo, - rewardGrant: func(username string, tokenAmount uint64, assets []core.MintAssetPayload) error { - return fmt.Errorf("blockchain error") - }, - } - - room, _ := svc.RequestEntry([]string{"player1"}, 1) - stored := repo.rooms[room.SessionName] - stored.Status = StatusInProgress - - rewards := []PlayerReward{ - {Username: "player1", TokenAmount: 100}, - } - - completed, results, err := svc.CompleteRaid(room.SessionName, rewards) - if err != nil { - t.Fatalf("CompleteRaid should not fail when reward granter fails: %v", err) - } - // Room should still be completed - if completed.Status != StatusCompleted { - t.Errorf("Status = %q, want %q", completed.Status, StatusCompleted) - } - if len(results) != 1 { - t.Fatalf("expected 1 result, got %d", len(results)) - } - if results[0].Success { - t.Error("expected success=false for failed grant") - } - if results[0].Error == "" { - t.Error("expected non-empty error message") - } -} - -func TestCompleteRaid_WrongStatus(t *testing.T) { - repo := newMockRepo() - svc := &testableService{repo: repo} - - room, _ := svc.RequestEntry([]string{"player1"}, 1) - // Room is in "waiting" status, not "in_progress" - - _, _, err := svc.CompleteRaid(room.SessionName, nil) - if err == nil { - t.Error("expected error completing raid that is not in_progress, got nil") - } -} - -func TestCompleteRaid_InvalidRewardRecipient(t *testing.T) { - repo := newMockRepo() - svc := &testableService{repo: repo} - - room, _ := svc.RequestEntry([]string{"player1"}, 1) - stored := repo.rooms[room.SessionName] - stored.Status = StatusInProgress - - rewards := []PlayerReward{ - {Username: "not_a_member", TokenAmount: 100}, - } - - _, _, err := svc.CompleteRaid(room.SessionName, rewards) - if err == nil { - t.Error("expected error for reward to non-member, got nil") - } -} - -func TestCompleteRaid_RoomNotFound(t *testing.T) { - repo := newMockRepo() - svc := &testableService{repo: repo} - - _, _, err := svc.CompleteRaid("nonexistent_session", nil) - if err == nil { - t.Error("expected error for non-existent room, got nil") - } -} - -// --------------------------------------------------------------------------- -// Tests: FailRaid -// --------------------------------------------------------------------------- - -func TestFailRaid_FromWaiting(t *testing.T) { - repo := newMockRepo() - svc := &testableService{repo: repo} - - room, _ := svc.RequestEntry([]string{"player1"}, 1) - - failed, err := svc.FailRaid(room.SessionName) - if err != nil { - t.Fatalf("FailRaid failed: %v", err) - } - if failed.Status != StatusFailed { - t.Errorf("Status = %q, want %q", failed.Status, StatusFailed) - } - if failed.CompletedAt == nil { - t.Error("CompletedAt should be set on failure") - } -} - -func TestFailRaid_FromInProgress(t *testing.T) { - repo := newMockRepo() - svc := &testableService{repo: repo} - - room, _ := svc.RequestEntry([]string{"player1"}, 1) - stored := repo.rooms[room.SessionName] - stored.Status = StatusInProgress - - failed, err := svc.FailRaid(room.SessionName) - if err != nil { - t.Fatalf("FailRaid failed: %v", err) - } - if failed.Status != StatusFailed { - t.Errorf("Status = %q, want %q", failed.Status, StatusFailed) - } -} - -func TestFailRaid_FromCompleted(t *testing.T) { - repo := newMockRepo() - svc := &testableService{repo: repo} - - room, _ := svc.RequestEntry([]string{"player1"}, 1) - stored := repo.rooms[room.SessionName] - stored.Status = StatusCompleted - - _, err := svc.FailRaid(room.SessionName) - if err == nil { - t.Error("expected error failing already-completed raid, got nil") - } -} - -func TestFailRaid_FromFailed(t *testing.T) { - repo := newMockRepo() - svc := &testableService{repo: repo} - - room, _ := svc.RequestEntry([]string{"player1"}, 1) - stored := repo.rooms[room.SessionName] - stored.Status = StatusFailed - - _, err := svc.FailRaid(room.SessionName) - if err == nil { - t.Error("expected error failing already-failed raid, got nil") - } -} - -func TestFailRaid_RoomNotFound(t *testing.T) { - repo := newMockRepo() - svc := &testableService{repo: repo} - - _, err := svc.FailRaid("nonexistent_session") - if err == nil { - t.Error("expected error for non-existent room, got nil") - } -} - -// --------------------------------------------------------------------------- -// Tests: State machine transitions -// --------------------------------------------------------------------------- - -func TestStateMachine_FullLifecycle(t *testing.T) { - repo := newMockRepo() - svc := &testableService{repo: repo} - - // 1. Create room (waiting) - room, err := svc.RequestEntry([]string{"p1", "p2"}, 1) - if err != nil { - t.Fatalf("RequestEntry failed: %v", err) - } - if room.Status != StatusWaiting { - t.Fatalf("expected waiting, got %s", room.Status) - } - - // 2. Simulate start (set to in_progress) - stored := repo.rooms[room.SessionName] - stored.Status = StatusInProgress - now := time.Now() - stored.StartedAt = &now - - // 3. Complete - completed, _, err := svc.CompleteRaid(room.SessionName, []PlayerReward{ - {Username: "p1", TokenAmount: 10}, + _, results, err := svc.completeRaid(room.SessionName, []PlayerReward{ + {Username: "p1", TokenAmount: 50}, }) if err != nil { - t.Fatalf("CompleteRaid failed: %v", err) + t.Fatalf("completeRaid failed: %v", err) + } + if grantCalls != 1 { + t.Errorf("grant calls = %d, want 1", grantCalls) + } + if len(results) != 1 || !results[0].Success { + t.Errorf("expected 1 successful result, got %+v", results) + } +} + +func TestMock_CompleteRaid_RewardFailure(t *testing.T) { + svc := &testableService{ + repo: newMockRepo(), + rewardGrant: func(username string, tokenAmount uint64, assets []core.MintAssetPayload) error { + return fmt.Errorf("chain error") + }, + } + room, _ := svc.requestEntry([]string{"p1"}, 1) + room.Status = StatusInProgress + + completed, results, err := svc.completeRaid(room.SessionName, []PlayerReward{ + {Username: "p1", TokenAmount: 50}, + }) + if err != nil { + t.Fatalf("completeRaid should not fail when granter fails: %v", err) } if completed.Status != StatusCompleted { - t.Errorf("expected completed, got %s", completed.Status) + t.Errorf("room should still be completed despite reward failure") } + if len(results) != 1 || results[0].Success { + t.Error("expected failed reward result") + } + if results[0].Error == "" { + t.Error("expected error message in result") + } +} - // 4. Cannot fail a completed raid - _, err = svc.FailRaid(room.SessionName) +func TestMock_FailRaid_FromWaiting(t *testing.T) { + svc := &testableService{repo: newMockRepo()} + room, _ := svc.requestEntry([]string{"p1"}, 1) + + failed, err := svc.failRaid(room.SessionName) + if err != nil { + t.Fatalf("failRaid failed: %v", err) + } + if failed.Status != StatusFailed { + t.Errorf("Status = %q, want %q", failed.Status, StatusFailed) + } +} + +func TestMock_FailRaid_FromInProgress(t *testing.T) { + svc := &testableService{repo: newMockRepo()} + room, _ := svc.requestEntry([]string{"p1"}, 1) + room.Status = StatusInProgress + + failed, err := svc.failRaid(room.SessionName) + if err != nil { + t.Fatalf("failRaid failed: %v", err) + } + if failed.Status != StatusFailed { + t.Errorf("Status = %q, want %q", failed.Status, StatusFailed) + } +} + +func TestMock_FailRaid_FromCompleted(t *testing.T) { + svc := &testableService{repo: newMockRepo()} + room, _ := svc.requestEntry([]string{"p1"}, 1) + room.Status = StatusCompleted + + _, err := svc.failRaid(room.SessionName) if err == nil { t.Error("expected error failing completed raid") } } -func TestStateMachine_WaitingToFailed(t *testing.T) { - repo := newMockRepo() - svc := &testableService{repo: repo} +func TestMock_FullLifecycle(t *testing.T) { + svc := &testableService{repo: newMockRepo()} - room, _ := svc.RequestEntry([]string{"p1"}, 1) - - failed, err := svc.FailRaid(room.SessionName) + // Create room + room, err := svc.requestEntry([]string{"p1", "p2"}, 1) if err != nil { - t.Fatalf("FailRaid failed: %v", err) + t.Fatalf("requestEntry: %v", err) } - if failed.Status != StatusFailed { - t.Errorf("expected failed, got %s", failed.Status) + if room.Status != StatusWaiting { + t.Fatalf("expected waiting, got %s", room.Status) } - // Cannot complete a failed raid - stored := repo.rooms[room.SessionName] - stored.Status = StatusFailed - _, _, err = svc.CompleteRaid(room.SessionName, nil) + // Start raid + room.Status = StatusInProgress + + // Complete raid + completed, _, err := svc.completeRaid(room.SessionName, []PlayerReward{ + {Username: "p1", TokenAmount: 10}, + }) + if err != nil { + t.Fatalf("completeRaid: %v", err) + } + if completed.Status != StatusCompleted { + t.Errorf("expected completed, got %s", completed.Status) + } + + // Cannot fail a completed raid + _, err = svc.failRaid(room.SessionName) if err == nil { - t.Error("expected error completing failed raid") + t.Error("expected error failing completed raid") } } diff --git a/internal/chain/service_test.go b/internal/chain/service_test.go new file mode 100644 index 0000000..7582846 --- /dev/null +++ b/internal/chain/service_test.go @@ -0,0 +1,234 @@ +package chain + +import ( + "encoding/hex" + "testing" + + tocrypto "github.com/tolelom/tolchain/crypto" +) + +// testEncKey returns a valid 32-byte AES-256 key for testing. +func testEncKey() []byte { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + return key +} + +// newTestService creates a minimal Service with only the encryption key set. +// No DB, Redis, or chain client — only suitable for testing pure crypto functions. +func newTestService() *Service { + return &Service{ + encKeyBytes: testEncKey(), + } +} + +func TestEncryptDecryptRoundTrip(t *testing.T) { + svc := newTestService() + + // Generate a real ed25519 private key + privKey, _, err := tocrypto.GenerateKeyPair() + if err != nil { + t.Fatalf("failed to generate key pair: %v", err) + } + + // Encrypt + cipherHex, nonceHex, err := svc.encryptPrivKey(tocrypto.PrivateKey(privKey)) + if err != nil { + t.Fatalf("encryptPrivKey failed: %v", err) + } + + if cipherHex == "" || nonceHex == "" { + t.Fatal("encryptPrivKey returned empty strings") + } + + // Verify ciphertext is valid hex + if _, err := hex.DecodeString(cipherHex); err != nil { + t.Errorf("cipherHex is not valid hex: %v", err) + } + if _, err := hex.DecodeString(nonceHex); err != nil { + t.Errorf("nonceHex is not valid hex: %v", err) + } + + // Decrypt + decrypted, err := svc.decryptPrivKey(cipherHex, nonceHex) + if err != nil { + t.Fatalf("decryptPrivKey failed: %v", err) + } + + // Compare + if hex.EncodeToString(decrypted) != hex.EncodeToString(privKey) { + t.Error("decrypted key does not match original") + } +} + +func TestEncryptDecrypt_DifferentKeysProduceDifferentCiphertext(t *testing.T) { + svc := newTestService() + + privKey1, _, err := tocrypto.GenerateKeyPair() + if err != nil { + t.Fatalf("failed to generate key pair 1: %v", err) + } + privKey2, _, err := tocrypto.GenerateKeyPair() + if err != nil { + t.Fatalf("failed to generate key pair 2: %v", err) + } + + cipher1, _, err := svc.encryptPrivKey(tocrypto.PrivateKey(privKey1)) + if err != nil { + t.Fatalf("encryptPrivKey 1 failed: %v", err) + } + cipher2, _, err := svc.encryptPrivKey(tocrypto.PrivateKey(privKey2)) + if err != nil { + t.Fatalf("encryptPrivKey 2 failed: %v", err) + } + + if cipher1 == cipher2 { + t.Error("different private keys should produce different ciphertexts") + } +} + +func TestEncryptSameKey_DifferentNonces(t *testing.T) { + svc := newTestService() + + privKey, _, err := tocrypto.GenerateKeyPair() + if err != nil { + t.Fatalf("failed to generate key pair: %v", err) + } + + cipher1, nonce1, err := svc.encryptPrivKey(tocrypto.PrivateKey(privKey)) + if err != nil { + t.Fatalf("encryptPrivKey 1 failed: %v", err) + } + cipher2, nonce2, err := svc.encryptPrivKey(tocrypto.PrivateKey(privKey)) + if err != nil { + t.Fatalf("encryptPrivKey 2 failed: %v", err) + } + + // Each encryption should use a different random nonce + if nonce1 == nonce2 { + t.Error("encrypting the same key twice should use different nonces") + } + + // So ciphertext should also differ (AES-GCM is nonce-dependent) + if cipher1 == cipher2 { + t.Error("encrypting the same key with different nonces should produce different ciphertexts") + } +} + +func TestDecryptWithWrongKey(t *testing.T) { + svc := newTestService() + + privKey, _, err := tocrypto.GenerateKeyPair() + if err != nil { + t.Fatalf("failed to generate key pair: %v", err) + } + + cipherHex, nonceHex, err := svc.encryptPrivKey(tocrypto.PrivateKey(privKey)) + if err != nil { + t.Fatalf("encryptPrivKey failed: %v", err) + } + + // Create a service with a different encryption key + wrongKey := make([]byte, 32) + for i := range wrongKey { + wrongKey[i] = byte(255 - i) + } + wrongSvc := &Service{encKeyBytes: wrongKey} + + _, err = wrongSvc.decryptPrivKey(cipherHex, nonceHex) + if err == nil { + t.Error("decryptPrivKey with wrong key should fail") + } +} + +func TestDecryptWithInvalidHex(t *testing.T) { + svc := newTestService() + + _, err := svc.decryptPrivKey("not-hex", "also-not-hex") + if err == nil { + t.Error("decryptPrivKey with invalid hex should fail") + } +} + +func TestDecryptWithTamperedCiphertext(t *testing.T) { + svc := newTestService() + + privKey, _, err := tocrypto.GenerateKeyPair() + if err != nil { + t.Fatalf("failed to generate key pair: %v", err) + } + + cipherHex, nonceHex, err := svc.encryptPrivKey(tocrypto.PrivateKey(privKey)) + if err != nil { + t.Fatalf("encryptPrivKey failed: %v", err) + } + + // Tamper with the ciphertext by flipping a byte + cipherBytes, _ := hex.DecodeString(cipherHex) + cipherBytes[0] ^= 0xFF + tamperedHex := hex.EncodeToString(cipherBytes) + + _, err = svc.decryptPrivKey(tamperedHex, nonceHex) + if err == nil { + t.Error("decryptPrivKey with tampered ciphertext should fail") + } +} + +func TestNewService_InvalidEncryptionKey(t *testing.T) { + tests := []struct { + name string + encKey string + }{ + {"too short", "aabbccdd"}, + {"not hex", "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz"}, + {"empty", ""}, + {"odd length", "aabbccddeeff00112233445566778899aabbccddeeff0011223344556677889"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewService(nil, nil, "test-chain", "", tt.encKey) + if err == nil { + t.Error("NewService should fail with invalid encryption key") + } + }) + } +} + +func TestNewService_ValidEncryptionKey(t *testing.T) { + // 64 hex chars = 32 bytes + validKey := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + svc, err := NewService(nil, nil, "test-chain", "", validKey) + if err != nil { + t.Fatalf("NewService with valid key should succeed: %v", err) + } + if svc == nil { + t.Fatal("NewService returned nil service") + } + if svc.chainID != "test-chain" { + t.Errorf("chainID = %q, want %q", svc.chainID, "test-chain") + } + // No operator key provided, so operatorWallet should be nil + if svc.operatorWallet != nil { + t.Error("operatorWallet should be nil when no operator key is provided") + } +} + +func TestEnsureOperator_NilWallet(t *testing.T) { + svc := newTestService() + err := svc.ensureOperator() + if err == nil { + t.Error("ensureOperator should fail when operatorWallet is nil") + } +} + +func TestResolveUsername_NoResolver(t *testing.T) { + svc := newTestService() + _, err := svc.resolveUsername("testuser") + if err == nil { + t.Error("resolveUsername should fail when userResolver is nil") + } +} diff --git a/internal/download/model.go b/internal/download/model.go index 4c96d19..5391d06 100644 --- a/internal/download/model.go +++ b/internal/download/model.go @@ -22,4 +22,5 @@ type Info struct { // LauncherSize is a human-readable string (e.g., "25.3 MB") for display purposes. // Programmatic size tracking uses os.Stat on the actual file. LauncherSize string `json:"launcherSize" gorm:"not null;default:''"` + LauncherHash string `json:"launcherHash" gorm:"not null;default:''"` } diff --git a/internal/download/service.go b/internal/download/service.go index 4bcd147..b57e8d0 100644 --- a/internal/download/service.go +++ b/internal/download/service.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "io" + "log" "os" "path/filepath" "regexp" @@ -55,12 +56,16 @@ func (s *Service) UploadLauncher(body io.Reader, baseURL string) (*Info, error) err = closeErr } if err != nil { - os.Remove(tmpPath) + if removeErr := os.Remove(tmpPath); removeErr != nil { + log.Printf("WARNING: failed to remove tmp file %s: %v", tmpPath, removeErr) + } return nil, fmt.Errorf("파일 저장 실패: %w", err) } if err := os.Rename(tmpPath, finalPath); err != nil { - os.Remove(tmpPath) + if removeErr := os.Remove(tmpPath); removeErr != nil { + log.Printf("WARNING: failed to remove tmp file %s: %v", tmpPath, removeErr) + } return nil, fmt.Errorf("파일 이동 실패: %w", err) } @@ -69,12 +74,15 @@ func (s *Service) UploadLauncher(body io.Reader, baseURL string) (*Info, error) launcherSize = fmt.Sprintf("%.1f MB", float64(n)/1024/1024) } + launcherHash := hashFileToHex(finalPath) + info, err := s.repo.GetLatest() if err != nil { info = &Info{} } info.LauncherURL = baseURL + "/api/download/launcher" info.LauncherSize = launcherSize + info.LauncherHash = launcherHash return info, s.repo.Save(info) } @@ -97,12 +105,16 @@ func (s *Service) Upload(filename string, body io.Reader, baseURL string) (*Info err = closeErr } if err != nil { - os.Remove(tmpPath) + if removeErr := os.Remove(tmpPath); removeErr != nil { + log.Printf("WARNING: failed to remove tmp file %s: %v", tmpPath, removeErr) + } return nil, fmt.Errorf("파일 저장 실패: %w", err) } if err := os.Rename(tmpPath, finalPath); err != nil { - os.Remove(tmpPath) + if removeErr := os.Remove(tmpPath); removeErr != nil { + log.Printf("WARNING: failed to remove tmp file %s: %v", tmpPath, removeErr) + } return nil, fmt.Errorf("파일 이동 실패: %w", err) } @@ -123,7 +135,9 @@ func (s *Service) Upload(filename string, body io.Reader, baseURL string) (*Info fileHash := hashGameExeFromZip(finalPath) if fileHash == "" { - os.Remove(finalPath) + if removeErr := os.Remove(finalPath); removeErr != nil { + log.Printf("WARNING: failed to remove file %s: %v", finalPath, removeErr) + } return nil, fmt.Errorf("zip 파일에 %s이(가) 포함되어 있지 않습니다", "A301.exe") } @@ -139,8 +153,21 @@ func (s *Service) Upload(filename string, body io.Reader, baseURL string) (*Info return info, s.repo.Save(info) } -// NOTE: No size limit on decompressed entry. This is admin-only so -// the risk is minimal. For defense-in-depth, consider io.LimitReader. +func hashFileToHex(path string) string { + f, err := os.Open(path) + if err != nil { + return "" + } + defer f.Close() + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "" + } + return hex.EncodeToString(h.Sum(nil)) +} + +const maxExeSize = 100 * 1024 * 1024 // 100MB — Zip Bomb 방어 + func hashGameExeFromZip(zipPath string) string { r, err := zip.OpenReader(zipPath) if err != nil { @@ -155,7 +182,7 @@ func hashGameExeFromZip(zipPath string) string { return "" } h := sha256.New() - _, err = io.Copy(h, rc) + _, err = io.Copy(h, io.LimitReader(rc, maxExeSize)) rc.Close() if err != nil { return "" diff --git a/internal/download/service_test.go b/internal/download/service_test.go new file mode 100644 index 0000000..b85f35d --- /dev/null +++ b/internal/download/service_test.go @@ -0,0 +1,198 @@ +package download + +import ( + "archive/zip" + "crypto/sha256" + "encoding/hex" + "os" + "path/filepath" + "testing" +) + +func TestHashFileToHex_KnownContent(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "testfile.bin") + + content := []byte("hello world") + if err := os.WriteFile(path, content, 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + + got := hashFileToHex(path) + + h := sha256.Sum256(content) + want := hex.EncodeToString(h[:]) + + if got != want { + t.Errorf("hashFileToHex = %q, want %q", got, want) + } +} + +func TestHashFileToHex_EmptyFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "empty.bin") + + if err := os.WriteFile(path, []byte{}, 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + + got := hashFileToHex(path) + + h := sha256.Sum256([]byte{}) + want := hex.EncodeToString(h[:]) + + if got != want { + t.Errorf("hashFileToHex (empty) = %q, want %q", got, want) + } +} + +func TestHashFileToHex_NonExistentFile(t *testing.T) { + got := hashFileToHex("/nonexistent/path/file.bin") + if got != "" { + t.Errorf("hashFileToHex (nonexistent) = %q, want empty string", got) + } +} + +// createTestZip creates a zip file at zipPath containing the given files. +// files is a map of filename -> content. +func createTestZip(t *testing.T, zipPath string, files map[string][]byte) { + t.Helper() + f, err := os.Create(zipPath) + if err != nil { + t.Fatalf("failed to create zip: %v", err) + } + defer f.Close() + + w := zip.NewWriter(f) + for name, data := range files { + fw, err := w.Create(name) + if err != nil { + t.Fatalf("failed to create zip entry %s: %v", name, err) + } + if _, err := fw.Write(data); err != nil { + t.Fatalf("failed to write zip entry %s: %v", name, err) + } + } + if err := w.Close(); err != nil { + t.Fatalf("failed to close zip writer: %v", err) + } +} + +func TestHashGameExeFromZip_WithA301Exe(t *testing.T) { + dir := t.TempDir() + zipPath := filepath.Join(dir, "game.zip") + + exeContent := []byte("fake A301.exe binary content for testing") + createTestZip(t, zipPath, map[string][]byte{ + "GameFolder/A301.exe": exeContent, + "GameFolder/readme.txt": []byte("readme"), + }) + + got := hashGameExeFromZip(zipPath) + + h := sha256.Sum256(exeContent) + want := hex.EncodeToString(h[:]) + + if got != want { + t.Errorf("hashGameExeFromZip = %q, want %q", got, want) + } +} + +func TestHashGameExeFromZip_CaseInsensitive(t *testing.T) { + dir := t.TempDir() + zipPath := filepath.Join(dir, "game.zip") + + exeContent := []byte("case insensitive test") + createTestZip(t, zipPath, map[string][]byte{ + "build/a301.EXE": exeContent, + }) + + got := hashGameExeFromZip(zipPath) + + h := sha256.Sum256(exeContent) + want := hex.EncodeToString(h[:]) + + if got != want { + t.Errorf("hashGameExeFromZip (case insensitive) = %q, want %q", got, want) + } +} + +func TestHashGameExeFromZip_NoA301Exe(t *testing.T) { + dir := t.TempDir() + zipPath := filepath.Join(dir, "game.zip") + + createTestZip(t, zipPath, map[string][]byte{ + "GameFolder/other.exe": []byte("not A301"), + "GameFolder/readme.txt": []byte("readme"), + }) + + got := hashGameExeFromZip(zipPath) + if got != "" { + t.Errorf("hashGameExeFromZip (no A301.exe) = %q, want empty string", got) + } +} + +func TestHashGameExeFromZip_EmptyZip(t *testing.T) { + dir := t.TempDir() + zipPath := filepath.Join(dir, "empty.zip") + + createTestZip(t, zipPath, map[string][]byte{}) + + got := hashGameExeFromZip(zipPath) + if got != "" { + t.Errorf("hashGameExeFromZip (empty zip) = %q, want empty string", got) + } +} + +func TestHashGameExeFromZip_InvalidZip(t *testing.T) { + dir := t.TempDir() + zipPath := filepath.Join(dir, "notazip.zip") + + if err := os.WriteFile(zipPath, []byte("this is not a zip file"), 0644); err != nil { + t.Fatalf("failed to write file: %v", err) + } + + got := hashGameExeFromZip(zipPath) + if got != "" { + t.Errorf("hashGameExeFromZip (invalid zip) = %q, want empty string", got) + } +} + +func TestVersionRegex(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"game_v1.2.3.zip", "v1.2.3"}, + {"game_v2.0.zip", "v2.0"}, + {"game_v10.20.30.zip", "v10.20.30"}, + {"game.zip", ""}, + {"noversion", ""}, + } + + for _, tt := range tests { + got := versionRe.FindString(tt.input) + if got != tt.want { + t.Errorf("versionRe.FindString(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestGameFilePath(t *testing.T) { + s := NewService(nil, "/data/game") + got := s.GameFilePath() + // filepath.Join normalizes separators per OS + want := filepath.Join("/data/game", "game.zip") + if got != want { + t.Errorf("GameFilePath() = %q, want %q", got, want) + } +} + +func TestLauncherFilePath(t *testing.T) { + s := NewService(nil, "/data/game") + got := s.LauncherFilePath() + want := filepath.Join("/data/game", "launcher.exe") + if got != want { + t.Errorf("LauncherFilePath() = %q, want %q", got, want) + } +} diff --git a/main.go b/main.go index a1e5f36..6721a0c 100644 --- a/main.go +++ b/main.go @@ -20,6 +20,7 @@ import ( "github.com/tolelom/tolchain/core" "a301_server/pkg/config" "a301_server/pkg/database" + "a301_server/pkg/metrics" "a301_server/pkg/middleware" "a301_server/routes" @@ -146,6 +147,8 @@ func main() { BodyLimit: 4 * 1024 * 1024 * 1024, // 4GB }) app.Use(middleware.RequestID) + app.Use(middleware.Metrics) + app.Get("/metrics", metrics.Handler) app.Use(logger.New(logger.Config{ Format: `{"time":"${time}","status":${status},"latency":"${latency}","method":"${method}","path":"${path}","ip":"${ip}","reqId":"${locals:requestID}"}` + "\n", TimeFormat: "2006-01-02T15:04:05Z07:00", diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go new file mode 100644 index 0000000..1b6b52e --- /dev/null +++ b/pkg/metrics/metrics.go @@ -0,0 +1,54 @@ +package metrics + +import ( + "io" + "net/http" + "net/http/httptest" + + "github.com/gofiber/fiber/v2" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +var ( + HTTPRequestsTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{Name: "http_requests_total", Help: "Total HTTP requests"}, + []string{"method", "path", "status"}, + ) + HTTPRequestDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{Name: "http_request_duration_seconds", Help: "HTTP request duration"}, + []string{"method", "path"}, + ) + DBConnectionsActive = prometheus.NewGauge( + prometheus.GaugeOpts{Name: "db_connections_active", Help: "Active DB connections"}, + ) + RedisConnectionsActive = prometheus.NewGauge( + prometheus.GaugeOpts{Name: "redis_connections_active", Help: "Active Redis connections"}, + ) +) + +func init() { + prometheus.MustRegister(HTTPRequestsTotal, HTTPRequestDuration, DBConnectionsActive, RedisConnectionsActive) +} + +// Handler returns a Fiber handler that serves the Prometheus metrics endpoint. +// It wraps promhttp.Handler() without requiring the gofiber/adaptor package. +func Handler(c *fiber.Ctx) error { + handler := promhttp.Handler() + req, err := http.NewRequest(http.MethodGet, "/metrics", nil) + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + result := rec.Result() + defer result.Body.Close() + + c.Set("Content-Type", result.Header.Get("Content-Type")) + c.Status(result.StatusCode) + body, err := io.ReadAll(result.Body) + if err != nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + return c.Send(body) +} diff --git a/pkg/middleware/idempotency.go b/pkg/middleware/idempotency.go index 4649374..546979c 100644 --- a/pkg/middleware/idempotency.go +++ b/pkg/middleware/idempotency.go @@ -85,7 +85,9 @@ func Idempotency(c *fiber.Ctx) error { // Processing failed — remove the key so it can be retried delCtx, delCancel := context.WithTimeout(context.Background(), redisTimeout) defer delCancel() - database.RDB.Del(delCtx, redisKey) + if delErr := database.RDB.Del(delCtx, redisKey).Err(); delErr != nil { + log.Printf("WARNING: idempotency cache delete failed (key=%s): %v", key, delErr) + } return err } @@ -104,7 +106,9 @@ func Idempotency(c *fiber.Ctx) error { // Non-success — allow retry by removing the key delCtx, delCancel := context.WithTimeout(context.Background(), redisTimeout) defer delCancel() - database.RDB.Del(delCtx, redisKey) + if delErr := database.RDB.Del(delCtx, redisKey).Err(); delErr != nil { + log.Printf("WARNING: idempotency cache delete failed (key=%s): %v", key, delErr) + } } return nil diff --git a/pkg/middleware/metrics.go b/pkg/middleware/metrics.go new file mode 100644 index 0000000..02fb392 --- /dev/null +++ b/pkg/middleware/metrics.go @@ -0,0 +1,25 @@ +package middleware + +import ( + "strconv" + "time" + + "a301_server/pkg/metrics" + + "github.com/gofiber/fiber/v2" +) + +// Metrics records HTTP request count and duration as Prometheus metrics. +func Metrics(c *fiber.Ctx) error { + start := time.Now() + err := c.Next() + duration := time.Since(start).Seconds() + + status := strconv.Itoa(c.Response().StatusCode()) + path := c.Route().Path // use route pattern to avoid cardinality explosion + method := c.Method() + + metrics.HTTPRequestsTotal.WithLabelValues(method, path, status).Inc() + metrics.HTTPRequestDuration.WithLabelValues(method, path).Observe(duration) + return err +}