From 3f69cade21f2c440f93c0ebd92f78e8214554804 Mon Sep 17 00:00:00 2001 From: Tiago Yamamoto Date: Fri, 2 Jan 2026 11:01:56 -0300 Subject: [PATCH] test: expand backend coverage --- backend/internal/http/handler/dto_test.go | 236 ++++++++++++++++++ .../http/middleware/middleware_test.go | 66 +++++ backend/internal/server/server_test.go | 38 +++ backend/internal/usecase/usecase_test.go | 122 +++++++++ 4 files changed, 462 insertions(+) create mode 100644 backend/internal/http/handler/dto_test.go diff --git a/backend/internal/http/handler/dto_test.go b/backend/internal/http/handler/dto_test.go new file mode 100644 index 0000000..c2de9e0 --- /dev/null +++ b/backend/internal/http/handler/dto_test.go @@ -0,0 +1,236 @@ +package handler + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gofrs/uuid/v5" + "github.com/golang-jwt/jwt/v5" + "github.com/saveinmed/backend-go/internal/http/middleware" +) + +func TestWriteJSONAndError(t *testing.T) { + rec := httptest.NewRecorder() + writeJSON(rec, http.StatusCreated, map[string]string{"ok": "true"}) + + if rec.Code != http.StatusCreated { + t.Fatalf("expected status 201, got %d", rec.Code) + } + if ct := rec.Header().Get("Content-Type"); ct != "application/json" { + t.Fatalf("expected Content-Type application/json, got %q", ct) + } + + var payload map[string]string + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if payload["ok"] != "true" { + t.Fatalf("expected payload ok=true, got %v", payload) + } + + rec = httptest.NewRecorder() + writeError(rec, http.StatusBadRequest, errors.New("boom")) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", rec.Code) + } + var errPayload map[string]string + if err := json.Unmarshal(rec.Body.Bytes(), &errPayload); err != nil { + t.Fatalf("failed to decode error response: %v", err) + } + if errPayload["error"] != "boom" { + t.Fatalf("expected error boom, got %v", errPayload) + } +} + +func TestDecodeJSONUnknownField(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(`{"unknown":1}`)) + var payload struct { + Known string `json:"known"` + } + err := decodeJSON(context.Background(), req, &payload) + if err == nil { + t.Fatal("expected error for unknown field") + } +} + +func TestDecodeJSONCanceledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(`{"known":"ok"}`)) + var payload struct { + Known string `json:"known"` + } + err := decodeJSON(ctx, req, &payload) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled, got %v", err) + } +} + +func TestParseUUIDFromPath(t *testing.T) { + id := uuid.Must(uuid.NewV7()) + got, err := parseUUIDFromPath("/api/v1/companies/" + id.String() + "/rating") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != id { + t.Fatalf("expected %s, got %s", id, got) + } + if _, err := parseUUIDFromPath("/api/v1/companies/rating"); err == nil { + t.Fatal("expected error for missing UUID") + } +} + +func TestSplitPath(t *testing.T) { + parts := splitPath("/api/v1/companies/123") + if len(parts) != 4 { + t.Fatalf("expected 4 parts, got %d", len(parts)) + } + if parts[0] != "api" || parts[3] != "123" { + t.Fatalf("unexpected parts: %v", parts) + } +} + +func TestIsValidStatus(t *testing.T) { + if !isValidStatus("pending") { + t.Fatal("expected pending to be valid") + } + if isValidStatus("cancelled") { + t.Fatal("expected cancelled to be invalid") + } +} + +func TestParsePagination(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/?page=2&page_size=30", nil) + page, size := parsePagination(req) + if page != 2 || size != 30 { + t.Fatalf("expected page 2 size 30, got %d %d", page, size) + } + + req = httptest.NewRequest(http.MethodGet, "/?page=-1&page_size=0", nil) + page, size = parsePagination(req) + if page != 1 || size != 20 { + t.Fatalf("expected defaults, got %d %d", page, size) + } +} + +func TestGetRequesterFromHeaders(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + got, err := getRequester(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.Role != "Admin" || got.CompanyID != nil { + t.Fatalf("unexpected requester: %+v", got) + } + + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-User-Role", "Seller") + companyID := uuid.Must(uuid.NewV7()) + req.Header.Set("X-Company-ID", companyID.String()) + got, err = getRequester(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.Role != "Seller" || got.CompanyID == nil || *got.CompanyID != companyID { + t.Fatalf("unexpected requester: %+v", got) + } + + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Company-ID", "invalid") + if _, err := getRequester(req); err == nil { + t.Fatal("expected error for invalid company id header") + } +} + +func TestGetRequesterFromClaims(t *testing.T) { + secret := []byte("secret") + userID := uuid.Must(uuid.NewV7()) + companyID := uuid.Must(uuid.NewV7()) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": userID.String(), + "role": "Owner", + "company_id": companyID.String(), + "exp": time.Now().Add(time.Hour).Unix(), + }) + tokenStr, _ := token.SignedString(secret) + + var got requester + handler := middleware.RequireAuth(secret)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + got, err = getRequester(r) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer "+tokenStr) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if got.Role != "Owner" || got.CompanyID == nil || *got.CompanyID != companyID { + t.Fatalf("unexpected requester: %+v", got) + } +} + +func TestParseBearerToken(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + if _, err := parseBearerToken(req); err == nil { + t.Fatal("expected error for missing header") + } + + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "Token abc") + if _, err := parseBearerToken(req); err == nil { + t.Fatal("expected error for invalid header") + } + + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer ") + if _, err := parseBearerToken(req); err == nil { + t.Fatal("expected error for empty token") + } + + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer token123") + if token, err := parseBearerToken(req); err != nil || token != "token123" { + t.Fatalf("unexpected token result: %v %v", token, err) + } +} + +func TestGetUserFromContext(t *testing.T) { + secret := []byte("secret") + userID := uuid.Must(uuid.NewV7()) + companyID := uuid.Must(uuid.NewV7()) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": userID.String(), + "role": "Admin", + "company_id": companyID.String(), + "exp": time.Now().Add(time.Hour).Unix(), + }) + tokenStr, _ := token.SignedString(secret) + + handler := middleware.RequireAuth(secret)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := &Handler{} + user, err := h.getUserFromContext(r.Context()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if user.ID != userID || user.Role != "Admin" || user.CompanyID != companyID { + t.Fatalf("unexpected user: %+v", user) + } + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer "+tokenStr) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) +} diff --git a/backend/internal/http/middleware/middleware_test.go b/backend/internal/http/middleware/middleware_test.go index 9efeb69..8fb9ea0 100644 --- a/backend/internal/http/middleware/middleware_test.go +++ b/backend/internal/http/middleware/middleware_test.go @@ -392,6 +392,72 @@ func TestRequireAuthWrongSigningMethod(t *testing.T) { } } +func TestRequireAuthInvalidSubject(t *testing.T) { + secret := "secret" + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": "not-a-uuid", + "role": "Admin", + "exp": time.Now().Add(time.Hour).Unix(), + }) + tokenStr, _ := token.SignedString([]byte(secret)) + + handler := RequireAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("handler should not be called for invalid subject") + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer "+tokenStr) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", rec.Code) + } +} + +func TestOptionalAuthMissingToken(t *testing.T) { + var gotClaims bool + handler := OptionalAuth([]byte("secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, gotClaims = GetClaims(r.Context()) + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + if gotClaims { + t.Error("expected no claims for missing token") + } +} + +func TestOptionalAuthValidToken(t *testing.T) { + secret := "secret" + userID, _ := uuid.NewV7() + tokenStr := createTestToken(secret, userID, "Admin", nil) + + var gotClaims Claims + handler := OptionalAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotClaims, _ = GetClaims(r.Context()) + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer "+tokenStr) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + if gotClaims.UserID != userID { + t.Errorf("expected userID %s, got %s", userID, gotClaims.UserID) + } +} + // --- Security Headers Tests --- func TestSecurityHeaders(t *testing.T) { diff --git a/backend/internal/server/server_test.go b/backend/internal/server/server_test.go index 09ca024..2b27904 100644 --- a/backend/internal/server/server_test.go +++ b/backend/internal/server/server_test.go @@ -141,3 +141,41 @@ func TestDatabaseConnectionAndPing(t *testing.T) { t.Errorf("Expected 1, got %d", result) } } + +func TestChainMiddlewareOrder(t *testing.T) { + order := make([]string, 0, 3) + mw1 := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + order = append(order, "mw1") + next.ServeHTTP(w, r) + }) + } + mw2 := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + order = append(order, "mw2") + next.ServeHTTP(w, r) + }) + } + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + order = append(order, "handler") + w.WriteHeader(http.StatusOK) + }) + + chained := chain(handler, mw1, mw2) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + chained.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + expected := []string{"mw1", "mw2", "handler"} + if len(order) != len(expected) { + t.Fatalf("expected order %v, got %v", expected, order) + } + for i := range expected { + if order[i] != expected[i] { + t.Fatalf("expected order %v, got %v", expected, order) + } + } +} diff --git a/backend/internal/usecase/usecase_test.go b/backend/internal/usecase/usecase_test.go index 5c0f139..1903d9d 100644 --- a/backend/internal/usecase/usecase_test.go +++ b/backend/internal/usecase/usecase_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/gofrs/uuid/v5" + "github.com/golang-jwt/jwt/v5" "github.com/saveinmed/backend-go/internal/domain" ) @@ -1337,3 +1338,124 @@ func TestRegisterAccount(t *testing.T) { t.Error("expected user to be linked to company") } } + +func TestRefreshTokenValid(t *testing.T) { + svc, repo := newTestService() + ctx := context.Background() + + user := &domain.User{ + ID: uuid.Must(uuid.NewV7()), + Role: "admin", + CompanyID: uuid.Must(uuid.NewV7()), + } + repo.users = append(repo.users, *user) + + tokenStr, err := svc.signToken(jwt.MapClaims{ + "sub": user.ID.String(), + }, time.Now().Add(time.Hour)) + if err != nil { + t.Fatalf("failed to sign token: %v", err) + } + + token, expiresAt, err := svc.RefreshToken(ctx, tokenStr) + if err != nil { + t.Fatalf("failed to refresh token: %v", err) + } + if token == "" { + t.Error("expected new token") + } + if expiresAt.Before(time.Now()) { + t.Error("expected expiration in the future") + } +} + +func TestRefreshTokenInvalidScope(t *testing.T) { + svc, _ := newTestService() + + tokenStr, err := svc.signToken(jwt.MapClaims{ + "sub": uuid.Must(uuid.NewV7()).String(), + "scope": "password_reset", + }, time.Now().Add(time.Hour)) + if err != nil { + t.Fatalf("failed to sign token: %v", err) + } + + if _, _, err := svc.RefreshToken(context.Background(), tokenStr); err == nil { + t.Error("expected invalid scope error") + } +} + +func TestCreatePasswordResetTokenAndResetPassword(t *testing.T) { + svc, repo := newTestService() + ctx := context.Background() + + user := &domain.User{ + ID: uuid.Must(uuid.NewV7()), + Email: "reset@example.com", + } + repo.users = append(repo.users, *user) + + token, expiresAt, err := svc.CreatePasswordResetToken(ctx, user.Email) + if err != nil { + t.Fatalf("failed to create reset token: %v", err) + } + if token == "" { + t.Error("expected reset token") + } + if expiresAt.Before(time.Now()) { + t.Error("expected expiration in the future") + } + + if err := svc.ResetPassword(ctx, token, "newpass123"); err != nil { + t.Fatalf("failed to reset password: %v", err) + } +} + +func TestResetPasswordInvalidScope(t *testing.T) { + svc, _ := newTestService() + + tokenStr, err := svc.signToken(jwt.MapClaims{ + "sub": uuid.Must(uuid.NewV7()).String(), + }, time.Now().Add(time.Hour)) + if err != nil { + t.Fatalf("failed to sign token: %v", err) + } + + if err := svc.ResetPassword(context.Background(), tokenStr, "newpass"); err == nil { + t.Error("expected invalid token scope error") + } +} + +func TestVerifyEmailMarksVerified(t *testing.T) { + svc, repo := newTestService() + ctx := context.Background() + + user := &domain.User{ + ID: uuid.Must(uuid.NewV7()), + EmailVerified: false, + } + repo.users = append(repo.users, *user) + + tokenStr, err := svc.signToken(jwt.MapClaims{ + "sub": user.ID.String(), + "scope": "email_verify", + }, time.Now().Add(time.Hour)) + if err != nil { + t.Fatalf("failed to sign token: %v", err) + } + + updated, err := svc.VerifyEmail(ctx, tokenStr) + if err != nil { + t.Fatalf("failed to verify email: %v", err) + } + if !updated.EmailVerified { + t.Error("expected email to be verified") + } +} + +func TestParseTokenEmpty(t *testing.T) { + svc, _ := newTestService() + if _, err := svc.parseToken(" "); err == nil { + t.Error("expected error for empty token") + } +}