package handler import ( "bytes" "context" "encoding/json" "errors" "net/http" "net/http/httptest" "testing" "time" "github.com/gofrs/uuid/v5" "github.com/golang-jwt/jwt/v5" "github.com/saveinmed/backend-go/internal/http/middleware" ) func TestWriteJSONAndError(t *testing.T) { rec := httptest.NewRecorder() writeJSON(rec, http.StatusCreated, map[string]string{"ok": "true"}) if rec.Code != http.StatusCreated { t.Fatalf("expected status 201, got %d", rec.Code) } if ct := rec.Header().Get("Content-Type"); ct != "application/json" { t.Fatalf("expected Content-Type application/json, got %q", ct) } var payload map[string]string if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { t.Fatalf("failed to decode response: %v", err) } if payload["ok"] != "true" { t.Fatalf("expected payload ok=true, got %v", payload) } rec = httptest.NewRecorder() writeError(rec, http.StatusBadRequest, errors.New("boom")) if rec.Code != http.StatusBadRequest { t.Fatalf("expected status 400, got %d", rec.Code) } var errPayload map[string]string if err := json.Unmarshal(rec.Body.Bytes(), &errPayload); err != nil { t.Fatalf("failed to decode error response: %v", err) } if errPayload["error"] != "boom" { t.Fatalf("expected error boom, got %v", errPayload) } } func TestDecodeJSONUnknownField(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(`{"unknown":1}`)) var payload struct { Known string `json:"known"` } err := decodeJSON(context.Background(), req, &payload) if err == nil { t.Fatal("expected error for unknown field") } } func TestDecodeJSONCanceledContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(`{"known":"ok"}`)) var payload struct { Known string `json:"known"` } err := decodeJSON(ctx, req, &payload) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context canceled, got %v", err) } } func TestParseUUIDFromPath(t *testing.T) { id := uuid.Must(uuid.NewV7()) got, err := parseUUIDFromPath("/api/v1/companies/" + id.String() + "/rating") if err != nil { t.Fatalf("unexpected error: %v", err) } if got != id { t.Fatalf("expected %s, got %s", id, got) } if _, err := parseUUIDFromPath("/api/v1/companies/rating"); err == nil { t.Fatal("expected error for missing UUID") } } func TestSplitPath(t *testing.T) { parts := splitPath("/api/v1/companies/123") if len(parts) != 4 { t.Fatalf("expected 4 parts, got %d", len(parts)) } if parts[0] != "api" || parts[3] != "123" { t.Fatalf("unexpected parts: %v", parts) } } func TestIsValidStatus(t *testing.T) { if !isValidStatus("pending") { t.Fatal("expected pending to be valid") } if isValidStatus("cancelled") { t.Fatal("expected cancelled to be invalid") } } func TestParsePagination(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/?page=2&page_size=30", nil) page, size := parsePagination(req) if page != 2 || size != 30 { t.Fatalf("expected page 2 size 30, got %d %d", page, size) } req = httptest.NewRequest(http.MethodGet, "/?page=-1&page_size=0", nil) page, size = parsePagination(req) if page != 1 || size != 20 { t.Fatalf("expected defaults, got %d %d", page, size) } } func TestGetRequesterFromHeaders(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) got, err := getRequester(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if got.Role != "Admin" || got.CompanyID != nil { t.Fatalf("unexpected requester: %+v", got) } req = httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-User-Role", "Seller") companyID := uuid.Must(uuid.NewV7()) req.Header.Set("X-Company-ID", companyID.String()) got, err = getRequester(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if got.Role != "Seller" || got.CompanyID == nil || *got.CompanyID != companyID { t.Fatalf("unexpected requester: %+v", got) } req = httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-Company-ID", "invalid") if _, err := getRequester(req); err == nil { t.Fatal("expected error for invalid company id header") } } func TestGetRequesterFromClaims(t *testing.T) { secret := []byte("secret") userID := uuid.Must(uuid.NewV7()) companyID := uuid.Must(uuid.NewV7()) token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "sub": userID.String(), "role": "Owner", "company_id": companyID.String(), "exp": time.Now().Add(time.Hour).Unix(), }) tokenStr, _ := token.SignedString(secret) var got requester handler := middleware.RequireAuth(secret)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var err error got, err = getRequester(r) if err != nil { t.Fatalf("unexpected error: %v", err) } w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if got.Role != "Owner" || got.CompanyID == nil || *got.CompanyID != companyID { t.Fatalf("unexpected requester: %+v", got) } } func TestParseBearerToken(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) if _, err := parseBearerToken(req); err == nil { t.Fatal("expected error for missing header") } req = httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Token abc") if _, err := parseBearerToken(req); err == nil { t.Fatal("expected error for invalid header") } req = httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer ") if _, err := parseBearerToken(req); err == nil { t.Fatal("expected error for empty token") } req = httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer token123") if token, err := parseBearerToken(req); err != nil || token != "token123" { t.Fatalf("unexpected token result: %v %v", token, err) } } func TestGetUserFromContext(t *testing.T) { secret := []byte("secret") userID := uuid.Must(uuid.NewV7()) companyID := uuid.Must(uuid.NewV7()) token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "sub": userID.String(), "role": "Admin", "company_id": companyID.String(), "exp": time.Now().Add(time.Hour).Unix(), }) tokenStr, _ := token.SignedString(secret) handler := middleware.RequireAuth(secret)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := &Handler{} user, err := h.getUserFromContext(r.Context()) if err != nil { t.Fatalf("unexpected error: %v", err) } if user.ID != userID || user.Role != "Admin" || user.CompanyID != companyID { t.Fatalf("unexpected user: %+v", user) } w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) }