package middleware import ( "context" "encoding/json" "fmt" "log" "time" "a301_server/pkg/apperror" "github.com/gofiber/fiber/v2" "github.com/redis/go-redis/v9" ) const idempotencyTTL = 10 * time.Minute const redisTimeout = 5 * time.Second type cachedResponse struct { StatusCode int `json:"s"` Body json.RawMessage `json:"b"` } // IdempotencyRequired returns a middleware that rejects requests without an Idempotency-Key header, // then delegates to idempotency cache/replay logic. func IdempotencyRequired(rdb *redis.Client) fiber.Handler { idempotency := Idempotency(rdb) return func(c *fiber.Ctx) error { if c.Get("Idempotency-Key") == "" { return apperror.BadRequest("Idempotency-Key 헤더가 필요합니다") } return idempotency(c) } } // Idempotency returns a middleware that checks the Idempotency-Key header to prevent duplicate transactions. // If the same key is seen again within the TTL, the cached response is returned. func Idempotency(rdb *redis.Client) fiber.Handler { return func(c *fiber.Ctx) error { key := c.Get("Idempotency-Key") if key == "" { return c.Next() } if len(key) > 256 { return apperror.BadRequest("Idempotency-Key가 너무 깁니다") } // userID가 있으면 키에 포함하여 사용자 간 캐시 충돌 방지 redisKey := "idempotency:" if uid, ok := c.Locals("userID").(uint); ok { redisKey += fmt.Sprintf("u%d:", uid) } redisKey += c.Method() + ":" + c.Route().Path + ":" + key ctx, cancel := context.WithTimeout(context.Background(), redisTimeout) defer cancel() // Atomically claim the key using SET NX (only succeeds if key doesn't exist) set, err := rdb.SetNX(ctx, redisKey, "processing", idempotencyTTL).Result() if err != nil { // Redis error — reject to prevent duplicate transactions log.Printf("ERROR: idempotency SetNX failed (key=%s): %v", key, err) return apperror.New("internal_error", "서버 오류가 발생했습니다. 잠시 후 다시 시도해주세요", 503) } if !set { // Key already exists — either processing or completed getCtx, getCancel := context.WithTimeout(context.Background(), redisTimeout) defer getCancel() cached, err := rdb.Get(getCtx, redisKey).Bytes() if err != nil { return apperror.Conflict("요청이 처리 중입니다") } if string(cached) == "processing" { return apperror.Conflict("요청이 처리 중입니다") } var cr cachedResponse if json.Unmarshal(cached, &cr) == nil { c.Set("Content-Type", "application/json") c.Set("X-Idempotent-Replay", "true") return c.Status(cr.StatusCode).Send(cr.Body) } return apperror.Conflict("요청이 처리 중입니다") } // We claimed the key — process the request if err := c.Next(); err != nil { // Processing failed — remove the key so it can be retried delCtx, delCancel := context.WithTimeout(context.Background(), redisTimeout) defer delCancel() if delErr := rdb.Del(delCtx, redisKey).Err(); delErr != nil { log.Printf("WARNING: idempotency cache delete failed (key=%s): %v", key, delErr) } return err } // Cache successful responses (2xx), otherwise remove the key for retry status := c.Response().StatusCode() if status >= 200 && status < 300 { cr := cachedResponse{StatusCode: status, Body: c.Response().Body()} if data, err := json.Marshal(cr); err == nil { writeCtx, writeCancel := context.WithTimeout(context.Background(), redisTimeout) defer writeCancel() if err := rdb.Set(writeCtx, redisKey, data, idempotencyTTL).Err(); err != nil { log.Printf("WARNING: idempotency cache write failed (key=%s): %v", key, err) } } } else { // Non-success — allow retry by removing the key delCtx, delCancel := context.WithTimeout(context.Background(), redisTimeout) defer delCancel() if delErr := rdb.Del(delCtx, redisKey).Err(); delErr != nil { log.Printf("WARNING: idempotency cache delete failed (key=%s): %v", key, delErr) } } return nil } }