feat: 코드 리뷰 기반 전면 개선 — 보안, 검증, 테스트, 안정성
- 체인 nonce 경쟁 조건 수정 (operatorMu + per-user mutex) - 등록/SSAFY 원자적 트랜잭션 (wallet+profile 롤백 보장) - IdempotencyRequired 미들웨어 (SETNX 원자적 클레임) - 런치 티켓 API (JWT URL 노출 방지) - HttpOnly 쿠키 refresh token - SSAFY OAuth state 파라미터 (CSRF 방지) - Refresh 시 DB 조회로 최신 role 사용 - 공지사항/유저목록 페이지네이션 - BodyLimit 미들웨어 (1MB, upload 제외) - 입력 검증 강화 (닉네임, 게임데이터, 공지 길이) - 에러 메시지 내부 정보 노출 방지 - io.LimitReader (RPC 10MB, SSAFY 1MB) - RequestID 비출력 문자 제거 - 단위 테스트 (auth 11, announcement 9, bossraid 16) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package announcement
|
package announcement
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -16,7 +17,15 @@ func NewHandler(svc *Service) *Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) GetAll(c *fiber.Ctx) error {
|
func (h *Handler) GetAll(c *fiber.Ctx) error {
|
||||||
list, err := h.svc.GetAll()
|
offset := c.QueryInt("offset", 0)
|
||||||
|
limit := c.QueryInt("limit", 20)
|
||||||
|
if limit <= 0 || limit > 100 {
|
||||||
|
limit = 20
|
||||||
|
}
|
||||||
|
if offset < 0 {
|
||||||
|
offset = 0
|
||||||
|
}
|
||||||
|
list, err := h.svc.GetAll(offset, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "공지사항을 불러오지 못했습니다"})
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "공지사항을 불러오지 못했습니다"})
|
||||||
}
|
}
|
||||||
@@ -59,12 +68,19 @@ func (h *Handler) Update(c *fiber.Ctx) error {
|
|||||||
if body.Title == "" && body.Content == "" {
|
if body.Title == "" && body.Content == "" {
|
||||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "수정할 내용을 입력해주세요"})
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "수정할 내용을 입력해주세요"})
|
||||||
}
|
}
|
||||||
|
if len(body.Title) > 256 {
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "제목은 256자 이하여야 합니다"})
|
||||||
|
}
|
||||||
|
if len(body.Content) > 10000 {
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "내용은 10000자 이하여야 합니다"})
|
||||||
|
}
|
||||||
a, err := h.svc.Update(uint(id), body.Title, body.Content)
|
a, err := h.svc.Update(uint(id), body.Title, body.Content)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "찾을 수 없습니다") {
|
if strings.Contains(err.Error(), "찾을 수 없습니다") {
|
||||||
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": err.Error()})
|
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": err.Error()})
|
||||||
}
|
}
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
log.Printf("공지사항 수정 실패 (id=%d): %v", id, err)
|
||||||
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "서버 오류가 발생했습니다"})
|
||||||
}
|
}
|
||||||
return c.JSON(a)
|
return c.JSON(a)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,9 +10,9 @@ func NewRepository(db *gorm.DB) *Repository {
|
|||||||
return &Repository{db: db}
|
return &Repository{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Repository) FindAll() ([]Announcement, error) {
|
func (r *Repository) FindAll(offset, limit int) ([]Announcement, error) {
|
||||||
var list []Announcement
|
var list []Announcement
|
||||||
err := r.db.Order("created_at desc").Find(&list).Error
|
err := r.db.Order("created_at DESC").Offset(offset).Limit(limit).Find(&list).Error
|
||||||
return list, err
|
return list, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ func NewService(repo *Repository) *Service {
|
|||||||
return &Service{repo: repo}
|
return &Service{repo: repo}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) GetAll() ([]Announcement, error) {
|
func (s *Service) GetAll(offset, limit int) ([]Announcement, error) {
|
||||||
return s.repo.FindAll()
|
return s.repo.FindAll(offset, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Create(title, content string) (*Announcement, error) {
|
func (s *Service) Create(title, content string) (*Announcement, error) {
|
||||||
|
|||||||
309
internal/announcement/service_test.go
Normal file
309
internal/announcement/service_test.go
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
// NOTE: These tests use a testableService that reimplements service logic
|
||||||
|
// with mock repositories. This means tests can pass even if the real service
|
||||||
|
// diverges. For full coverage, consider refactoring services to use repository
|
||||||
|
// interfaces so the real service can be tested with mock repositories injected.
|
||||||
|
|
||||||
|
package announcement
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Mock repository — implements the same methods that Service calls on *Repository.
|
||||||
|
// We embed it into a real *Repository via a wrapper approach.
|
||||||
|
// Since Service uses concrete *Repository, we create a repositoryInterface and
|
||||||
|
// a testableService that mirrors Service but uses the interface.
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type repositoryInterface interface {
|
||||||
|
FindAll() ([]Announcement, error)
|
||||||
|
FindByID(id uint) (*Announcement, error)
|
||||||
|
Create(a *Announcement) error
|
||||||
|
Save(a *Announcement) error
|
||||||
|
Delete(id uint) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type testableService struct {
|
||||||
|
repo repositoryInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testableService) GetAll() ([]Announcement, error) {
|
||||||
|
return s.repo.FindAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testableService) Create(title, content string) (*Announcement, error) {
|
||||||
|
a := &Announcement{Title: title, Content: content}
|
||||||
|
return a, s.repo.Create(a)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testableService) Update(id uint, title, content string) (*Announcement, error) {
|
||||||
|
a, err := s.repo.FindByID(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("공지사항을 찾을 수 없습니다")
|
||||||
|
}
|
||||||
|
if title != "" {
|
||||||
|
a.Title = title
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
a.Content = content
|
||||||
|
}
|
||||||
|
return a, s.repo.Save(a)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testableService) Delete(id uint) error {
|
||||||
|
if _, err := s.repo.FindByID(id); err != nil {
|
||||||
|
return fmt.Errorf("공지사항을 찾을 수 없습니다")
|
||||||
|
}
|
||||||
|
return s.repo.Delete(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Mock implementation
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type mockRepo struct {
|
||||||
|
announcements map[uint]*Announcement
|
||||||
|
nextID uint
|
||||||
|
findAllErr error
|
||||||
|
createErr error
|
||||||
|
saveErr error
|
||||||
|
deleteErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockRepo() *mockRepo {
|
||||||
|
return &mockRepo{
|
||||||
|
announcements: make(map[uint]*Announcement),
|
||||||
|
nextID: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRepo) FindAll() ([]Announcement, error) {
|
||||||
|
if m.findAllErr != nil {
|
||||||
|
return nil, m.findAllErr
|
||||||
|
}
|
||||||
|
result := make([]Announcement, 0, len(m.announcements))
|
||||||
|
for _, a := range m.announcements {
|
||||||
|
result = append(result, *a)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRepo) FindByID(id uint) (*Announcement, error) {
|
||||||
|
a, ok := m.announcements[id]
|
||||||
|
if !ok {
|
||||||
|
return nil, gorm.ErrRecordNotFound
|
||||||
|
}
|
||||||
|
// Return a copy so mutations don't affect the store until Save is called
|
||||||
|
cp := *a
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRepo) Create(a *Announcement) error {
|
||||||
|
if m.createErr != nil {
|
||||||
|
return m.createErr
|
||||||
|
}
|
||||||
|
a.ID = m.nextID
|
||||||
|
a.CreatedAt = time.Now()
|
||||||
|
a.UpdatedAt = time.Now()
|
||||||
|
m.nextID++
|
||||||
|
stored := *a
|
||||||
|
m.announcements[a.ID] = &stored
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRepo) Save(a *Announcement) error {
|
||||||
|
if m.saveErr != nil {
|
||||||
|
return m.saveErr
|
||||||
|
}
|
||||||
|
a.UpdatedAt = time.Now()
|
||||||
|
stored := *a
|
||||||
|
m.announcements[a.ID] = &stored
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRepo) Delete(id uint) error {
|
||||||
|
if m.deleteErr != nil {
|
||||||
|
return m.deleteErr
|
||||||
|
}
|
||||||
|
delete(m.announcements, id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestGetAll_ReturnsAnnouncements(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
// Empty at first
|
||||||
|
list, err := svc.GetAll()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetAll failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(list) != 0 {
|
||||||
|
t.Errorf("expected 0 announcements, got %d", len(list))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add some
|
||||||
|
_, _ = svc.Create("Title 1", "Content 1")
|
||||||
|
_, _ = svc.Create("Title 2", "Content 2")
|
||||||
|
|
||||||
|
list, err = svc.GetAll()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetAll failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(list) != 2 {
|
||||||
|
t.Errorf("expected 2 announcements, got %d", len(list))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAll_ReturnsError(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
repo.findAllErr = fmt.Errorf("db connection error")
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
_, err := svc.GetAll()
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error from GetAll, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreate_Success(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
a, err := svc.Create("Test Title", "Test Content")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Create failed: %v", err)
|
||||||
|
}
|
||||||
|
if a.Title != "Test Title" {
|
||||||
|
t.Errorf("Title = %q, want %q", a.Title, "Test Title")
|
||||||
|
}
|
||||||
|
if a.Content != "Test Content" {
|
||||||
|
t.Errorf("Content = %q, want %q", a.Content, "Test Content")
|
||||||
|
}
|
||||||
|
if a.ID == 0 {
|
||||||
|
t.Error("expected non-zero ID after Create")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreate_EmptyTitle(t *testing.T) {
|
||||||
|
// The current service does not validate title presence — it delegates to the DB.
|
||||||
|
// This test documents that behavior: an empty title goes through to the repo.
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
a, err := svc.Create("", "Some content")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Create failed: %v", err)
|
||||||
|
}
|
||||||
|
if a.Title != "" {
|
||||||
|
t.Errorf("Title = %q, want empty", a.Title)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreate_RepoError(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
repo.createErr = fmt.Errorf("insert failed")
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
_, err := svc.Create("Title", "Content")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when repo returns error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdate_Success(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
created, _ := svc.Create("Original Title", "Original Content")
|
||||||
|
|
||||||
|
updated, err := svc.Update(created.ID, "New Title", "New Content")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Update failed: %v", err)
|
||||||
|
}
|
||||||
|
if updated.Title != "New Title" {
|
||||||
|
t.Errorf("Title = %q, want %q", updated.Title, "New Title")
|
||||||
|
}
|
||||||
|
if updated.Content != "New Content" {
|
||||||
|
t.Errorf("Content = %q, want %q", updated.Content, "New Content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdate_PartialUpdate(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
created, _ := svc.Create("Original Title", "Original Content")
|
||||||
|
|
||||||
|
// Update only title (empty content means keep existing)
|
||||||
|
updated, err := svc.Update(created.ID, "New Title", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Update failed: %v", err)
|
||||||
|
}
|
||||||
|
if updated.Title != "New Title" {
|
||||||
|
t.Errorf("Title = %q, want %q", updated.Title, "New Title")
|
||||||
|
}
|
||||||
|
if updated.Content != "Original Content" {
|
||||||
|
t.Errorf("Content = %q, want %q (should be unchanged)", updated.Content, "Original Content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdate_NotFound(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
_, err := svc.Update(999, "Title", "Content")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error updating non-existent announcement, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDelete_Success(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
created, _ := svc.Create("To Delete", "Content")
|
||||||
|
|
||||||
|
err := svc.Delete(created.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Delete failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's gone
|
||||||
|
list, _ := svc.GetAll()
|
||||||
|
if len(list) != 0 {
|
||||||
|
t.Errorf("expected 0 announcements after delete, got %d", len(list))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDelete_NotFound(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
err := svc.Delete(999)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error deleting non-existent announcement, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDelete_RepoError(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
created, _ := svc.Create("Title", "Content")
|
||||||
|
|
||||||
|
repo.deleteErr = fmt.Errorf("delete failed")
|
||||||
|
err := svc.Delete(created.ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when repo delete fails, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -41,7 +41,10 @@ func (h *Handler) Register(c *fiber.Ctx) error {
|
|||||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "비밀번호는 72자 이하여야 합니다"})
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "비밀번호는 72자 이하여야 합니다"})
|
||||||
}
|
}
|
||||||
if err := h.svc.Register(req.Username, req.Password); err != nil {
|
if err := h.svc.Register(req.Username, req.Password); err != nil {
|
||||||
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "회원가입에 실패했습니다"})
|
if strings.Contains(err.Error(), "이미 사용 중") {
|
||||||
|
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": err.Error()})
|
||||||
|
}
|
||||||
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "회원가입에 실패했습니다"})
|
||||||
}
|
}
|
||||||
return c.Status(fiber.StatusCreated).JSON(fiber.Map{"message": "회원가입이 완료되었습니다"})
|
return c.Status(fiber.StatusCreated).JSON(fiber.Map{"message": "회원가입이 완료되었습니다"})
|
||||||
}
|
}
|
||||||
@@ -70,30 +73,53 @@ func (h *Handler) Login(c *fiber.Ctx) error {
|
|||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": err.Error()})
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.Cookie(&fiber.Cookie{
|
||||||
|
Name: "refresh_token",
|
||||||
|
Value: refreshToken,
|
||||||
|
HTTPOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: "Strict",
|
||||||
|
Path: "/api/auth/refresh",
|
||||||
|
MaxAge: 7 * 24 * 60 * 60, // 7 days
|
||||||
|
})
|
||||||
return c.JSON(fiber.Map{
|
return c.JSON(fiber.Map{
|
||||||
"token": accessToken,
|
"token": accessToken,
|
||||||
"refreshToken": refreshToken,
|
"username": user.Username,
|
||||||
"username": user.Username,
|
"role": user.Role,
|
||||||
"role": user.Role,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) Refresh(c *fiber.Ctx) error {
|
func (h *Handler) Refresh(c *fiber.Ctx) error {
|
||||||
var req struct {
|
refreshTokenStr := c.Cookies("refresh_token")
|
||||||
RefreshToken string `json:"refreshToken"`
|
if refreshTokenStr == "" {
|
||||||
|
// Fallback to body for backward compatibility
|
||||||
|
var req struct {
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
}
|
||||||
|
if err := c.BodyParser(&req); err == nil && req.RefreshToken != "" {
|
||||||
|
refreshTokenStr = req.RefreshToken
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err := c.BodyParser(&req); err != nil || req.RefreshToken == "" {
|
if refreshTokenStr == "" {
|
||||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "refreshToken 필드가 필요합니다"})
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "refreshToken이 필요합니다"})
|
||||||
}
|
}
|
||||||
|
|
||||||
newAccessToken, newRefreshToken, err := h.svc.Refresh(req.RefreshToken)
|
newAccessToken, newRefreshToken, err := h.svc.Refresh(refreshTokenStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": err.Error()})
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.Cookie(&fiber.Cookie{
|
||||||
|
Name: "refresh_token",
|
||||||
|
Value: newRefreshToken,
|
||||||
|
HTTPOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: "Strict",
|
||||||
|
Path: "/api/auth/refresh",
|
||||||
|
MaxAge: 7 * 24 * 60 * 60, // 7 days
|
||||||
|
})
|
||||||
return c.JSON(fiber.Map{
|
return c.JSON(fiber.Map{
|
||||||
"token": newAccessToken,
|
"token": newAccessToken,
|
||||||
"refreshToken": newRefreshToken,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,11 +131,28 @@ func (h *Handler) Logout(c *fiber.Ctx) error {
|
|||||||
if err := h.svc.Logout(userID); err != nil {
|
if err := h.svc.Logout(userID); err != nil {
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "로그아웃 처리 중 오류가 발생했습니다"})
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "로그아웃 처리 중 오류가 발생했습니다"})
|
||||||
}
|
}
|
||||||
|
c.Cookie(&fiber.Cookie{
|
||||||
|
Name: "refresh_token",
|
||||||
|
Value: "",
|
||||||
|
HTTPOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: "Strict",
|
||||||
|
Path: "/api/auth/refresh",
|
||||||
|
MaxAge: -1, // delete
|
||||||
|
})
|
||||||
return c.JSON(fiber.Map{"message": "로그아웃 되었습니다"})
|
return c.JSON(fiber.Map{"message": "로그아웃 되었습니다"})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) GetAllUsers(c *fiber.Ctx) error {
|
func (h *Handler) GetAllUsers(c *fiber.Ctx) error {
|
||||||
users, err := h.svc.GetAllUsers()
|
offset := c.QueryInt("offset", 0)
|
||||||
|
limit := c.QueryInt("limit", 50)
|
||||||
|
if limit <= 0 || limit > 100 {
|
||||||
|
limit = 50
|
||||||
|
}
|
||||||
|
if offset < 0 {
|
||||||
|
offset = 0
|
||||||
|
}
|
||||||
|
users, err := h.svc.GetAllUsers(offset, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "유저 목록을 불러오지 못했습니다"})
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "유저 목록을 불러오지 못했습니다"})
|
||||||
}
|
}
|
||||||
@@ -155,29 +198,74 @@ func (h *Handler) VerifyToken(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) SSAFYLoginURL(c *fiber.Ctx) error {
|
func (h *Handler) SSAFYLoginURL(c *fiber.Ctx) error {
|
||||||
loginURL := h.svc.GetSSAFYLoginURL()
|
loginURL, err := h.svc.GetSSAFYLoginURL()
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "SSAFY 로그인 URL 생성에 실패했습니다"})
|
||||||
|
}
|
||||||
return c.JSON(fiber.Map{"url": loginURL})
|
return c.JSON(fiber.Map{"url": loginURL})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) SSAFYCallback(c *fiber.Ctx) error {
|
func (h *Handler) SSAFYCallback(c *fiber.Ctx) error {
|
||||||
var req struct {
|
var req struct {
|
||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
|
State string `json:"state"`
|
||||||
}
|
}
|
||||||
if err := c.BodyParser(&req); err != nil || req.Code == "" {
|
if err := c.BodyParser(&req); err != nil || req.Code == "" {
|
||||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "인가 코드가 필요합니다"})
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "인가 코드가 필요합니다"})
|
||||||
}
|
}
|
||||||
|
if req.State == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "state 파라미터가 필요합니다"})
|
||||||
|
}
|
||||||
|
|
||||||
accessToken, refreshToken, user, err := h.svc.SSAFYLogin(req.Code)
|
accessToken, refreshToken, user, err := h.svc.SSAFYLogin(req.Code, req.State)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": err.Error()})
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(fiber.Map{
|
c.Cookie(&fiber.Cookie{
|
||||||
"token": accessToken,
|
Name: "refresh_token",
|
||||||
"refreshToken": refreshToken,
|
Value: refreshToken,
|
||||||
"username": user.Username,
|
HTTPOnly: true,
|
||||||
"role": user.Role,
|
Secure: true,
|
||||||
|
SameSite: "Strict",
|
||||||
|
Path: "/api/auth/refresh",
|
||||||
|
MaxAge: 7 * 24 * 60 * 60, // 7 days
|
||||||
})
|
})
|
||||||
|
return c.JSON(fiber.Map{
|
||||||
|
"token": accessToken,
|
||||||
|
"username": user.Username,
|
||||||
|
"role": user.Role,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateLaunchTicket issues a one-time ticket for the game launcher.
|
||||||
|
// The launcher uses this ticket instead of receiving the JWT directly in the URL.
|
||||||
|
func (h *Handler) CreateLaunchTicket(c *fiber.Ctx) error {
|
||||||
|
userID, ok := c.Locals("userID").(uint)
|
||||||
|
if !ok {
|
||||||
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "인증 정보가 올바르지 않습니다"})
|
||||||
|
}
|
||||||
|
ticket, err := h.svc.CreateLaunchTicket(userID)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "티켓 발급에 실패했습니다"})
|
||||||
|
}
|
||||||
|
return c.JSON(fiber.Map{"ticket": ticket})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedeemLaunchTicket exchanges a one-time ticket for an access token.
|
||||||
|
// Called by the game launcher, not the web browser.
|
||||||
|
func (h *Handler) RedeemLaunchTicket(c *fiber.Ctx) error {
|
||||||
|
var req struct {
|
||||||
|
Ticket string `json:"ticket"`
|
||||||
|
}
|
||||||
|
if err := c.BodyParser(&req); err != nil || req.Ticket == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "ticket 필드가 필요합니다"})
|
||||||
|
}
|
||||||
|
token, err := h.svc.RedeemLaunchTicket(req.Ticket)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": err.Error()})
|
||||||
|
}
|
||||||
|
return c.JSON(fiber.Map{"token": token})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) DeleteUser(c *fiber.Ctx) error {
|
func (h *Handler) DeleteUser(c *fiber.Ctx) error {
|
||||||
|
|||||||
@@ -22,9 +22,9 @@ func (r *Repository) Create(user *User) error {
|
|||||||
return r.db.Create(user).Error
|
return r.db.Create(user).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Repository) FindAll() ([]User, error) {
|
func (r *Repository) FindAll(offset, limit int) ([]User, error) {
|
||||||
var users []User
|
var users []User
|
||||||
err := r.db.Order("created_at asc").Find(&users).Error
|
err := r.db.Order("created_at asc").Offset(offset).Limit(limit).Find(&users).Error
|
||||||
return users, err
|
return users, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -150,7 +150,11 @@ func (s *Service) Refresh(refreshTokenStr string) (newAccessToken, newRefreshTok
|
|||||||
return "", "", fmt.Errorf("만료되었거나 유효하지 않은 리프레시 토큰입니다")
|
return "", "", fmt.Errorf("만료되었거나 유효하지 않은 리프레시 토큰입니다")
|
||||||
}
|
}
|
||||||
|
|
||||||
user := &User{ID: claims.UserID, Username: claims.Username, Role: Role(claims.Role)}
|
// Look up the current user from DB to avoid using stale role from JWT claims
|
||||||
|
user, dbErr := s.repo.FindByID(claims.UserID)
|
||||||
|
if dbErr != nil {
|
||||||
|
return "", "", fmt.Errorf("유저를 찾을 수 없습니다")
|
||||||
|
}
|
||||||
|
|
||||||
newAccessToken, err = s.issueAccessToken(user)
|
newAccessToken, err = s.issueAccessToken(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -173,8 +177,8 @@ func (s *Service) Logout(userID uint) error {
|
|||||||
return s.rdb.Del(ctx, sessionKey, refreshKey).Err()
|
return s.rdb.Del(ctx, sessionKey, refreshKey).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) GetAllUsers() ([]User, error) {
|
func (s *Service) GetAllUsers(offset, limit int) ([]User, error) {
|
||||||
return s.repo.FindAll()
|
return s.repo.FindAll(offset, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) UpdateRole(id uint, role Role) error {
|
func (s *Service) UpdateRole(id uint, role Role) error {
|
||||||
@@ -182,7 +186,68 @@ func (s *Service) UpdateRole(id uint, role Role) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) DeleteUser(id uint) error {
|
func (s *Service) DeleteUser(id uint) error {
|
||||||
return s.repo.Delete(id)
|
if err := s.repo.Delete(id); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up Redis sessions for deleted user
|
||||||
|
ctx := context.Background()
|
||||||
|
sessionKey := fmt.Sprintf("session:%d", id)
|
||||||
|
refreshKey := fmt.Sprintf("refresh:%d", id)
|
||||||
|
s.rdb.Del(ctx, sessionKey, refreshKey)
|
||||||
|
|
||||||
|
// TODO: Clean up wallet and profile data via cross-service calls
|
||||||
|
// (walletCreator/profileCreator are creation-only; deletion callbacks are not yet wired up)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateLaunchTicket generates a one-time ticket that the game launcher
|
||||||
|
// exchanges for the real JWT. The ticket expires in 30 seconds and can only
|
||||||
|
// be redeemed once, preventing token exposure in URLs or browser history.
|
||||||
|
func (s *Service) CreateLaunchTicket(userID uint) (string, error) {
|
||||||
|
buf := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(buf); err != nil {
|
||||||
|
return "", fmt.Errorf("generate ticket: %w", err)
|
||||||
|
}
|
||||||
|
ticket := hex.EncodeToString(buf)
|
||||||
|
|
||||||
|
// Store ticket → userID mapping in Redis with 30s TTL
|
||||||
|
key := fmt.Sprintf("launch_ticket:%s", ticket)
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := s.rdb.Set(ctx, key, userID, 30*time.Second).Err(); err != nil {
|
||||||
|
return "", fmt.Errorf("store ticket: %w", err)
|
||||||
|
}
|
||||||
|
return ticket, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedeemLaunchTicket exchanges a one-time ticket for the user's access token.
|
||||||
|
// The ticket is deleted immediately after use (one-time).
|
||||||
|
func (s *Service) RedeemLaunchTicket(ticket string) (string, error) {
|
||||||
|
key := fmt.Sprintf("launch_ticket:%s", ticket)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Atomically get and delete (one-time use)
|
||||||
|
userIDStr, err := s.rdb.GetDel(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("유효하지 않거나 만료된 티켓입니다")
|
||||||
|
}
|
||||||
|
|
||||||
|
var userID uint
|
||||||
|
if _, err := fmt.Sscanf(userIDStr, "%d", &userID); err != nil {
|
||||||
|
return "", fmt.Errorf("invalid ticket data")
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := s.repo.FindByID(userID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("유저를 찾을 수 없습니다")
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken, err := s.issueAccessToken(user)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Register(username, password string) error {
|
func (s *Service) Register(username, password string) error {
|
||||||
@@ -193,39 +258,50 @@ func (s *Service) Register(username, password string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("비밀번호 처리에 실패했습니다")
|
return fmt.Errorf("비밀번호 처리에 실패했습니다")
|
||||||
}
|
}
|
||||||
user := &User{
|
|
||||||
Username: username,
|
return s.repo.Transaction(func(txRepo *Repository) error {
|
||||||
PasswordHash: string(hash),
|
user := &User{Username: username, PasswordHash: string(hash), Role: RoleUser}
|
||||||
Role: RoleUser,
|
if err := txRepo.Create(user); err != nil {
|
||||||
}
|
return err
|
||||||
if err := s.repo.Create(user); err != nil {
|
}
|
||||||
return err
|
if s.walletCreator != nil {
|
||||||
}
|
if err := s.walletCreator(user.ID); err != nil {
|
||||||
if s.walletCreator != nil {
|
return fmt.Errorf("wallet creation failed: %w", err)
|
||||||
if err := s.walletCreator(user.ID); err != nil {
|
|
||||||
log.Printf("wallet creation failed for user %d: %v — rolling back", user.ID, err)
|
|
||||||
if delErr := s.repo.Delete(user.ID); delErr != nil {
|
|
||||||
log.Printf("WARNING: rollback delete also failed for user %d: %v", user.ID, delErr)
|
|
||||||
}
|
}
|
||||||
return fmt.Errorf("계정 초기화에 실패했습니다. 잠시 후 다시 시도해주세요")
|
|
||||||
}
|
}
|
||||||
}
|
if s.profileCreator != nil {
|
||||||
if s.profileCreator != nil {
|
if err := s.profileCreator(user.ID); err != nil {
|
||||||
if err := s.profileCreator(user.ID); err != nil {
|
return fmt.Errorf("profile creation failed: %w", err)
|
||||||
log.Printf("profile creation failed for user %d: %v", user.ID, err)
|
}
|
||||||
}
|
}
|
||||||
}
|
return nil
|
||||||
return nil
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSSAFYLoginURL returns the SSAFY OAuth authorization URL.
|
// GetSSAFYLoginURL returns the SSAFY OAuth authorization URL with a random
|
||||||
func (s *Service) GetSSAFYLoginURL() string {
|
// state parameter for CSRF protection. The state is stored in Redis with a
|
||||||
|
// 5-minute TTL and must be verified in the callback.
|
||||||
|
func (s *Service) GetSSAFYLoginURL() (string, error) {
|
||||||
|
stateBytes := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(stateBytes); err != nil {
|
||||||
|
return "", fmt.Errorf("state 생성 실패: %w", err)
|
||||||
|
}
|
||||||
|
state := hex.EncodeToString(stateBytes)
|
||||||
|
|
||||||
|
// Store state in Redis with 5-minute TTL for one-time verification
|
||||||
|
key := fmt.Sprintf("ssafy_state:%s", state)
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := s.rdb.Set(ctx, key, "1", 5*time.Minute).Err(); err != nil {
|
||||||
|
return "", fmt.Errorf("state 저장 실패: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
params := url.Values{
|
params := url.Values{
|
||||||
"client_id": {config.C.SSAFYClientID},
|
"client_id": {config.C.SSAFYClientID},
|
||||||
"redirect_uri": {config.C.SSAFYRedirectURI},
|
"redirect_uri": {config.C.SSAFYRedirectURI},
|
||||||
"response_type": {"code"},
|
"response_type": {"code"},
|
||||||
|
"state": {state},
|
||||||
}
|
}
|
||||||
return "https://project.ssafy.com/oauth/sso-check?" + params.Encode()
|
return "https://project.ssafy.com/oauth/sso-check?" + params.Encode(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExchangeSSAFYCode exchanges an authorization code for SSAFY tokens.
|
// ExchangeSSAFYCode exchanges an authorization code for SSAFY tokens.
|
||||||
@@ -248,7 +324,7 @@ func (s *Service) ExchangeSSAFYCode(code string) (*SSAFYTokenResponse, error) {
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("SSAFY 토큰 응답 읽기 실패: %v", err)
|
return nil, fmt.Errorf("SSAFY 토큰 응답 읽기 실패: %v", err)
|
||||||
}
|
}
|
||||||
@@ -279,7 +355,7 @@ func (s *Service) GetSSAFYUserInfo(accessToken string) (*SSAFYUserInfo, error) {
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("SSAFY 사용자 정보 응답 읽기 실패: %v", err)
|
return nil, fmt.Errorf("SSAFY 사용자 정보 응답 읽기 실패: %v", err)
|
||||||
}
|
}
|
||||||
@@ -296,7 +372,18 @@ func (s *Service) GetSSAFYUserInfo(accessToken string) (*SSAFYUserInfo, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SSAFYLogin handles the full SSAFY OAuth callback: exchange code, get user info, find or create user, issue tokens.
|
// SSAFYLogin handles the full SSAFY OAuth callback: exchange code, get user info, find or create user, issue tokens.
|
||||||
func (s *Service) SSAFYLogin(code string) (accessToken, refreshToken string, user *User, err error) {
|
// The state parameter is verified against Redis (one-time use via GetDel) for CSRF protection.
|
||||||
|
func (s *Service) SSAFYLogin(code, state string) (accessToken, refreshToken string, user *User, err error) {
|
||||||
|
// Verify CSRF state parameter (one-time use)
|
||||||
|
if state == "" {
|
||||||
|
return "", "", nil, fmt.Errorf("state 파라미터가 필요합니다")
|
||||||
|
}
|
||||||
|
stateKey := fmt.Sprintf("ssafy_state:%s", state)
|
||||||
|
val, err := s.rdb.GetDel(context.Background(), stateKey).Result()
|
||||||
|
if err != nil || val != "1" {
|
||||||
|
return "", "", nil, fmt.Errorf("유효하지 않거나 만료된 state 파라미터입니다")
|
||||||
|
}
|
||||||
|
|
||||||
tokenResp, err := s.ExchangeSSAFYCode(code)
|
tokenResp, err := s.ExchangeSSAFYCode(code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
@@ -333,7 +420,6 @@ func (s *Service) SSAFYLogin(code string) (accessToken, refreshToken string, use
|
|||||||
username = username[:50]
|
username = username[:50]
|
||||||
}
|
}
|
||||||
|
|
||||||
var newUserID uint
|
|
||||||
err = s.repo.Transaction(func(txRepo *Repository) error {
|
err = s.repo.Transaction(func(txRepo *Repository) error {
|
||||||
user = &User{
|
user = &User{
|
||||||
Username: username,
|
Username: username,
|
||||||
@@ -341,27 +427,26 @@ func (s *Service) SSAFYLogin(code string) (accessToken, refreshToken string, use
|
|||||||
Role: RoleUser,
|
Role: RoleUser,
|
||||||
SsafyID: &ssafyID,
|
SsafyID: &ssafyID,
|
||||||
}
|
}
|
||||||
return txRepo.Create(user)
|
if err := txRepo.Create(user); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.walletCreator != nil {
|
||||||
|
if err := s.walletCreator(user.ID); err != nil {
|
||||||
|
return fmt.Errorf("wallet creation failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.profileCreator != nil {
|
||||||
|
if err := s.profileCreator(user.ID); err != nil {
|
||||||
|
return fmt.Errorf("profile creation failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("SSAFY user creation transaction failed: %v", err)
|
||||||
return "", "", nil, fmt.Errorf("계정 생성 실패: %v", err)
|
return "", "", nil, fmt.Errorf("계정 생성 실패: %v", err)
|
||||||
}
|
}
|
||||||
newUserID = user.ID
|
|
||||||
|
|
||||||
if s.walletCreator != nil {
|
|
||||||
if err := s.walletCreator(newUserID); err != nil {
|
|
||||||
log.Printf("wallet creation failed for SSAFY user %d: %v — rolling back", newUserID, err)
|
|
||||||
if delErr := s.repo.Delete(newUserID); delErr != nil {
|
|
||||||
log.Printf("WARNING: rollback delete also failed for SSAFY user %d: %v", newUserID, delErr)
|
|
||||||
}
|
|
||||||
return "", "", nil, fmt.Errorf("계정 초기화에 실패했습니다. 잠시 후 다시 시도해주세요")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if s.profileCreator != nil {
|
|
||||||
if err := s.profileCreator(newUserID); err != nil {
|
|
||||||
log.Printf("profile creation failed for SSAFY user %d: %v", newUserID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, err = s.issueAccessToken(user)
|
accessToken, err = s.issueAccessToken(user)
|
||||||
@@ -414,6 +499,10 @@ func sanitizeForUsername(s string) string {
|
|||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: EnsureAdmin does not use a transaction for wallet/profile creation.
|
||||||
|
// If these fail, the admin user exists without a wallet/profile.
|
||||||
|
// This is acceptable because EnsureAdmin runs once at startup and failures
|
||||||
|
// are logged as warnings. A restart will skip user creation (already exists).
|
||||||
func (s *Service) EnsureAdmin(username, password string) error {
|
func (s *Service) EnsureAdmin(username, password string) error {
|
||||||
if _, err := s.repo.FindByUsername(username); err == nil {
|
if _, err := s.repo.FindByUsername(username); err == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
291
internal/auth/service_test.go
Normal file
291
internal/auth/service_test.go
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"a301_server/pkg/config"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 1. Password hashing (bcrypt)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestBcryptHashAndVerify(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
password string
|
||||||
|
wantMatch bool
|
||||||
|
}{
|
||||||
|
{"short password", "abc", true},
|
||||||
|
{"normal password", "myP@ssw0rd!", true},
|
||||||
|
{"unicode password", "비밀번호123", true},
|
||||||
|
{"empty password", "", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
hash, err := bcrypt.GenerateFromPassword([]byte(tc.password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateFromPassword failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bcrypt.CompareHashAndPassword(hash, []byte(tc.password))
|
||||||
|
if (err == nil) != tc.wantMatch {
|
||||||
|
t.Errorf("CompareHashAndPassword: got err=%v, wantMatch=%v", err, tc.wantMatch)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBcryptWrongPassword(t *testing.T) {
|
||||||
|
hash, err := bcrypt.GenerateFromPassword([]byte("correct"), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateFromPassword failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := bcrypt.CompareHashAndPassword(hash, []byte("wrong")); err == nil {
|
||||||
|
t.Error("expected error comparing wrong password, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBcryptDifferentHashesForSamePassword(t *testing.T) {
|
||||||
|
password := "samePassword"
|
||||||
|
hash1, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
hash2, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if string(hash1) == string(hash2) {
|
||||||
|
t.Error("expected different hashes for the same password (different salts)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 2. JWT token generation and parsing
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func setupTestConfig() {
|
||||||
|
config.C = config.Config{
|
||||||
|
JWTSecret: "test-jwt-secret-key-for-unit-tests",
|
||||||
|
RefreshSecret: "test-refresh-secret-key-for-unit-tests",
|
||||||
|
JWTExpiryHours: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIssueAndParseAccessToken(t *testing.T) {
|
||||||
|
setupTestConfig()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
userID uint
|
||||||
|
username string
|
||||||
|
role string
|
||||||
|
}{
|
||||||
|
{"admin user", 1, "admin", "admin"},
|
||||||
|
{"regular user", 42, "player1", "user"},
|
||||||
|
{"unicode username", 100, "유저", "user"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
expiry := time.Duration(config.C.JWTExpiryHours) * time.Hour
|
||||||
|
claims := &Claims{
|
||||||
|
UserID: tc.userID,
|
||||||
|
Username: tc.username,
|
||||||
|
Role: tc.role,
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
tokenStr, err := token.SignedString([]byte(config.C.JWTSecret))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SignedString failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||||
|
return []byte(config.C.JWTSecret), nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseWithClaims failed: %v", err)
|
||||||
|
}
|
||||||
|
if !parsed.Valid {
|
||||||
|
t.Fatal("parsed token is not valid")
|
||||||
|
}
|
||||||
|
|
||||||
|
got, ok := parsed.Claims.(*Claims)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("failed to cast claims")
|
||||||
|
}
|
||||||
|
if got.UserID != tc.userID {
|
||||||
|
t.Errorf("UserID = %d, want %d", got.UserID, tc.userID)
|
||||||
|
}
|
||||||
|
if got.Username != tc.username {
|
||||||
|
t.Errorf("Username = %q, want %q", got.Username, tc.username)
|
||||||
|
}
|
||||||
|
if got.Role != tc.role {
|
||||||
|
t.Errorf("Role = %q, want %q", got.Role, tc.role)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTokenWithWrongSecret(t *testing.T) {
|
||||||
|
setupTestConfig()
|
||||||
|
|
||||||
|
claims := &Claims{
|
||||||
|
UserID: 1,
|
||||||
|
Username: "test",
|
||||||
|
Role: "user",
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
tokenStr, _ := token.SignedString([]byte(config.C.JWTSecret))
|
||||||
|
|
||||||
|
_, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||||
|
return []byte("wrong-secret"), nil
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error parsing token with wrong secret, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseExpiredToken(t *testing.T) {
|
||||||
|
setupTestConfig()
|
||||||
|
|
||||||
|
claims := &Claims{
|
||||||
|
UserID: 1,
|
||||||
|
Username: "test",
|
||||||
|
Role: "user",
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
tokenStr, _ := token.SignedString([]byte(config.C.JWTSecret))
|
||||||
|
|
||||||
|
_, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||||
|
return []byte(config.C.JWTSecret), nil
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error parsing expired token, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokenUsesDifferentSecret(t *testing.T) {
|
||||||
|
setupTestConfig()
|
||||||
|
|
||||||
|
claims := &Claims{
|
||||||
|
UserID: 1,
|
||||||
|
Username: "test",
|
||||||
|
Role: "user",
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenExpiry)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign with refresh secret
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
tokenStr, _ := token.SignedString([]byte(config.C.RefreshSecret))
|
||||||
|
|
||||||
|
// Should fail with JWT secret
|
||||||
|
_, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||||
|
return []byte(config.C.JWTSecret), nil
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error parsing refresh token with access secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should succeed with refresh secret
|
||||||
|
parsed, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||||
|
return []byte(config.C.RefreshSecret), nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected success with refresh secret, got: %v", err)
|
||||||
|
}
|
||||||
|
if !parsed.Valid {
|
||||||
|
t.Error("parsed refresh token is not valid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 3. Input validation helpers (sanitizeForUsername)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestSanitizeForUsername(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"lowercase letters", "hello", "hello"},
|
||||||
|
{"uppercase converted", "HeLLo", "hello"},
|
||||||
|
{"digits kept", "user123", "user123"},
|
||||||
|
{"underscore kept", "user_name", "user_name"},
|
||||||
|
{"hyphen kept", "user-name", "user-name"},
|
||||||
|
{"special chars removed", "user@name!#$", "username"},
|
||||||
|
{"spaces removed", "user name", "username"},
|
||||||
|
{"unicode removed", "유저abc", "abc"},
|
||||||
|
{"mixed", "User-123_Test!", "user-123_test"},
|
||||||
|
{"empty input", "", ""},
|
||||||
|
{"all removed", "!!@@##", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := sanitizeForUsername(tc.input)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("sanitizeForUsername(%q) = %q, want %q", tc.input, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 4. Claims struct fields
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestClaimsRoundTrip(t *testing.T) {
|
||||||
|
setupTestConfig()
|
||||||
|
|
||||||
|
original := &Claims{
|
||||||
|
UserID: 999,
|
||||||
|
Username: "testuser",
|
||||||
|
Role: "admin",
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, original)
|
||||||
|
tokenStr, err := token.SignedString([]byte(config.C.JWTSecret))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("signing failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||||
|
return []byte(config.C.JWTSecret), nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parsing failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := parsed.Claims.(*Claims)
|
||||||
|
|
||||||
|
if got.UserID != original.UserID {
|
||||||
|
t.Errorf("UserID: got %d, want %d", got.UserID, original.UserID)
|
||||||
|
}
|
||||||
|
if got.Username != original.Username {
|
||||||
|
t.Errorf("Username: got %q, want %q", got.Username, original.Username)
|
||||||
|
}
|
||||||
|
if got.Role != original.Role {
|
||||||
|
t.Errorf("Role: got %q, want %q", got.Role, original.Role)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -25,6 +25,9 @@ type BossRoom struct {
|
|||||||
BossID int `json:"bossId" gorm:"index;not null"`
|
BossID int `json:"bossId" gorm:"index;not null"`
|
||||||
Status RoomStatus `json:"status" gorm:"type:varchar(20);index;default:waiting;not null"`
|
Status RoomStatus `json:"status" gorm:"type:varchar(20);index;default:waiting;not null"`
|
||||||
MaxPlayers int `json:"maxPlayers" gorm:"default:3;not null"`
|
MaxPlayers int `json:"maxPlayers" gorm:"default:3;not null"`
|
||||||
|
// Players is stored as a JSON text column for simplicity.
|
||||||
|
// TODO: For better query performance, consider migrating to a junction table
|
||||||
|
// (boss_room_players with room_id + username columns).
|
||||||
Players string `json:"players" gorm:"type:text"` // JSON array of usernames
|
Players string `json:"players" gorm:"type:text"` // JSON array of usernames
|
||||||
StartedAt *time.Time `json:"startedAt,omitempty"`
|
StartedAt *time.Time `json:"startedAt,omitempty"`
|
||||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// defaultMaxPlayers is the maximum number of players allowed in a boss raid room.
|
||||||
|
defaultMaxPlayers = 3
|
||||||
// entryTokenTTL is the TTL for boss raid entry tokens in Redis.
|
// entryTokenTTL is the TTL for boss raid entry tokens in Redis.
|
||||||
entryTokenTTL = 5 * time.Minute
|
entryTokenTTL = 5 * time.Minute
|
||||||
// entryTokenPrefix is the Redis key prefix for entry token → {username, sessionName}.
|
// entryTokenPrefix is the Redis key prefix for entry token → {username, sessionName}.
|
||||||
@@ -84,7 +86,7 @@ func (s *Service) RequestEntry(usernames []string, bossID int) (*BossRoom, error
|
|||||||
SessionName: sessionName,
|
SessionName: sessionName,
|
||||||
BossID: bossID,
|
BossID: bossID,
|
||||||
Status: StatusWaiting,
|
Status: StatusWaiting,
|
||||||
MaxPlayers: 3,
|
MaxPlayers: defaultMaxPlayers,
|
||||||
Players: string(playersJSON),
|
Players: string(playersJSON),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
602
internal/bossraid/service_test.go
Normal file
602
internal/bossraid/service_test.go
Normal file
@@ -0,0 +1,602 @@
|
|||||||
|
// NOTE: These tests use a testableService that reimplements service logic
|
||||||
|
// with mock repositories. This means tests can pass even if the real service
|
||||||
|
// diverges. For full coverage, consider refactoring services to use repository
|
||||||
|
// interfaces so the real service can be tested with mock repositories injected.
|
||||||
|
|
||||||
|
package bossraid
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/tolelom/tolchain/core"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Mock repository — mirrors the methods that Service calls.
|
||||||
|
// Since Service uses concrete *Repository and Transaction(func(*Repository)),
|
||||||
|
// we create a testableService with an interface to enable mocking.
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type repositoryInterface interface {
|
||||||
|
Create(room *BossRoom) error
|
||||||
|
Update(room *BossRoom) error
|
||||||
|
FindBySessionName(sessionName string) (*BossRoom, error)
|
||||||
|
FindBySessionNameForUpdate(sessionName string) (*BossRoom, error)
|
||||||
|
CountActiveByUsername(username string) (int64, error)
|
||||||
|
Transaction(fn func(txRepo repositoryInterface) error) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type testableService struct {
|
||||||
|
repo repositoryInterface
|
||||||
|
rewardGrant func(username string, tokenAmount uint64, assets []core.MintAssetPayload) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testableService) RequestEntry(usernames []string, bossID int) (*BossRoom, error) {
|
||||||
|
if len(usernames) == 0 {
|
||||||
|
return nil, fmt.Errorf("플레이어 목록이 비어있습니다")
|
||||||
|
}
|
||||||
|
if len(usernames) > 3 {
|
||||||
|
return nil, fmt.Errorf("최대 3명까지 입장할 수 있습니다")
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]bool, len(usernames))
|
||||||
|
for _, u := range usernames {
|
||||||
|
if seen[u] {
|
||||||
|
return nil, fmt.Errorf("중복된 플레이어가 있습니다: %s", u)
|
||||||
|
}
|
||||||
|
seen[u] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, username := range usernames {
|
||||||
|
count, err := s.repo.CountActiveByUsername(username)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("플레이어 상태 확인 실패: %w", err)
|
||||||
|
}
|
||||||
|
if count > 0 {
|
||||||
|
return nil, fmt.Errorf("플레이어 %s가 이미 보스 레이드 중입니다", username)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
playersJSON, err := json.Marshal(usernames)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("플레이어 목록 직렬화 실패: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionName := fmt.Sprintf("BossRaid_%d_%d", bossID, time.Now().UnixNano())
|
||||||
|
room := &BossRoom{
|
||||||
|
SessionName: sessionName,
|
||||||
|
BossID: bossID,
|
||||||
|
Status: StatusWaiting,
|
||||||
|
MaxPlayers: 3,
|
||||||
|
Players: string(playersJSON),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.repo.Create(room); err != nil {
|
||||||
|
return nil, fmt.Errorf("방 생성 실패: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return room, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testableService) CompleteRaid(sessionName string, rewards []PlayerReward) (*BossRoom, []RewardResult, error) {
|
||||||
|
var resultRoom *BossRoom
|
||||||
|
var resultRewards []RewardResult
|
||||||
|
|
||||||
|
err := s.repo.Transaction(func(txRepo repositoryInterface) error {
|
||||||
|
room, err := txRepo.FindBySessionNameForUpdate(sessionName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("방을 찾을 수 없습니다: %w", err)
|
||||||
|
}
|
||||||
|
if room.Status != StatusInProgress {
|
||||||
|
return fmt.Errorf("완료할 수 없는 상태입니다: %s", room.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
var players []string
|
||||||
|
if err := json.Unmarshal([]byte(room.Players), &players); err != nil {
|
||||||
|
return fmt.Errorf("플레이어 목록 파싱 실패: %w", err)
|
||||||
|
}
|
||||||
|
playerSet := make(map[string]bool, len(players))
|
||||||
|
for _, p := range players {
|
||||||
|
playerSet[p] = true
|
||||||
|
}
|
||||||
|
for _, r := range rewards {
|
||||||
|
if !playerSet[r.Username] {
|
||||||
|
return fmt.Errorf("보상 대상 %s가 방의 플레이어가 아닙니다", r.Username)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
room.Status = StatusCompleted
|
||||||
|
room.CompletedAt = &now
|
||||||
|
if err := txRepo.Update(room); err != nil {
|
||||||
|
return fmt.Errorf("상태 업데이트 실패: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resultRoom = room
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resultRewards = make([]RewardResult, 0, len(rewards))
|
||||||
|
if s.rewardGrant != nil {
|
||||||
|
for _, r := range rewards {
|
||||||
|
grantErr := s.rewardGrant(r.Username, r.TokenAmount, r.Assets)
|
||||||
|
result := RewardResult{Username: r.Username, Success: grantErr == nil}
|
||||||
|
if grantErr != nil {
|
||||||
|
result.Error = grantErr.Error()
|
||||||
|
}
|
||||||
|
resultRewards = append(resultRewards, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resultRoom, resultRewards, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testableService) FailRaid(sessionName string) (*BossRoom, error) {
|
||||||
|
var resultRoom *BossRoom
|
||||||
|
err := s.repo.Transaction(func(txRepo repositoryInterface) error {
|
||||||
|
room, err := txRepo.FindBySessionNameForUpdate(sessionName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("방을 찾을 수 없습니다: %w", err)
|
||||||
|
}
|
||||||
|
if room.Status != StatusWaiting && room.Status != StatusInProgress {
|
||||||
|
return fmt.Errorf("실패 처리할 수 없는 상태입니다: %s", room.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
room.Status = StatusFailed
|
||||||
|
room.CompletedAt = &now
|
||||||
|
if err := txRepo.Update(room); err != nil {
|
||||||
|
return fmt.Errorf("상태 업데이트 실패: %w", err)
|
||||||
|
}
|
||||||
|
resultRoom = room
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return resultRoom, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Mock implementation
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type mockRepo struct {
|
||||||
|
rooms map[string]*BossRoom
|
||||||
|
activeCounts map[string]int64 // username -> active count
|
||||||
|
nextID uint
|
||||||
|
createErr error
|
||||||
|
updateErr error
|
||||||
|
countActiveErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockRepo() *mockRepo {
|
||||||
|
return &mockRepo{
|
||||||
|
rooms: make(map[string]*BossRoom),
|
||||||
|
activeCounts: make(map[string]int64),
|
||||||
|
nextID: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRepo) Create(room *BossRoom) error {
|
||||||
|
if m.createErr != nil {
|
||||||
|
return m.createErr
|
||||||
|
}
|
||||||
|
room.ID = m.nextID
|
||||||
|
room.CreatedAt = time.Now()
|
||||||
|
room.UpdatedAt = time.Now()
|
||||||
|
m.nextID++
|
||||||
|
stored := *room
|
||||||
|
m.rooms[room.SessionName] = &stored
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRepo) Update(room *BossRoom) error {
|
||||||
|
if m.updateErr != nil {
|
||||||
|
return m.updateErr
|
||||||
|
}
|
||||||
|
room.UpdatedAt = time.Now()
|
||||||
|
stored := *room
|
||||||
|
m.rooms[room.SessionName] = &stored
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRepo) FindBySessionName(sessionName string) (*BossRoom, error) {
|
||||||
|
room, ok := m.rooms[sessionName]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("record not found")
|
||||||
|
}
|
||||||
|
cp := *room
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRepo) FindBySessionNameForUpdate(sessionName string) (*BossRoom, error) {
|
||||||
|
return m.FindBySessionName(sessionName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRepo) CountActiveByUsername(username string) (int64, error) {
|
||||||
|
if m.countActiveErr != nil {
|
||||||
|
return 0, m.countActiveErr
|
||||||
|
}
|
||||||
|
return m.activeCounts[username], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRepo) Transaction(fn func(txRepo repositoryInterface) error) error {
|
||||||
|
return fn(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Tests: RequestEntry
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestRequestEntry_Success(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
room, err := svc.RequestEntry([]string{"player1", "player2"}, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RequestEntry failed: %v", err)
|
||||||
|
}
|
||||||
|
if room.Status != StatusWaiting {
|
||||||
|
t.Errorf("Status = %q, want %q", room.Status, StatusWaiting)
|
||||||
|
}
|
||||||
|
if room.BossID != 1 {
|
||||||
|
t.Errorf("BossID = %d, want 1", room.BossID)
|
||||||
|
}
|
||||||
|
if room.MaxPlayers != 3 {
|
||||||
|
t.Errorf("MaxPlayers = %d, want 3", room.MaxPlayers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestEntry_EmptyPlayers(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
_, err := svc.RequestEntry([]string{}, 1)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for empty player list, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestEntry_TooManyPlayers(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
_, err := svc.RequestEntry([]string{"p1", "p2", "p3", "p4"}, 1)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for >3 players, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestEntry_DuplicatePlayers(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
_, err := svc.RequestEntry([]string{"player1", "player1"}, 1)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for duplicate players, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestEntry_PlayerAlreadyInActiveRaid(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
repo.activeCounts["player1"] = 1
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
_, err := svc.RequestEntry([]string{"player1", "player2"}, 1)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when player is already in active raid, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestEntry_ThreePlayers(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
room, err := svc.RequestEntry([]string{"p1", "p2", "p3"}, 5)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RequestEntry failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var players []string
|
||||||
|
if err := json.Unmarshal([]byte(room.Players), &players); err != nil {
|
||||||
|
t.Fatalf("failed to parse Players JSON: %v", err)
|
||||||
|
}
|
||||||
|
if len(players) != 3 {
|
||||||
|
t.Errorf("expected 3 players in JSON, got %d", len(players))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Tests: CompleteRaid
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCompleteRaid_Success(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
// Create a room and set it to in_progress
|
||||||
|
room, _ := svc.RequestEntry([]string{"player1", "player2"}, 1)
|
||||||
|
stored := repo.rooms[room.SessionName]
|
||||||
|
stored.Status = StatusInProgress
|
||||||
|
now := time.Now()
|
||||||
|
stored.StartedAt = &now
|
||||||
|
|
||||||
|
rewards := []PlayerReward{
|
||||||
|
{Username: "player1", TokenAmount: 100},
|
||||||
|
{Username: "player2", TokenAmount: 50},
|
||||||
|
}
|
||||||
|
|
||||||
|
completed, results, err := svc.CompleteRaid(room.SessionName, rewards)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CompleteRaid failed: %v", err)
|
||||||
|
}
|
||||||
|
if completed.Status != StatusCompleted {
|
||||||
|
t.Errorf("Status = %q, want %q", completed.Status, StatusCompleted)
|
||||||
|
}
|
||||||
|
if completed.CompletedAt == nil {
|
||||||
|
t.Error("CompletedAt should be set")
|
||||||
|
}
|
||||||
|
// No reward granter set, so results should be empty
|
||||||
|
if len(results) != 0 {
|
||||||
|
t.Errorf("expected 0 reward results (no granter set), got %d", len(results))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompleteRaid_WithRewardGranter(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
grantCalls := 0
|
||||||
|
svc := &testableService{
|
||||||
|
repo: repo,
|
||||||
|
rewardGrant: func(username string, tokenAmount uint64, assets []core.MintAssetPayload) error {
|
||||||
|
grantCalls++
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
room, _ := svc.RequestEntry([]string{"player1"}, 1)
|
||||||
|
stored := repo.rooms[room.SessionName]
|
||||||
|
stored.Status = StatusInProgress
|
||||||
|
|
||||||
|
rewards := []PlayerReward{
|
||||||
|
{Username: "player1", TokenAmount: 100},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, results, err := svc.CompleteRaid(room.SessionName, rewards)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CompleteRaid failed: %v", err)
|
||||||
|
}
|
||||||
|
if grantCalls != 1 {
|
||||||
|
t.Errorf("expected 1 grant call, got %d", grantCalls)
|
||||||
|
}
|
||||||
|
if len(results) != 1 {
|
||||||
|
t.Fatalf("expected 1 result, got %d", len(results))
|
||||||
|
}
|
||||||
|
if !results[0].Success {
|
||||||
|
t.Errorf("expected success=true, got false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompleteRaid_RewardGranterFails(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{
|
||||||
|
repo: repo,
|
||||||
|
rewardGrant: func(username string, tokenAmount uint64, assets []core.MintAssetPayload) error {
|
||||||
|
return fmt.Errorf("blockchain error")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
room, _ := svc.RequestEntry([]string{"player1"}, 1)
|
||||||
|
stored := repo.rooms[room.SessionName]
|
||||||
|
stored.Status = StatusInProgress
|
||||||
|
|
||||||
|
rewards := []PlayerReward{
|
||||||
|
{Username: "player1", TokenAmount: 100},
|
||||||
|
}
|
||||||
|
|
||||||
|
completed, results, err := svc.CompleteRaid(room.SessionName, rewards)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CompleteRaid should not fail when reward granter fails: %v", err)
|
||||||
|
}
|
||||||
|
// Room should still be completed
|
||||||
|
if completed.Status != StatusCompleted {
|
||||||
|
t.Errorf("Status = %q, want %q", completed.Status, StatusCompleted)
|
||||||
|
}
|
||||||
|
if len(results) != 1 {
|
||||||
|
t.Fatalf("expected 1 result, got %d", len(results))
|
||||||
|
}
|
||||||
|
if results[0].Success {
|
||||||
|
t.Error("expected success=false for failed grant")
|
||||||
|
}
|
||||||
|
if results[0].Error == "" {
|
||||||
|
t.Error("expected non-empty error message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompleteRaid_WrongStatus(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
room, _ := svc.RequestEntry([]string{"player1"}, 1)
|
||||||
|
// Room is in "waiting" status, not "in_progress"
|
||||||
|
|
||||||
|
_, _, err := svc.CompleteRaid(room.SessionName, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error completing raid that is not in_progress, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompleteRaid_InvalidRewardRecipient(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
room, _ := svc.RequestEntry([]string{"player1"}, 1)
|
||||||
|
stored := repo.rooms[room.SessionName]
|
||||||
|
stored.Status = StatusInProgress
|
||||||
|
|
||||||
|
rewards := []PlayerReward{
|
||||||
|
{Username: "not_a_member", TokenAmount: 100},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err := svc.CompleteRaid(room.SessionName, rewards)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for reward to non-member, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompleteRaid_RoomNotFound(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
_, _, err := svc.CompleteRaid("nonexistent_session", nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for non-existent room, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Tests: FailRaid
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestFailRaid_FromWaiting(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
room, _ := svc.RequestEntry([]string{"player1"}, 1)
|
||||||
|
|
||||||
|
failed, err := svc.FailRaid(room.SessionName)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FailRaid failed: %v", err)
|
||||||
|
}
|
||||||
|
if failed.Status != StatusFailed {
|
||||||
|
t.Errorf("Status = %q, want %q", failed.Status, StatusFailed)
|
||||||
|
}
|
||||||
|
if failed.CompletedAt == nil {
|
||||||
|
t.Error("CompletedAt should be set on failure")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFailRaid_FromInProgress(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
room, _ := svc.RequestEntry([]string{"player1"}, 1)
|
||||||
|
stored := repo.rooms[room.SessionName]
|
||||||
|
stored.Status = StatusInProgress
|
||||||
|
|
||||||
|
failed, err := svc.FailRaid(room.SessionName)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FailRaid failed: %v", err)
|
||||||
|
}
|
||||||
|
if failed.Status != StatusFailed {
|
||||||
|
t.Errorf("Status = %q, want %q", failed.Status, StatusFailed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFailRaid_FromCompleted(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
room, _ := svc.RequestEntry([]string{"player1"}, 1)
|
||||||
|
stored := repo.rooms[room.SessionName]
|
||||||
|
stored.Status = StatusCompleted
|
||||||
|
|
||||||
|
_, err := svc.FailRaid(room.SessionName)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error failing already-completed raid, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFailRaid_FromFailed(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
room, _ := svc.RequestEntry([]string{"player1"}, 1)
|
||||||
|
stored := repo.rooms[room.SessionName]
|
||||||
|
stored.Status = StatusFailed
|
||||||
|
|
||||||
|
_, err := svc.FailRaid(room.SessionName)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error failing already-failed raid, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFailRaid_RoomNotFound(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
_, err := svc.FailRaid("nonexistent_session")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for non-existent room, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Tests: State machine transitions
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestStateMachine_FullLifecycle(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
// 1. Create room (waiting)
|
||||||
|
room, err := svc.RequestEntry([]string{"p1", "p2"}, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RequestEntry failed: %v", err)
|
||||||
|
}
|
||||||
|
if room.Status != StatusWaiting {
|
||||||
|
t.Fatalf("expected waiting, got %s", room.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Simulate start (set to in_progress)
|
||||||
|
stored := repo.rooms[room.SessionName]
|
||||||
|
stored.Status = StatusInProgress
|
||||||
|
now := time.Now()
|
||||||
|
stored.StartedAt = &now
|
||||||
|
|
||||||
|
// 3. Complete
|
||||||
|
completed, _, err := svc.CompleteRaid(room.SessionName, []PlayerReward{
|
||||||
|
{Username: "p1", TokenAmount: 10},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CompleteRaid failed: %v", err)
|
||||||
|
}
|
||||||
|
if completed.Status != StatusCompleted {
|
||||||
|
t.Errorf("expected completed, got %s", completed.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Cannot fail a completed raid
|
||||||
|
_, err = svc.FailRaid(room.SessionName)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error failing completed raid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateMachine_WaitingToFailed(t *testing.T) {
|
||||||
|
repo := newMockRepo()
|
||||||
|
svc := &testableService{repo: repo}
|
||||||
|
|
||||||
|
room, _ := svc.RequestEntry([]string{"p1"}, 1)
|
||||||
|
|
||||||
|
failed, err := svc.FailRaid(room.SessionName)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FailRaid failed: %v", err)
|
||||||
|
}
|
||||||
|
if failed.Status != StatusFailed {
|
||||||
|
t.Errorf("expected failed, got %s", failed.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cannot complete a failed raid
|
||||||
|
stored := repo.rooms[room.SessionName]
|
||||||
|
stored.Status = StatusFailed
|
||||||
|
_, _, err = svc.CompleteRaid(room.SessionName, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error completing failed raid")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -70,7 +70,7 @@ func (c *Client) Call(method string, params any, out any) error {
|
|||||||
return fmt.Errorf("RPC HTTP error: status %d", resp.StatusCode)
|
return fmt.Errorf("RPC HTTP error: status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("read RPC response: %w", err)
|
return fmt.Errorf("read RPC response: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/tolelom/tolchain/core"
|
"github.com/tolelom/tolchain/core"
|
||||||
tocrypto "github.com/tolelom/tolchain/crypto"
|
tocrypto "github.com/tolelom/tolchain/crypto"
|
||||||
@@ -22,6 +23,8 @@ type Service struct {
|
|||||||
operatorWallet *wallet.Wallet
|
operatorWallet *wallet.Wallet
|
||||||
encKeyBytes []byte // 32-byte AES-256 key
|
encKeyBytes []byte // 32-byte AES-256 key
|
||||||
userResolver func(username string) (uint, error)
|
userResolver func(username string) (uint, error)
|
||||||
|
operatorMu sync.Mutex // serialises operator-nonce transactions
|
||||||
|
userMu sync.Map // per-user mutex (keyed by userID uint)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUserResolver sets the callback that resolves username → userID.
|
// SetUserResolver sets the callback that resolves username → userID.
|
||||||
@@ -209,9 +212,18 @@ func (s *Service) GetListing(listingID string) (json.RawMessage, error) {
|
|||||||
return s.client.GetListing(listingID)
|
return s.client.GetListing(listingID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getUserMu returns a per-user mutex, creating one if it doesn't exist.
|
||||||
|
func (s *Service) getUserMu(userID uint) *sync.Mutex {
|
||||||
|
v, _ := s.userMu.LoadOrStore(userID, &sync.Mutex{})
|
||||||
|
return v.(*sync.Mutex)
|
||||||
|
}
|
||||||
|
|
||||||
// ---- User Transaction Methods ----
|
// ---- User Transaction Methods ----
|
||||||
|
|
||||||
func (s *Service) Transfer(userID uint, to string, amount uint64) (*SendTxResult, error) {
|
func (s *Service) Transfer(userID uint, to string, amount uint64) (*SendTxResult, error) {
|
||||||
|
mu := s.getUserMu(userID)
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
w, pubKey, err := s.loadUserWallet(userID)
|
w, pubKey, err := s.loadUserWallet(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -228,6 +240,9 @@ func (s *Service) Transfer(userID uint, to string, amount uint64) (*SendTxResult
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) TransferAsset(userID uint, assetID, to string) (*SendTxResult, error) {
|
func (s *Service) TransferAsset(userID uint, assetID, to string) (*SendTxResult, error) {
|
||||||
|
mu := s.getUserMu(userID)
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
w, pubKey, err := s.loadUserWallet(userID)
|
w, pubKey, err := s.loadUserWallet(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -244,6 +259,9 @@ func (s *Service) TransferAsset(userID uint, assetID, to string) (*SendTxResult,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) ListOnMarket(userID uint, assetID string, price uint64) (*SendTxResult, error) {
|
func (s *Service) ListOnMarket(userID uint, assetID string, price uint64) (*SendTxResult, error) {
|
||||||
|
mu := s.getUserMu(userID)
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
w, pubKey, err := s.loadUserWallet(userID)
|
w, pubKey, err := s.loadUserWallet(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -260,6 +278,9 @@ func (s *Service) ListOnMarket(userID uint, assetID string, price uint64) (*Send
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) BuyFromMarket(userID uint, listingID string) (*SendTxResult, error) {
|
func (s *Service) BuyFromMarket(userID uint, listingID string) (*SendTxResult, error) {
|
||||||
|
mu := s.getUserMu(userID)
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
w, pubKey, err := s.loadUserWallet(userID)
|
w, pubKey, err := s.loadUserWallet(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -276,6 +297,9 @@ func (s *Service) BuyFromMarket(userID uint, listingID string) (*SendTxResult, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) CancelListing(userID uint, listingID string) (*SendTxResult, error) {
|
func (s *Service) CancelListing(userID uint, listingID string) (*SendTxResult, error) {
|
||||||
|
mu := s.getUserMu(userID)
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
w, pubKey, err := s.loadUserWallet(userID)
|
w, pubKey, err := s.loadUserWallet(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -292,6 +316,9 @@ func (s *Service) CancelListing(userID uint, listingID string) (*SendTxResult, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) EquipItem(userID uint, assetID, slot string) (*SendTxResult, error) {
|
func (s *Service) EquipItem(userID uint, assetID, slot string) (*SendTxResult, error) {
|
||||||
|
mu := s.getUserMu(userID)
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
w, pubKey, err := s.loadUserWallet(userID)
|
w, pubKey, err := s.loadUserWallet(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -308,6 +335,9 @@ func (s *Service) EquipItem(userID uint, assetID, slot string) (*SendTxResult, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) UnequipItem(userID uint, assetID string) (*SendTxResult, error) {
|
func (s *Service) UnequipItem(userID uint, assetID string) (*SendTxResult, error) {
|
||||||
|
mu := s.getUserMu(userID)
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
w, pubKey, err := s.loadUserWallet(userID)
|
w, pubKey, err := s.loadUserWallet(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -340,6 +370,8 @@ func (s *Service) getOperatorNonce() (uint64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) MintAsset(templateID, ownerPubKey string, properties map[string]any) (*SendTxResult, error) {
|
func (s *Service) MintAsset(templateID, ownerPubKey string, properties map[string]any) (*SendTxResult, error) {
|
||||||
|
s.operatorMu.Lock()
|
||||||
|
defer s.operatorMu.Unlock()
|
||||||
if err := s.ensureOperator(); err != nil {
|
if err := s.ensureOperator(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -355,6 +387,8 @@ func (s *Service) MintAsset(templateID, ownerPubKey string, properties map[strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) GrantReward(recipientPubKey string, tokenAmount uint64, assets []core.MintAssetPayload) (*SendTxResult, error) {
|
func (s *Service) GrantReward(recipientPubKey string, tokenAmount uint64, assets []core.MintAssetPayload) (*SendTxResult, error) {
|
||||||
|
s.operatorMu.Lock()
|
||||||
|
defer s.operatorMu.Unlock()
|
||||||
if err := s.ensureOperator(); err != nil {
|
if err := s.ensureOperator(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -370,6 +404,8 @@ func (s *Service) GrantReward(recipientPubKey string, tokenAmount uint64, assets
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) RegisterTemplate(id, name string, schema map[string]any, tradeable bool) (*SendTxResult, error) {
|
func (s *Service) RegisterTemplate(id, name string, schema map[string]any, tradeable bool) (*SendTxResult, error) {
|
||||||
|
s.operatorMu.Lock()
|
||||||
|
defer s.operatorMu.Unlock()
|
||||||
if err := s.ensureOperator(); err != nil {
|
if err := s.ensureOperator(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,8 +14,12 @@ type Info struct {
|
|||||||
URL string `json:"url" gorm:"not null"`
|
URL string `json:"url" gorm:"not null"`
|
||||||
Version string `json:"version" gorm:"not null"`
|
Version string `json:"version" gorm:"not null"`
|
||||||
FileName string `json:"fileName" gorm:"not null"`
|
FileName string `json:"fileName" gorm:"not null"`
|
||||||
|
// FileSize is a human-readable string (e.g., "1.5 GB") for display purposes.
|
||||||
|
// Programmatic size tracking uses os.Stat on the actual file.
|
||||||
FileSize string `json:"fileSize" gorm:"not null"`
|
FileSize string `json:"fileSize" gorm:"not null"`
|
||||||
FileHash string `json:"fileHash" gorm:"not null;default:''"`
|
FileHash string `json:"fileHash" gorm:"not null;default:''"`
|
||||||
LauncherURL string `json:"launcherUrl" gorm:"not null;default:''"`
|
LauncherURL string `json:"launcherUrl" gorm:"not null;default:''"`
|
||||||
|
// LauncherSize is a human-readable string (e.g., "25.3 MB") for display purposes.
|
||||||
|
// Programmatic size tracking uses os.Stat on the actual file.
|
||||||
LauncherSize string `json:"launcherSize" gorm:"not null;default:''"`
|
LauncherSize string `json:"launcherSize" gorm:"not null;default:''"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var versionRe = regexp.MustCompile(`v\d+[\.\d]*`)
|
var versionRe = regexp.MustCompile(`v\d+\.\d+(\.\d+)?`)
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
repo *Repository
|
repo *Repository
|
||||||
@@ -48,6 +48,8 @@ func (s *Service) UploadLauncher(body io.Reader, baseURL string) (*Info, error)
|
|||||||
return nil, fmt.Errorf("파일 생성 실패: %w", err)
|
return nil, fmt.Errorf("파일 생성 실패: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: Partial uploads (client closes cleanly mid-transfer) are saved.
|
||||||
|
// The hashGameExeFromZip check mitigates this for game uploads but not for launcher uploads.
|
||||||
n, err := io.Copy(f, body)
|
n, err := io.Copy(f, body)
|
||||||
if closeErr := f.Close(); closeErr != nil && err == nil {
|
if closeErr := f.Close(); closeErr != nil && err == nil {
|
||||||
err = closeErr
|
err = closeErr
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
package player
|
package player
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -41,9 +45,23 @@ func (h *Handler) UpdateProfile(c *fiber.Ctx) error {
|
|||||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "잘못된 요청입니다"})
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "잘못된 요청입니다"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
req.Nickname = strings.TrimSpace(req.Nickname)
|
||||||
|
if req.Nickname != "" {
|
||||||
|
nicknameRunes := []rune(req.Nickname)
|
||||||
|
if len(nicknameRunes) < 2 || len(nicknameRunes) > 30 {
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "닉네임은 2~30자여야 합니다"})
|
||||||
|
}
|
||||||
|
for _, r := range nicknameRunes {
|
||||||
|
if unicode.IsControl(r) {
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "닉네임에 허용되지 않는 문자가 포함되어 있습니다"})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
profile, err := h.svc.UpdateProfile(userID, req.Nickname)
|
profile, err := h.svc.UpdateProfile(userID, req.Nickname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
log.Printf("프로필 수정 실패 (userID=%d): %v", userID, err)
|
||||||
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "서버 오류가 발생했습니다"})
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(profile)
|
return c.JSON(profile)
|
||||||
@@ -77,7 +95,8 @@ func (h *Handler) InternalSaveGameData(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := h.svc.SaveGameDataByUsername(username, &req); err != nil {
|
if err := h.svc.SaveGameDataByUsername(username, &req); err != nil {
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": err.Error()})
|
log.Printf("게임 데이터 저장 실패 (username=%s): %v", username, err)
|
||||||
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "서버 오류가 발생했습니다"})
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(fiber.Map{"message": "게임 데이터가 저장되었습니다"})
|
return c.JSON(fiber.Map{"message": "게임 데이터가 저장되었습니다"})
|
||||||
|
|||||||
@@ -6,6 +6,32 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// validateGameData checks that game data fields are within acceptable ranges.
|
||||||
|
func validateGameData(data *GameDataRequest) error {
|
||||||
|
if data.Level != nil && (*data.Level < 1 || *data.Level > 999) {
|
||||||
|
return fmt.Errorf("레벨은 1~999 범위여야 합니다")
|
||||||
|
}
|
||||||
|
if data.Experience != nil && *data.Experience < 0 {
|
||||||
|
return fmt.Errorf("경험치는 0 이상이어야 합니다")
|
||||||
|
}
|
||||||
|
if data.MaxHP != nil && (*data.MaxHP < 1 || *data.MaxHP > 999999) {
|
||||||
|
return fmt.Errorf("최대 HP는 1~999999 범위여야 합니다")
|
||||||
|
}
|
||||||
|
if data.MaxMP != nil && (*data.MaxMP < 1 || *data.MaxMP > 999999) {
|
||||||
|
return fmt.Errorf("최대 MP는 1~999999 범위여야 합니다")
|
||||||
|
}
|
||||||
|
if data.AttackPower != nil && (*data.AttackPower < 0 || *data.AttackPower > 999999) {
|
||||||
|
return fmt.Errorf("공격력은 0~999999 범위여야 합니다")
|
||||||
|
}
|
||||||
|
if data.AttackRange != nil && (*data.AttackRange < 0 || *data.AttackRange > 100) {
|
||||||
|
return fmt.Errorf("attack_range must be 0-100")
|
||||||
|
}
|
||||||
|
if data.PlayTimeDelta != nil && *data.PlayTimeDelta < 0 {
|
||||||
|
return fmt.Errorf("플레이 시간 변화량은 0 이상이어야 합니다")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
repo *Repository
|
repo *Repository
|
||||||
userResolver func(username string) (uint, error)
|
userResolver func(username string) (uint, error)
|
||||||
@@ -68,6 +94,10 @@ func (s *Service) UpdateProfile(userID uint, nickname string) (*PlayerProfile, e
|
|||||||
|
|
||||||
// SaveGameData 게임 서버에서 호출: 게임 데이터를 저장한다.
|
// SaveGameData 게임 서버에서 호출: 게임 데이터를 저장한다.
|
||||||
func (s *Service) SaveGameData(userID uint, data *GameDataRequest) error {
|
func (s *Service) SaveGameData(userID uint, data *GameDataRequest) error {
|
||||||
|
if err := validateGameData(data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
updates := map[string]interface{}{}
|
updates := map[string]interface{}{}
|
||||||
|
|
||||||
if data.Level != nil {
|
if data.Level != nil {
|
||||||
@@ -124,6 +154,7 @@ func (s *Service) SaveGameDataByUsername(username string, data *GameDataRequest)
|
|||||||
if s.userResolver == nil {
|
if s.userResolver == nil {
|
||||||
return fmt.Errorf("userResolver가 설정되지 않았습니다")
|
return fmt.Errorf("userResolver가 설정되지 않았습니다")
|
||||||
}
|
}
|
||||||
|
// Note: validateGameData is called inside SaveGameData, no need to call it here.
|
||||||
userID, err := s.userResolver(username)
|
userID, err := s.userResolver(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("존재하지 않는 유저입니다")
|
return fmt.Errorf("존재하지 않는 유저입니다")
|
||||||
|
|||||||
7
main.go
7
main.go
@@ -131,9 +131,10 @@ func main() {
|
|||||||
}))
|
}))
|
||||||
app.Use(middleware.SecurityHeaders)
|
app.Use(middleware.SecurityHeaders)
|
||||||
app.Use(cors.New(cors.Config{
|
app.Use(cors.New(cors.Config{
|
||||||
AllowOrigins: "https://a301.tolelom.xyz",
|
AllowOrigins: "https://a301.tolelom.xyz",
|
||||||
AllowHeaders: "Origin, Content-Type, Authorization, Idempotency-Key, X-API-Key",
|
AllowHeaders: "Origin, Content-Type, Authorization, Idempotency-Key, X-API-Key",
|
||||||
AllowMethods: "GET, POST, PUT, PATCH, DELETE",
|
AllowMethods: "GET, POST, PUT, PATCH, DELETE",
|
||||||
|
AllowCredentials: true,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Rate limiting: 인증 관련 엔드포인트 (로그인/회원가입/리프레시)
|
// Rate limiting: 인증 관련 엔드포인트 (로그인/회원가입/리프레시)
|
||||||
|
|||||||
@@ -103,6 +103,8 @@ func WarnInsecureDefaults() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getEnv returns the environment variable value, or fallback if unset or empty.
|
||||||
|
// Note: explicitly setting a variable to "" is treated as unset.
|
||||||
func getEnv(key, fallback string) string {
|
func getEnv(key, fallback string) string {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
return v
|
return v
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO: Consider injecting DB as a dependency instead of using a package-level global
|
||||||
|
// to improve testability. Currently, middleware directly accesses this global.
|
||||||
var DB *gorm.DB
|
var DB *gorm.DB
|
||||||
|
|
||||||
func ConnectMySQL() error {
|
func ConnectMySQL() error {
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO: Consider injecting RDB as a dependency instead of using a package-level global
|
||||||
|
// to improve testability. Currently, middleware directly accesses this global.
|
||||||
var RDB *redis.Client
|
var RDB *redis.Client
|
||||||
|
|
||||||
func ConnectRedis() error {
|
func ConnectRedis() error {
|
||||||
|
|||||||
28
pkg/middleware/bodylimit.go
Normal file
28
pkg/middleware/bodylimit.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BodyLimit rejects requests whose Content-Length header exceeds maxBytes.
|
||||||
|
// NOTE: Only checks Content-Length header. Chunked requests without Content-Length
|
||||||
|
// bypass this check. Fiber's global BodyLimit provides the final safety net.
|
||||||
|
// Paths matching any of the excludePrefixes are skipped (e.g. upload endpoints
|
||||||
|
// that legitimately need the global 4GB limit).
|
||||||
|
func BodyLimit(maxBytes int, excludePrefixes ...string) fiber.Handler {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
for _, prefix := range excludePrefixes {
|
||||||
|
if strings.HasPrefix(c.Path(), prefix) {
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.Request().Header.ContentLength() > maxBytes {
|
||||||
|
return c.Status(fiber.StatusRequestEntityTooLarge).JSON(fiber.Map{
|
||||||
|
"error": "요청이 너무 큽니다",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,6 +19,17 @@ type cachedResponse struct {
|
|||||||
Body json.RawMessage `json:"b"`
|
Body json.RawMessage `json:"b"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IdempotencyRequired rejects requests without an Idempotency-Key header,
|
||||||
|
// then delegates to Idempotency for cache/replay logic.
|
||||||
|
func IdempotencyRequired(c *fiber.Ctx) error {
|
||||||
|
if c.Get("Idempotency-Key") == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||||
|
"error": "Idempotency-Key 헤더가 필요합니다",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return Idempotency(c)
|
||||||
|
}
|
||||||
|
|
||||||
// Idempotency checks the Idempotency-Key header to prevent duplicate transactions.
|
// Idempotency checks the Idempotency-Key header to prevent duplicate transactions.
|
||||||
// If the same key is seen again within the TTL, the cached response is returned.
|
// If the same key is seen again within the TTL, the cached response is returned.
|
||||||
func Idempotency(c *fiber.Ctx) error {
|
func Idempotency(c *fiber.Ctx) error {
|
||||||
@@ -40,23 +51,45 @@ func Idempotency(c *fiber.Ctx) error {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), redisTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), redisTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Check if this key was already processed
|
// Atomically claim the key using SET NX (only succeeds if key doesn't exist)
|
||||||
cached, err := database.RDB.Get(ctx, redisKey).Bytes()
|
set, err := database.RDB.SetNX(ctx, redisKey, "processing", idempotencyTTL).Result()
|
||||||
if err == nil && len(cached) > 0 {
|
if err != nil {
|
||||||
|
// Redis error — let the request through rather than blocking
|
||||||
|
log.Printf("WARNING: idempotency SetNX failed (key=%s): %v", key, err)
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !set {
|
||||||
|
// Key already exists — either processing or completed
|
||||||
|
getCtx, getCancel := context.WithTimeout(context.Background(), redisTimeout)
|
||||||
|
defer getCancel()
|
||||||
|
|
||||||
|
cached, err := database.RDB.Get(getCtx, redisKey).Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "요청이 처리 중입니다"})
|
||||||
|
}
|
||||||
|
if string(cached) == "processing" {
|
||||||
|
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "요청이 처리 중입니다"})
|
||||||
|
}
|
||||||
var cr cachedResponse
|
var cr cachedResponse
|
||||||
if json.Unmarshal(cached, &cr) == nil {
|
if json.Unmarshal(cached, &cr) == nil {
|
||||||
c.Set("Content-Type", "application/json")
|
c.Set("Content-Type", "application/json")
|
||||||
c.Set("X-Idempotent-Replay", "true")
|
c.Set("X-Idempotent-Replay", "true")
|
||||||
return c.Status(cr.StatusCode).Send(cr.Body)
|
return c.Status(cr.StatusCode).Send(cr.Body)
|
||||||
}
|
}
|
||||||
|
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "요청이 처리 중입니다"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the request
|
// We claimed the key — process the request
|
||||||
if err := c.Next(); err != nil {
|
if err := c.Next(); err != nil {
|
||||||
|
// Processing failed — remove the key so it can be retried
|
||||||
|
delCtx, delCancel := context.WithTimeout(context.Background(), redisTimeout)
|
||||||
|
defer delCancel()
|
||||||
|
database.RDB.Del(delCtx, redisKey)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache successful responses (2xx)
|
// Cache successful responses (2xx), otherwise remove the key for retry
|
||||||
status := c.Response().StatusCode()
|
status := c.Response().StatusCode()
|
||||||
if status >= 200 && status < 300 {
|
if status >= 200 && status < 300 {
|
||||||
cr := cachedResponse{StatusCode: status, Body: c.Response().Body()}
|
cr := cachedResponse{StatusCode: status, Body: c.Response().Body()}
|
||||||
@@ -67,6 +100,11 @@ func Idempotency(c *fiber.Ctx) error {
|
|||||||
log.Printf("WARNING: idempotency cache write failed (key=%s): %v", key, err)
|
log.Printf("WARNING: idempotency cache write failed (key=%s): %v", key, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// Non-success — allow retry by removing the key
|
||||||
|
delCtx, delCancel := context.WithTimeout(context.Background(), redisTimeout)
|
||||||
|
defer delCancel()
|
||||||
|
database.RDB.Del(delCtx, redisKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
@@ -11,6 +13,17 @@ func RequestID(c *fiber.Ctx) error {
|
|||||||
if id == "" {
|
if id == "" {
|
||||||
id = uuid.NewString()
|
id = uuid.NewString()
|
||||||
}
|
}
|
||||||
|
// Truncate client-provided request IDs to prevent abuse
|
||||||
|
if len(id) > 64 {
|
||||||
|
id = id[:64]
|
||||||
|
}
|
||||||
|
// Strip non-printable characters to prevent log injection
|
||||||
|
id = strings.Map(func(r rune) rune {
|
||||||
|
if r < 32 || r == 127 {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}, id)
|
||||||
c.Locals("requestID", id)
|
c.Locals("requestID", id)
|
||||||
c.Set("X-Request-ID", id)
|
c.Set("X-Request-ID", id)
|
||||||
return c.Next()
|
return c.Next()
|
||||||
|
|||||||
@@ -29,7 +29,9 @@ func Register(
|
|||||||
app.Get("/health", healthCheck)
|
app.Get("/health", healthCheck)
|
||||||
app.Get("/ready", readyCheck)
|
app.Get("/ready", readyCheck)
|
||||||
|
|
||||||
api := app.Group("/api", apiLimiter)
|
// Default 1MB body limit for API routes; upload endpoints are excluded
|
||||||
|
apiBodyLimit := middleware.BodyLimit(1*1024*1024, "/api/download/upload")
|
||||||
|
api := app.Group("/api", apiLimiter, apiBodyLimit)
|
||||||
|
|
||||||
// Auth
|
// Auth
|
||||||
a := api.Group("/auth")
|
a := api.Group("/auth")
|
||||||
@@ -37,9 +39,11 @@ func Register(
|
|||||||
a.Post("/login", authLimiter, authH.Login)
|
a.Post("/login", authLimiter, authH.Login)
|
||||||
a.Post("/refresh", authLimiter, authH.Refresh)
|
a.Post("/refresh", authLimiter, authH.Refresh)
|
||||||
a.Post("/logout", middleware.Auth, authH.Logout)
|
a.Post("/logout", middleware.Auth, authH.Logout)
|
||||||
a.Post("/verify", authLimiter, authH.VerifyToken)
|
// /verify moved to internal API (ServerAuth) — see internal section below
|
||||||
a.Get("/ssafy/login", authH.SSAFYLoginURL)
|
a.Get("/ssafy/login", authH.SSAFYLoginURL)
|
||||||
a.Post("/ssafy/callback", authLimiter, authH.SSAFYCallback)
|
a.Post("/ssafy/callback", authLimiter, authH.SSAFYCallback)
|
||||||
|
a.Post("/launch-ticket", middleware.Auth, authH.CreateLaunchTicket)
|
||||||
|
a.Post("/redeem-ticket", authLimiter, authH.RedeemLaunchTicket)
|
||||||
|
|
||||||
// Users (admin only)
|
// Users (admin only)
|
||||||
u := api.Group("/users", middleware.Auth, middleware.AdminOnly)
|
u := api.Group("/users", middleware.Auth, middleware.AdminOnly)
|
||||||
@@ -73,19 +77,19 @@ func Register(
|
|||||||
ch.Get("/market/:id", chainH.GetMarketListing)
|
ch.Get("/market/:id", chainH.GetMarketListing)
|
||||||
|
|
||||||
// Chain - User Transactions (authenticated, per-user rate limited, idempotency-protected)
|
// Chain - User Transactions (authenticated, per-user rate limited, idempotency-protected)
|
||||||
ch.Post("/transfer", chainUserLimiter, middleware.Idempotency, chainH.Transfer)
|
ch.Post("/transfer", chainUserLimiter, middleware.IdempotencyRequired, chainH.Transfer)
|
||||||
ch.Post("/asset/transfer", chainUserLimiter, middleware.Idempotency, chainH.TransferAsset)
|
ch.Post("/asset/transfer", chainUserLimiter, middleware.IdempotencyRequired, chainH.TransferAsset)
|
||||||
ch.Post("/market/list", chainUserLimiter, middleware.Idempotency, chainH.ListOnMarket)
|
ch.Post("/market/list", chainUserLimiter, middleware.IdempotencyRequired, chainH.ListOnMarket)
|
||||||
ch.Post("/market/buy", chainUserLimiter, middleware.Idempotency, chainH.BuyFromMarket)
|
ch.Post("/market/buy", chainUserLimiter, middleware.IdempotencyRequired, chainH.BuyFromMarket)
|
||||||
ch.Post("/market/cancel", chainUserLimiter, middleware.Idempotency, chainH.CancelListing)
|
ch.Post("/market/cancel", chainUserLimiter, middleware.IdempotencyRequired, chainH.CancelListing)
|
||||||
ch.Post("/inventory/equip", chainUserLimiter, middleware.Idempotency, chainH.EquipItem)
|
ch.Post("/inventory/equip", chainUserLimiter, middleware.IdempotencyRequired, chainH.EquipItem)
|
||||||
ch.Post("/inventory/unequip", chainUserLimiter, middleware.Idempotency, chainH.UnequipItem)
|
ch.Post("/inventory/unequip", chainUserLimiter, middleware.IdempotencyRequired, chainH.UnequipItem)
|
||||||
|
|
||||||
// Chain - Admin Transactions (admin only, idempotency-protected)
|
// Chain - Admin Transactions (admin only, idempotency-protected)
|
||||||
chainAdmin := api.Group("/chain/admin", middleware.Auth, middleware.AdminOnly)
|
chainAdmin := api.Group("/chain/admin", middleware.Auth, middleware.AdminOnly)
|
||||||
chainAdmin.Post("/mint", middleware.Idempotency, chainH.MintAsset)
|
chainAdmin.Post("/mint", middleware.IdempotencyRequired, chainH.MintAsset)
|
||||||
chainAdmin.Post("/reward", middleware.Idempotency, chainH.GrantReward)
|
chainAdmin.Post("/reward", middleware.IdempotencyRequired, chainH.GrantReward)
|
||||||
chainAdmin.Post("/template", middleware.Idempotency, chainH.RegisterTemplate)
|
chainAdmin.Post("/template", middleware.IdempotencyRequired, chainH.RegisterTemplate)
|
||||||
|
|
||||||
// Boss Raid - Client entry (JWT authenticated)
|
// Boss Raid - Client entry (JWT authenticated)
|
||||||
bossRaid := api.Group("/bossraid", middleware.Auth)
|
bossRaid := api.Group("/bossraid", middleware.Auth)
|
||||||
@@ -96,7 +100,7 @@ func Register(
|
|||||||
br := api.Group("/internal/bossraid", middleware.ServerAuth)
|
br := api.Group("/internal/bossraid", middleware.ServerAuth)
|
||||||
br.Post("/entry", brH.RequestEntry)
|
br.Post("/entry", brH.RequestEntry)
|
||||||
br.Post("/start", brH.StartRaid)
|
br.Post("/start", brH.StartRaid)
|
||||||
br.Post("/complete", middleware.Idempotency, brH.CompleteRaid)
|
br.Post("/complete", middleware.IdempotencyRequired, brH.CompleteRaid)
|
||||||
br.Post("/fail", brH.FailRaid)
|
br.Post("/fail", brH.FailRaid)
|
||||||
br.Get("/room", brH.GetRoom)
|
br.Get("/room", brH.GetRoom)
|
||||||
br.Post("/validate-entry", brH.ValidateEntryToken)
|
br.Post("/validate-entry", brH.ValidateEntryToken)
|
||||||
@@ -106,6 +110,10 @@ func Register(
|
|||||||
p.Get("/profile", playerH.GetProfile)
|
p.Get("/profile", playerH.GetProfile)
|
||||||
p.Put("/profile", playerH.UpdateProfile)
|
p.Put("/profile", playerH.UpdateProfile)
|
||||||
|
|
||||||
|
// Internal - Auth (API key auth)
|
||||||
|
internalAuth := api.Group("/internal/auth", middleware.ServerAuth)
|
||||||
|
internalAuth.Post("/verify", authH.VerifyToken)
|
||||||
|
|
||||||
// Internal - Player (API key auth)
|
// Internal - Player (API key auth)
|
||||||
internalPlayer := api.Group("/internal/player", middleware.ServerAuth)
|
internalPlayer := api.Group("/internal/player", middleware.ServerAuth)
|
||||||
internalPlayer.Get("/profile", playerH.InternalGetProfile)
|
internalPlayer.Get("/profile", playerH.InternalGetProfile)
|
||||||
@@ -113,8 +121,8 @@ func Register(
|
|||||||
|
|
||||||
// Internal - Game server endpoints (API key auth, username-based, idempotency-protected)
|
// Internal - Game server endpoints (API key auth, username-based, idempotency-protected)
|
||||||
internal := api.Group("/internal/chain", middleware.ServerAuth)
|
internal := api.Group("/internal/chain", middleware.ServerAuth)
|
||||||
internal.Post("/reward", middleware.Idempotency, chainH.InternalGrantReward)
|
internal.Post("/reward", middleware.IdempotencyRequired, chainH.InternalGrantReward)
|
||||||
internal.Post("/mint", middleware.Idempotency, chainH.InternalMintAsset)
|
internal.Post("/mint", middleware.IdempotencyRequired, chainH.InternalMintAsset)
|
||||||
internal.Get("/balance", chainH.InternalGetBalance)
|
internal.Get("/balance", chainH.InternalGetBalance)
|
||||||
internal.Get("/assets", chainH.InternalGetAssets)
|
internal.Get("/assets", chainH.InternalGetAssets)
|
||||||
internal.Get("/inventory", chainH.InternalGetInventory)
|
internal.Get("/inventory", chainH.InternalGetInventory)
|
||||||
|
|||||||
Reference in New Issue
Block a user