feat: 코드 리뷰 기반 전면 개선 — 보안, 검증, 테스트, 안정성
- 체인 nonce 경쟁 조건 수정 (operatorMu + per-user mutex) - 등록/SSAFY 원자적 트랜잭션 (wallet+profile 롤백 보장) - IdempotencyRequired 미들웨어 (SETNX 원자적 클레임) - 런치 티켓 API (JWT URL 노출 방지) - HttpOnly 쿠키 refresh token - SSAFY OAuth state 파라미터 (CSRF 방지) - Refresh 시 DB 조회로 최신 role 사용 - 공지사항/유저목록 페이지네이션 - BodyLimit 미들웨어 (1MB, upload 제외) - 입력 검증 강화 (닉네임, 게임데이터, 공지 길이) - 에러 메시지 내부 정보 노출 방지 - io.LimitReader (RPC 10MB, SSAFY 1MB) - RequestID 비출력 문자 제거 - 단위 테스트 (auth 11, announcement 9, bossraid 16) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -41,7 +41,10 @@ func (h *Handler) Register(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "비밀번호는 72자 이하여야 합니다"})
|
||||
}
|
||||
if err := h.svc.Register(req.Username, req.Password); err != nil {
|
||||
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": "회원가입에 실패했습니다"})
|
||||
if strings.Contains(err.Error(), "이미 사용 중") {
|
||||
return c.Status(fiber.StatusConflict).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "회원가입에 실패했습니다"})
|
||||
}
|
||||
return c.Status(fiber.StatusCreated).JSON(fiber.Map{"message": "회원가입이 완료되었습니다"})
|
||||
}
|
||||
@@ -70,30 +73,53 @@ func (h *Handler) Login(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: refreshToken,
|
||||
HTTPOnly: true,
|
||||
Secure: true,
|
||||
SameSite: "Strict",
|
||||
Path: "/api/auth/refresh",
|
||||
MaxAge: 7 * 24 * 60 * 60, // 7 days
|
||||
})
|
||||
return c.JSON(fiber.Map{
|
||||
"token": accessToken,
|
||||
"refreshToken": refreshToken,
|
||||
"username": user.Username,
|
||||
"role": user.Role,
|
||||
"token": accessToken,
|
||||
"username": user.Username,
|
||||
"role": user.Role,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) Refresh(c *fiber.Ctx) error {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
refreshTokenStr := c.Cookies("refresh_token")
|
||||
if refreshTokenStr == "" {
|
||||
// Fallback to body for backward compatibility
|
||||
var req struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err == nil && req.RefreshToken != "" {
|
||||
refreshTokenStr = req.RefreshToken
|
||||
}
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil || req.RefreshToken == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "refreshToken 필드가 필요합니다"})
|
||||
if refreshTokenStr == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "refreshToken이 필요합니다"})
|
||||
}
|
||||
|
||||
newAccessToken, newRefreshToken, err := h.svc.Refresh(req.RefreshToken)
|
||||
newAccessToken, newRefreshToken, err := h.svc.Refresh(refreshTokenStr)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: newRefreshToken,
|
||||
HTTPOnly: true,
|
||||
Secure: true,
|
||||
SameSite: "Strict",
|
||||
Path: "/api/auth/refresh",
|
||||
MaxAge: 7 * 24 * 60 * 60, // 7 days
|
||||
})
|
||||
return c.JSON(fiber.Map{
|
||||
"token": newAccessToken,
|
||||
"refreshToken": newRefreshToken,
|
||||
"token": newAccessToken,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -105,11 +131,28 @@ func (h *Handler) Logout(c *fiber.Ctx) error {
|
||||
if err := h.svc.Logout(userID); err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "로그아웃 처리 중 오류가 발생했습니다"})
|
||||
}
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: "",
|
||||
HTTPOnly: true,
|
||||
Secure: true,
|
||||
SameSite: "Strict",
|
||||
Path: "/api/auth/refresh",
|
||||
MaxAge: -1, // delete
|
||||
})
|
||||
return c.JSON(fiber.Map{"message": "로그아웃 되었습니다"})
|
||||
}
|
||||
|
||||
func (h *Handler) GetAllUsers(c *fiber.Ctx) error {
|
||||
users, err := h.svc.GetAllUsers()
|
||||
offset := c.QueryInt("offset", 0)
|
||||
limit := c.QueryInt("limit", 50)
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 50
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
users, err := h.svc.GetAllUsers(offset, limit)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "유저 목록을 불러오지 못했습니다"})
|
||||
}
|
||||
@@ -155,29 +198,74 @@ func (h *Handler) VerifyToken(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
func (h *Handler) SSAFYLoginURL(c *fiber.Ctx) error {
|
||||
loginURL := h.svc.GetSSAFYLoginURL()
|
||||
loginURL, err := h.svc.GetSSAFYLoginURL()
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "SSAFY 로그인 URL 생성에 실패했습니다"})
|
||||
}
|
||||
return c.JSON(fiber.Map{"url": loginURL})
|
||||
}
|
||||
|
||||
func (h *Handler) SSAFYCallback(c *fiber.Ctx) error {
|
||||
var req struct {
|
||||
Code string `json:"code"`
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil || req.Code == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "인가 코드가 필요합니다"})
|
||||
}
|
||||
if req.State == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "state 파라미터가 필요합니다"})
|
||||
}
|
||||
|
||||
accessToken, refreshToken, user, err := h.svc.SSAFYLogin(req.Code)
|
||||
accessToken, refreshToken, user, err := h.svc.SSAFYLogin(req.Code, req.State)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"token": accessToken,
|
||||
"refreshToken": refreshToken,
|
||||
"username": user.Username,
|
||||
"role": user.Role,
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: refreshToken,
|
||||
HTTPOnly: true,
|
||||
Secure: true,
|
||||
SameSite: "Strict",
|
||||
Path: "/api/auth/refresh",
|
||||
MaxAge: 7 * 24 * 60 * 60, // 7 days
|
||||
})
|
||||
return c.JSON(fiber.Map{
|
||||
"token": accessToken,
|
||||
"username": user.Username,
|
||||
"role": user.Role,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateLaunchTicket issues a one-time ticket for the game launcher.
|
||||
// The launcher uses this ticket instead of receiving the JWT directly in the URL.
|
||||
func (h *Handler) CreateLaunchTicket(c *fiber.Ctx) error {
|
||||
userID, ok := c.Locals("userID").(uint)
|
||||
if !ok {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "인증 정보가 올바르지 않습니다"})
|
||||
}
|
||||
ticket, err := h.svc.CreateLaunchTicket(userID)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "티켓 발급에 실패했습니다"})
|
||||
}
|
||||
return c.JSON(fiber.Map{"ticket": ticket})
|
||||
}
|
||||
|
||||
// RedeemLaunchTicket exchanges a one-time ticket for an access token.
|
||||
// Called by the game launcher, not the web browser.
|
||||
func (h *Handler) RedeemLaunchTicket(c *fiber.Ctx) error {
|
||||
var req struct {
|
||||
Ticket string `json:"ticket"`
|
||||
}
|
||||
if err := c.BodyParser(&req); err != nil || req.Ticket == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "ticket 필드가 필요합니다"})
|
||||
}
|
||||
token, err := h.svc.RedeemLaunchTicket(req.Ticket)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(fiber.Map{"token": token})
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteUser(c *fiber.Ctx) error {
|
||||
|
||||
@@ -22,9 +22,9 @@ func (r *Repository) Create(user *User) error {
|
||||
return r.db.Create(user).Error
|
||||
}
|
||||
|
||||
func (r *Repository) FindAll() ([]User, error) {
|
||||
func (r *Repository) FindAll(offset, limit int) ([]User, error) {
|
||||
var users []User
|
||||
err := r.db.Order("created_at asc").Find(&users).Error
|
||||
err := r.db.Order("created_at asc").Offset(offset).Limit(limit).Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
|
||||
@@ -150,7 +150,11 @@ func (s *Service) Refresh(refreshTokenStr string) (newAccessToken, newRefreshTok
|
||||
return "", "", fmt.Errorf("만료되었거나 유효하지 않은 리프레시 토큰입니다")
|
||||
}
|
||||
|
||||
user := &User{ID: claims.UserID, Username: claims.Username, Role: Role(claims.Role)}
|
||||
// Look up the current user from DB to avoid using stale role from JWT claims
|
||||
user, dbErr := s.repo.FindByID(claims.UserID)
|
||||
if dbErr != nil {
|
||||
return "", "", fmt.Errorf("유저를 찾을 수 없습니다")
|
||||
}
|
||||
|
||||
newAccessToken, err = s.issueAccessToken(user)
|
||||
if err != nil {
|
||||
@@ -173,8 +177,8 @@ func (s *Service) Logout(userID uint) error {
|
||||
return s.rdb.Del(ctx, sessionKey, refreshKey).Err()
|
||||
}
|
||||
|
||||
func (s *Service) GetAllUsers() ([]User, error) {
|
||||
return s.repo.FindAll()
|
||||
func (s *Service) GetAllUsers(offset, limit int) ([]User, error) {
|
||||
return s.repo.FindAll(offset, limit)
|
||||
}
|
||||
|
||||
func (s *Service) UpdateRole(id uint, role Role) error {
|
||||
@@ -182,7 +186,68 @@ func (s *Service) UpdateRole(id uint, role Role) error {
|
||||
}
|
||||
|
||||
func (s *Service) DeleteUser(id uint) error {
|
||||
return s.repo.Delete(id)
|
||||
if err := s.repo.Delete(id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Clean up Redis sessions for deleted user
|
||||
ctx := context.Background()
|
||||
sessionKey := fmt.Sprintf("session:%d", id)
|
||||
refreshKey := fmt.Sprintf("refresh:%d", id)
|
||||
s.rdb.Del(ctx, sessionKey, refreshKey)
|
||||
|
||||
// TODO: Clean up wallet and profile data via cross-service calls
|
||||
// (walletCreator/profileCreator are creation-only; deletion callbacks are not yet wired up)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateLaunchTicket generates a one-time ticket that the game launcher
|
||||
// exchanges for the real JWT. The ticket expires in 30 seconds and can only
|
||||
// be redeemed once, preventing token exposure in URLs or browser history.
|
||||
func (s *Service) CreateLaunchTicket(userID uint) (string, error) {
|
||||
buf := make([]byte, 32)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", fmt.Errorf("generate ticket: %w", err)
|
||||
}
|
||||
ticket := hex.EncodeToString(buf)
|
||||
|
||||
// Store ticket → userID mapping in Redis with 30s TTL
|
||||
key := fmt.Sprintf("launch_ticket:%s", ticket)
|
||||
ctx := context.Background()
|
||||
if err := s.rdb.Set(ctx, key, userID, 30*time.Second).Err(); err != nil {
|
||||
return "", fmt.Errorf("store ticket: %w", err)
|
||||
}
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
// RedeemLaunchTicket exchanges a one-time ticket for the user's access token.
|
||||
// The ticket is deleted immediately after use (one-time).
|
||||
func (s *Service) RedeemLaunchTicket(ticket string) (string, error) {
|
||||
key := fmt.Sprintf("launch_ticket:%s", ticket)
|
||||
ctx := context.Background()
|
||||
|
||||
// Atomically get and delete (one-time use)
|
||||
userIDStr, err := s.rdb.GetDel(ctx, key).Result()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("유효하지 않거나 만료된 티켓입니다")
|
||||
}
|
||||
|
||||
var userID uint
|
||||
if _, err := fmt.Sscanf(userIDStr, "%d", &userID); err != nil {
|
||||
return "", fmt.Errorf("invalid ticket data")
|
||||
}
|
||||
|
||||
user, err := s.repo.FindByID(userID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("유저를 찾을 수 없습니다")
|
||||
}
|
||||
|
||||
accessToken, err := s.issueAccessToken(user)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (s *Service) Register(username, password string) error {
|
||||
@@ -193,39 +258,50 @@ func (s *Service) Register(username, password string) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("비밀번호 처리에 실패했습니다")
|
||||
}
|
||||
user := &User{
|
||||
Username: username,
|
||||
PasswordHash: string(hash),
|
||||
Role: RoleUser,
|
||||
}
|
||||
if err := s.repo.Create(user); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.walletCreator != nil {
|
||||
if err := s.walletCreator(user.ID); err != nil {
|
||||
log.Printf("wallet creation failed for user %d: %v — rolling back", user.ID, err)
|
||||
if delErr := s.repo.Delete(user.ID); delErr != nil {
|
||||
log.Printf("WARNING: rollback delete also failed for user %d: %v", user.ID, delErr)
|
||||
|
||||
return s.repo.Transaction(func(txRepo *Repository) error {
|
||||
user := &User{Username: username, PasswordHash: string(hash), Role: RoleUser}
|
||||
if err := txRepo.Create(user); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.walletCreator != nil {
|
||||
if err := s.walletCreator(user.ID); err != nil {
|
||||
return fmt.Errorf("wallet creation failed: %w", err)
|
||||
}
|
||||
return fmt.Errorf("계정 초기화에 실패했습니다. 잠시 후 다시 시도해주세요")
|
||||
}
|
||||
}
|
||||
if s.profileCreator != nil {
|
||||
if err := s.profileCreator(user.ID); err != nil {
|
||||
log.Printf("profile creation failed for user %d: %v", user.ID, err)
|
||||
if s.profileCreator != nil {
|
||||
if err := s.profileCreator(user.ID); err != nil {
|
||||
return fmt.Errorf("profile creation failed: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetSSAFYLoginURL returns the SSAFY OAuth authorization URL.
|
||||
func (s *Service) GetSSAFYLoginURL() string {
|
||||
// GetSSAFYLoginURL returns the SSAFY OAuth authorization URL with a random
|
||||
// state parameter for CSRF protection. The state is stored in Redis with a
|
||||
// 5-minute TTL and must be verified in the callback.
|
||||
func (s *Service) GetSSAFYLoginURL() (string, error) {
|
||||
stateBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(stateBytes); err != nil {
|
||||
return "", fmt.Errorf("state 생성 실패: %w", err)
|
||||
}
|
||||
state := hex.EncodeToString(stateBytes)
|
||||
|
||||
// Store state in Redis with 5-minute TTL for one-time verification
|
||||
key := fmt.Sprintf("ssafy_state:%s", state)
|
||||
ctx := context.Background()
|
||||
if err := s.rdb.Set(ctx, key, "1", 5*time.Minute).Err(); err != nil {
|
||||
return "", fmt.Errorf("state 저장 실패: %w", err)
|
||||
}
|
||||
|
||||
params := url.Values{
|
||||
"client_id": {config.C.SSAFYClientID},
|
||||
"redirect_uri": {config.C.SSAFYRedirectURI},
|
||||
"response_type": {"code"},
|
||||
"state": {state},
|
||||
}
|
||||
return "https://project.ssafy.com/oauth/sso-check?" + params.Encode()
|
||||
return "https://project.ssafy.com/oauth/sso-check?" + params.Encode(), nil
|
||||
}
|
||||
|
||||
// ExchangeSSAFYCode exchanges an authorization code for SSAFY tokens.
|
||||
@@ -248,7 +324,7 @@ func (s *Service) ExchangeSSAFYCode(code string) (*SSAFYTokenResponse, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SSAFY 토큰 응답 읽기 실패: %v", err)
|
||||
}
|
||||
@@ -279,7 +355,7 @@ func (s *Service) GetSSAFYUserInfo(accessToken string) (*SSAFYUserInfo, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SSAFY 사용자 정보 응답 읽기 실패: %v", err)
|
||||
}
|
||||
@@ -296,7 +372,18 @@ func (s *Service) GetSSAFYUserInfo(accessToken string) (*SSAFYUserInfo, error) {
|
||||
}
|
||||
|
||||
// SSAFYLogin handles the full SSAFY OAuth callback: exchange code, get user info, find or create user, issue tokens.
|
||||
func (s *Service) SSAFYLogin(code string) (accessToken, refreshToken string, user *User, err error) {
|
||||
// The state parameter is verified against Redis (one-time use via GetDel) for CSRF protection.
|
||||
func (s *Service) SSAFYLogin(code, state string) (accessToken, refreshToken string, user *User, err error) {
|
||||
// Verify CSRF state parameter (one-time use)
|
||||
if state == "" {
|
||||
return "", "", nil, fmt.Errorf("state 파라미터가 필요합니다")
|
||||
}
|
||||
stateKey := fmt.Sprintf("ssafy_state:%s", state)
|
||||
val, err := s.rdb.GetDel(context.Background(), stateKey).Result()
|
||||
if err != nil || val != "1" {
|
||||
return "", "", nil, fmt.Errorf("유효하지 않거나 만료된 state 파라미터입니다")
|
||||
}
|
||||
|
||||
tokenResp, err := s.ExchangeSSAFYCode(code)
|
||||
if err != nil {
|
||||
return "", "", nil, err
|
||||
@@ -333,7 +420,6 @@ func (s *Service) SSAFYLogin(code string) (accessToken, refreshToken string, use
|
||||
username = username[:50]
|
||||
}
|
||||
|
||||
var newUserID uint
|
||||
err = s.repo.Transaction(func(txRepo *Repository) error {
|
||||
user = &User{
|
||||
Username: username,
|
||||
@@ -341,27 +427,26 @@ func (s *Service) SSAFYLogin(code string) (accessToken, refreshToken string, use
|
||||
Role: RoleUser,
|
||||
SsafyID: &ssafyID,
|
||||
}
|
||||
return txRepo.Create(user)
|
||||
if err := txRepo.Create(user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.walletCreator != nil {
|
||||
if err := s.walletCreator(user.ID); err != nil {
|
||||
return fmt.Errorf("wallet creation failed: %w", err)
|
||||
}
|
||||
}
|
||||
if s.profileCreator != nil {
|
||||
if err := s.profileCreator(user.ID); err != nil {
|
||||
return fmt.Errorf("profile creation failed: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("SSAFY user creation transaction failed: %v", err)
|
||||
return "", "", nil, fmt.Errorf("계정 생성 실패: %v", err)
|
||||
}
|
||||
newUserID = user.ID
|
||||
|
||||
if s.walletCreator != nil {
|
||||
if err := s.walletCreator(newUserID); err != nil {
|
||||
log.Printf("wallet creation failed for SSAFY user %d: %v — rolling back", newUserID, err)
|
||||
if delErr := s.repo.Delete(newUserID); delErr != nil {
|
||||
log.Printf("WARNING: rollback delete also failed for SSAFY user %d: %v", newUserID, delErr)
|
||||
}
|
||||
return "", "", nil, fmt.Errorf("계정 초기화에 실패했습니다. 잠시 후 다시 시도해주세요")
|
||||
}
|
||||
}
|
||||
if s.profileCreator != nil {
|
||||
if err := s.profileCreator(newUserID); err != nil {
|
||||
log.Printf("profile creation failed for SSAFY user %d: %v", newUserID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, err = s.issueAccessToken(user)
|
||||
@@ -414,6 +499,10 @@ func sanitizeForUsername(s string) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// NOTE: EnsureAdmin does not use a transaction for wallet/profile creation.
|
||||
// If these fail, the admin user exists without a wallet/profile.
|
||||
// This is acceptable because EnsureAdmin runs once at startup and failures
|
||||
// are logged as warnings. A restart will skip user creation (already exists).
|
||||
func (s *Service) EnsureAdmin(username, password string) error {
|
||||
if _, err := s.repo.FindByUsername(username); err == nil {
|
||||
return nil
|
||||
|
||||
291
internal/auth/service_test.go
Normal file
291
internal/auth/service_test.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"a301_server/pkg/config"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 1. Password hashing (bcrypt)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBcryptHashAndVerify(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
wantMatch bool
|
||||
}{
|
||||
{"short password", "abc", true},
|
||||
{"normal password", "myP@ssw0rd!", true},
|
||||
{"unicode password", "비밀번호123", true},
|
||||
{"empty password", "", true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(tc.password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateFromPassword failed: %v", err)
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword(hash, []byte(tc.password))
|
||||
if (err == nil) != tc.wantMatch {
|
||||
t.Errorf("CompareHashAndPassword: got err=%v, wantMatch=%v", err, tc.wantMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBcryptWrongPassword(t *testing.T) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte("correct"), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateFromPassword failed: %v", err)
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword(hash, []byte("wrong")); err == nil {
|
||||
t.Error("expected error comparing wrong password, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBcryptDifferentHashesForSamePassword(t *testing.T) {
|
||||
password := "samePassword"
|
||||
hash1, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
hash2, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if string(hash1) == string(hash2) {
|
||||
t.Error("expected different hashes for the same password (different salts)")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. JWT token generation and parsing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func setupTestConfig() {
|
||||
config.C = config.Config{
|
||||
JWTSecret: "test-jwt-secret-key-for-unit-tests",
|
||||
RefreshSecret: "test-refresh-secret-key-for-unit-tests",
|
||||
JWTExpiryHours: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func TestIssueAndParseAccessToken(t *testing.T) {
|
||||
setupTestConfig()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uint
|
||||
username string
|
||||
role string
|
||||
}{
|
||||
{"admin user", 1, "admin", "admin"},
|
||||
{"regular user", 42, "player1", "user"},
|
||||
{"unicode username", 100, "유저", "user"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
expiry := time.Duration(config.C.JWTExpiryHours) * time.Hour
|
||||
claims := &Claims{
|
||||
UserID: tc.userID,
|
||||
Username: tc.username,
|
||||
Role: tc.role,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiry)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, err := token.SignedString([]byte(config.C.JWTSecret))
|
||||
if err != nil {
|
||||
t.Fatalf("SignedString failed: %v", err)
|
||||
}
|
||||
|
||||
parsed, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||
return []byte(config.C.JWTSecret), nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ParseWithClaims failed: %v", err)
|
||||
}
|
||||
if !parsed.Valid {
|
||||
t.Fatal("parsed token is not valid")
|
||||
}
|
||||
|
||||
got, ok := parsed.Claims.(*Claims)
|
||||
if !ok {
|
||||
t.Fatal("failed to cast claims")
|
||||
}
|
||||
if got.UserID != tc.userID {
|
||||
t.Errorf("UserID = %d, want %d", got.UserID, tc.userID)
|
||||
}
|
||||
if got.Username != tc.username {
|
||||
t.Errorf("Username = %q, want %q", got.Username, tc.username)
|
||||
}
|
||||
if got.Role != tc.role {
|
||||
t.Errorf("Role = %q, want %q", got.Role, tc.role)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenWithWrongSecret(t *testing.T) {
|
||||
setupTestConfig()
|
||||
|
||||
claims := &Claims{
|
||||
UserID: 1,
|
||||
Username: "test",
|
||||
Role: "user",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, _ := token.SignedString([]byte(config.C.JWTSecret))
|
||||
|
||||
_, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||
return []byte("wrong-secret"), nil
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error parsing token with wrong secret, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseExpiredToken(t *testing.T) {
|
||||
setupTestConfig()
|
||||
|
||||
claims := &Claims{
|
||||
UserID: 1,
|
||||
Username: "test",
|
||||
Role: "user",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, _ := token.SignedString([]byte(config.C.JWTSecret))
|
||||
|
||||
_, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||
return []byte(config.C.JWTSecret), nil
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error parsing expired token, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokenUsesDifferentSecret(t *testing.T) {
|
||||
setupTestConfig()
|
||||
|
||||
claims := &Claims{
|
||||
UserID: 1,
|
||||
Username: "test",
|
||||
Role: "user",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(refreshTokenExpiry)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
|
||||
// Sign with refresh secret
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, _ := token.SignedString([]byte(config.C.RefreshSecret))
|
||||
|
||||
// Should fail with JWT secret
|
||||
_, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||
return []byte(config.C.JWTSecret), nil
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error parsing refresh token with access secret")
|
||||
}
|
||||
|
||||
// Should succeed with refresh secret
|
||||
parsed, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||
return []byte(config.C.RefreshSecret), nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected success with refresh secret, got: %v", err)
|
||||
}
|
||||
if !parsed.Valid {
|
||||
t.Error("parsed refresh token is not valid")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 3. Input validation helpers (sanitizeForUsername)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSanitizeForUsername(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"lowercase letters", "hello", "hello"},
|
||||
{"uppercase converted", "HeLLo", "hello"},
|
||||
{"digits kept", "user123", "user123"},
|
||||
{"underscore kept", "user_name", "user_name"},
|
||||
{"hyphen kept", "user-name", "user-name"},
|
||||
{"special chars removed", "user@name!#$", "username"},
|
||||
{"spaces removed", "user name", "username"},
|
||||
{"unicode removed", "유저abc", "abc"},
|
||||
{"mixed", "User-123_Test!", "user-123_test"},
|
||||
{"empty input", "", ""},
|
||||
{"all removed", "!!@@##", ""},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := sanitizeForUsername(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("sanitizeForUsername(%q) = %q, want %q", tc.input, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 4. Claims struct fields
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestClaimsRoundTrip(t *testing.T) {
|
||||
setupTestConfig()
|
||||
|
||||
original := &Claims{
|
||||
UserID: 999,
|
||||
Username: "testuser",
|
||||
Role: "admin",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, original)
|
||||
tokenStr, err := token.SignedString([]byte(config.C.JWTSecret))
|
||||
if err != nil {
|
||||
t.Fatalf("signing failed: %v", err)
|
||||
}
|
||||
|
||||
parsed, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||
return []byte(config.C.JWTSecret), nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("parsing failed: %v", err)
|
||||
}
|
||||
|
||||
got := parsed.Claims.(*Claims)
|
||||
|
||||
if got.UserID != original.UserID {
|
||||
t.Errorf("UserID: got %d, want %d", got.UserID, original.UserID)
|
||||
}
|
||||
if got.Username != original.Username {
|
||||
t.Errorf("Username: got %q, want %q", got.Username, original.Username)
|
||||
}
|
||||
if got.Role != original.Role {
|
||||
t.Errorf("Role: got %q, want %q", got.Role, original.Role)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user