package main import ( "archive/zip" "crypto/sha256" "encoding/hex" "fmt" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "testing" ) // ── extractZip tests ───────────────────────────────────────────────────────── // createTestZip creates a zip file at zipPath with the given entries. // Each entry is a path → content pair. Directories have empty content and end with "/". func createTestZip(t *testing.T, zipPath string, entries map[string]string) { t.Helper() f, err := os.Create(zipPath) if err != nil { t.Fatal(err) } w := zip.NewWriter(f) for name, content := range entries { fw, err := w.Create(name) if err != nil { t.Fatal(err) } if content != "" { if _, err := fw.Write([]byte(content)); err != nil { t.Fatal(err) } } } if err := w.Close(); err != nil { t.Fatal(err) } if err := f.Close(); err != nil { t.Fatal(err) } } func TestExtractZip_Normal(t *testing.T) { tmpDir := t.TempDir() zipPath := filepath.Join(tmpDir, "test.zip") destDir := filepath.Join(tmpDir, "out") os.MkdirAll(destDir, 0755) // zip 내 최상위 디렉토리 제거 동작 검증 (A301/hello.txt → hello.txt) createTestZip(t, zipPath, map[string]string{ "A301/hello.txt": "world", "A301/sub/nested.txt": "deep", }) if err := extractZip(zipPath, destDir); err != nil { t.Fatal(err) } // hello.txt가 destDir에 직접 존재해야 함 content, err := os.ReadFile(filepath.Join(destDir, "hello.txt")) if err != nil { t.Fatalf("hello.txt 읽기 실패: %v", err) } if string(content) != "world" { t.Errorf("hello.txt 내용 불일치: got %q, want %q", string(content), "world") } content, err = os.ReadFile(filepath.Join(destDir, "sub", "nested.txt")) if err != nil { t.Fatalf("sub/nested.txt 읽기 실패: %v", err) } if string(content) != "deep" { t.Errorf("sub/nested.txt 내용 불일치: got %q, want %q", string(content), "deep") } } func TestExtractZip_FlatZip(t *testing.T) { tmpDir := t.TempDir() zipPath := filepath.Join(tmpDir, "flat.zip") destDir := filepath.Join(tmpDir, "out") os.MkdirAll(destDir, 0755) // 디렉토리 없이 최상위에 직접 파일이 있는 zip createTestZip(t, zipPath, map[string]string{ "readme.txt": "flat file", }) if err := extractZip(zipPath, destDir); err != nil { t.Fatal(err) } content, err := os.ReadFile(filepath.Join(destDir, "readme.txt")) if err != nil { t.Fatalf("readme.txt 읽기 실패: %v", err) } if string(content) != "flat file" { t.Errorf("내용 불일치: got %q", string(content)) } } func TestExtractZip_ZipSlip(t *testing.T) { tmpDir := t.TempDir() zipPath := filepath.Join(tmpDir, "evil.zip") destDir := filepath.Join(tmpDir, "out") os.MkdirAll(destDir, 0755) // Zip Slip: 경로 탈출 시도 f, err := os.Create(zipPath) if err != nil { t.Fatal(err) } w := zip.NewWriter(f) // A301/../../../etc/passwd → 최상위 제거 후 ../../etc/passwd fw, _ := w.Create("A301/../../../etc/passwd") fw.Write([]byte("evil")) w.Close() f.Close() err = extractZip(zipPath, destDir) if err == nil { t.Fatal("Zip Slip 공격이 차단되지 않음") } } func TestExtractZip_NTFS_ADS(t *testing.T) { tmpDir := t.TempDir() zipPath := filepath.Join(tmpDir, "ads.zip") destDir := filepath.Join(tmpDir, "out") os.MkdirAll(destDir, 0755) // NTFS ADS: 콜론 포함 경로 createTestZip(t, zipPath, map[string]string{ "A301/file.txt:hidden": "ads data", }) err := extractZip(zipPath, destDir) if err == nil { t.Fatal("NTFS ADS 공격이 차단되지 않음") } } func TestExtractZip_Empty(t *testing.T) { tmpDir := t.TempDir() zipPath := filepath.Join(tmpDir, "empty.zip") destDir := filepath.Join(tmpDir, "out") os.MkdirAll(destDir, 0755) // 빈 zip createTestZip(t, zipPath, map[string]string{}) if err := extractZip(zipPath, destDir); err != nil { t.Fatalf("빈 zip 처리 실패: %v", err) } // destDir에 아무것도 없어야 함 entries, _ := os.ReadDir(destDir) if len(entries) != 0 { t.Errorf("빈 zip인데 파일이 추출됨: %d개", len(entries)) } } func TestExtractZip_NestedDirs(t *testing.T) { tmpDir := t.TempDir() zipPath := filepath.Join(tmpDir, "nested.zip") destDir := filepath.Join(tmpDir, "out") os.MkdirAll(destDir, 0755) createTestZip(t, zipPath, map[string]string{ "root/a/b/c/deep.txt": "deep content", "root/a/b/mid.txt": "mid content", }) if err := extractZip(zipPath, destDir); err != nil { t.Fatal(err) } content, err := os.ReadFile(filepath.Join(destDir, "a", "b", "c", "deep.txt")) if err != nil { t.Fatal(err) } if string(content) != "deep content" { t.Errorf("deep.txt 내용 불일치: got %q", string(content)) } } func TestExtractZip_AbsolutePath(t *testing.T) { tmpDir := t.TempDir() zipPath := filepath.Join(tmpDir, "abs.zip") destDir := filepath.Join(tmpDir, "out") os.MkdirAll(destDir, 0755) f, err := os.Create(zipPath) if err != nil { t.Fatal(err) } w := zip.NewWriter(f) // 절대 경로 시도 fw, _ := w.Create("A301/C:\\Windows\\evil.txt") fw.Write([]byte("evil")) w.Close() f.Close() err = extractZip(zipPath, destDir) // Windows에서 C: 포함은 ADS로도 잡히지만, 절대 경로로도 잡혀야 함 if err == nil { t.Fatal("절대 경로 공격이 차단되지 않음") } } // ── hashFile tests ─────────────────────────────────────────────────────────── func TestHashFile_Normal(t *testing.T) { tmpDir := t.TempDir() path := filepath.Join(tmpDir, "test.bin") content := []byte("hello world") os.WriteFile(path, content, 0644) got, err := hashFile(path) if err != nil { t.Fatal(err) } h := sha256.Sum256(content) want := hex.EncodeToString(h[:]) if got != want { t.Errorf("해시 불일치: got %s, want %s", got, want) } } func TestHashFile_Empty(t *testing.T) { tmpDir := t.TempDir() path := filepath.Join(tmpDir, "empty.bin") os.WriteFile(path, []byte{}, 0644) got, err := hashFile(path) if err != nil { t.Fatal(err) } h := sha256.Sum256([]byte{}) want := hex.EncodeToString(h[:]) if got != want { t.Errorf("빈 파일 해시 불일치: got %s, want %s", got, want) } } func TestHashFile_NotExist(t *testing.T) { _, err := hashFile("/nonexistent/path/to/file") if err == nil { t.Fatal("존재하지 않는 파일에 에러가 발생하지 않음") } } // ── redeemTicket tests (httptest) ──────────────────────────────────────────── func TestRedeemTicket_Success(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { t.Errorf("예상 메서드 POST, got %s", r.Method) } if r.Header.Get("Content-Type") != "application/json" { t.Errorf("Content-Type이 application/json이 아님") } w.Header().Set("Content-Type", "application/json") fmt.Fprint(w, `{"token":"eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.abc123"}`) })) defer srv.Close() origURL := redeemTicketURL redeemTicketURL = srv.URL defer func() { redeemTicketURL = origURL }() token, err := redeemTicket("test-ticket") if err != nil { t.Fatal(err) } if token != "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.abc123" { t.Errorf("토큰 불일치: got %s", token) } } func TestRedeemTicket_ServerError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) fmt.Fprint(w, `{"error":"invalid ticket"}`) })) defer srv.Close() origURL := redeemTicketURL redeemTicketURL = srv.URL defer func() { redeemTicketURL = origURL }() _, err := redeemTicket("bad-ticket") if err == nil { t.Fatal("서버 에러 시 에러가 반환되지 않음") } } func TestRedeemTicket_InvalidJSON(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") fmt.Fprint(w, `not json`) })) defer srv.Close() origURL := redeemTicketURL redeemTicketURL = srv.URL defer func() { redeemTicketURL = origURL }() _, err := redeemTicket("ticket") if err == nil { t.Fatal("잘못된 JSON에 에러가 반환되지 않음") } } func TestRedeemTicket_EmptyToken(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") fmt.Fprint(w, `{"token":""}`) })) defer srv.Close() origURL := redeemTicketURL redeemTicketURL = srv.URL defer func() { redeemTicketURL = origURL }() _, err := redeemTicket("ticket") if err == nil { t.Fatal("빈 토큰에 에러가 반환되지 않음") } } func TestRedeemTicket_Unreachable(t *testing.T) { origURL := redeemTicketURL redeemTicketURL = "http://127.0.0.1:1" defer func() { redeemTicketURL = origURL }() _, err := redeemTicket("ticket") if err == nil { t.Fatal("연결 불가 시 에러가 반환되지 않음") } } // ── URL parsing tests ──────────────────────────────────────────────────────── func TestParseURI_ValidToken(t *testing.T) { raw := "a301://launch?token=test-ticket-123" parsed, err := url.Parse(raw) if err != nil { t.Fatal(err) } if parsed.Scheme != protocolName { t.Errorf("스킴 불일치: got %s, want %s", parsed.Scheme, protocolName) } token := parsed.Query().Get("token") if token != "test-ticket-123" { t.Errorf("토큰 불일치: got %s", token) } } func TestParseURI_MissingToken(t *testing.T) { raw := "a301://launch" parsed, err := url.Parse(raw) if err != nil { t.Fatal(err) } token := parsed.Query().Get("token") if token != "" { t.Errorf("토큰이 비어있어야 함: got %s", token) } } func TestParseURI_WrongScheme(t *testing.T) { raw := "http://launch?token=xxx" parsed, err := url.Parse(raw) if err != nil { t.Fatal(err) } if parsed.Scheme == protocolName { t.Error("잘못된 스킴이 허용됨") } } func TestParseURI_EncodedToken(t *testing.T) { // URL 인코딩된 토큰 raw := "a301://launch?token=abc%2Bdef%3Dghi" parsed, err := url.Parse(raw) if err != nil { t.Fatal(err) } token := parsed.Query().Get("token") if token != "abc+def=ghi" { t.Errorf("URL 디코딩 불일치: got %s, want abc+def=ghi", token) } } func TestParseURI_MultipleParams(t *testing.T) { raw := "a301://launch?token=myticket&extra=ignored" parsed, err := url.Parse(raw) if err != nil { t.Fatal(err) } token := parsed.Query().Get("token") if token != "myticket" { t.Errorf("토큰 불일치: got %s", token) } }