Merge pull request #59 from rede5/codex/increase-test-coverage-in-go-backend
test: expand backend coverage with middleware, handler DTO and usecase tests
This commit is contained in:
commit
2a73598d6f
4 changed files with 462 additions and 0 deletions
236
backend/internal/http/handler/dto_test.go
Normal file
236
backend/internal/http/handler/dto_test.go
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/saveinmed/backend-go/internal/http/middleware"
|
||||
)
|
||||
|
||||
func TestWriteJSONAndError(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
writeJSON(rec, http.StatusCreated, map[string]string{"ok": "true"})
|
||||
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Fatalf("expected status 201, got %d", rec.Code)
|
||||
}
|
||||
if ct := rec.Header().Get("Content-Type"); ct != "application/json" {
|
||||
t.Fatalf("expected Content-Type application/json, got %q", ct)
|
||||
}
|
||||
|
||||
var payload map[string]string
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if payload["ok"] != "true" {
|
||||
t.Fatalf("expected payload ok=true, got %v", payload)
|
||||
}
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
writeError(rec, http.StatusBadRequest, errors.New("boom"))
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", rec.Code)
|
||||
}
|
||||
var errPayload map[string]string
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &errPayload); err != nil {
|
||||
t.Fatalf("failed to decode error response: %v", err)
|
||||
}
|
||||
if errPayload["error"] != "boom" {
|
||||
t.Fatalf("expected error boom, got %v", errPayload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeJSONUnknownField(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(`{"unknown":1}`))
|
||||
var payload struct {
|
||||
Known string `json:"known"`
|
||||
}
|
||||
err := decodeJSON(context.Background(), req, &payload)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeJSONCanceledContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(`{"known":"ok"}`))
|
||||
var payload struct {
|
||||
Known string `json:"known"`
|
||||
}
|
||||
err := decodeJSON(ctx, req, &payload)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context canceled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUUIDFromPath(t *testing.T) {
|
||||
id := uuid.Must(uuid.NewV7())
|
||||
got, err := parseUUIDFromPath("/api/v1/companies/" + id.String() + "/rating")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != id {
|
||||
t.Fatalf("expected %s, got %s", id, got)
|
||||
}
|
||||
if _, err := parseUUIDFromPath("/api/v1/companies/rating"); err == nil {
|
||||
t.Fatal("expected error for missing UUID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitPath(t *testing.T) {
|
||||
parts := splitPath("/api/v1/companies/123")
|
||||
if len(parts) != 4 {
|
||||
t.Fatalf("expected 4 parts, got %d", len(parts))
|
||||
}
|
||||
if parts[0] != "api" || parts[3] != "123" {
|
||||
t.Fatalf("unexpected parts: %v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidStatus(t *testing.T) {
|
||||
if !isValidStatus("pending") {
|
||||
t.Fatal("expected pending to be valid")
|
||||
}
|
||||
if isValidStatus("cancelled") {
|
||||
t.Fatal("expected cancelled to be invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePagination(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/?page=2&page_size=30", nil)
|
||||
page, size := parsePagination(req)
|
||||
if page != 2 || size != 30 {
|
||||
t.Fatalf("expected page 2 size 30, got %d %d", page, size)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/?page=-1&page_size=0", nil)
|
||||
page, size = parsePagination(req)
|
||||
if page != 1 || size != 20 {
|
||||
t.Fatalf("expected defaults, got %d %d", page, size)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRequesterFromHeaders(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
got, err := getRequester(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got.Role != "Admin" || got.CompanyID != nil {
|
||||
t.Fatalf("unexpected requester: %+v", got)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-User-Role", "Seller")
|
||||
companyID := uuid.Must(uuid.NewV7())
|
||||
req.Header.Set("X-Company-ID", companyID.String())
|
||||
got, err = getRequester(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got.Role != "Seller" || got.CompanyID == nil || *got.CompanyID != companyID {
|
||||
t.Fatalf("unexpected requester: %+v", got)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Company-ID", "invalid")
|
||||
if _, err := getRequester(req); err == nil {
|
||||
t.Fatal("expected error for invalid company id header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRequesterFromClaims(t *testing.T) {
|
||||
secret := []byte("secret")
|
||||
userID := uuid.Must(uuid.NewV7())
|
||||
companyID := uuid.Must(uuid.NewV7())
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": userID.String(),
|
||||
"role": "Owner",
|
||||
"company_id": companyID.String(),
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
tokenStr, _ := token.SignedString(secret)
|
||||
|
||||
var got requester
|
||||
handler := middleware.RequireAuth(secret)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
got, err = getRequester(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if got.Role != "Owner" || got.CompanyID == nil || *got.CompanyID != companyID {
|
||||
t.Fatalf("unexpected requester: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBearerToken(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if _, err := parseBearerToken(req); err == nil {
|
||||
t.Fatal("expected error for missing header")
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", "Token abc")
|
||||
if _, err := parseBearerToken(req); err == nil {
|
||||
t.Fatal("expected error for invalid header")
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer ")
|
||||
if _, err := parseBearerToken(req); err == nil {
|
||||
t.Fatal("expected error for empty token")
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer token123")
|
||||
if token, err := parseBearerToken(req); err != nil || token != "token123" {
|
||||
t.Fatalf("unexpected token result: %v %v", token, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserFromContext(t *testing.T) {
|
||||
secret := []byte("secret")
|
||||
userID := uuid.Must(uuid.NewV7())
|
||||
companyID := uuid.Must(uuid.NewV7())
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": userID.String(),
|
||||
"role": "Admin",
|
||||
"company_id": companyID.String(),
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
tokenStr, _ := token.SignedString(secret)
|
||||
|
||||
handler := middleware.RequireAuth(secret)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h := &Handler{}
|
||||
user, err := h.getUserFromContext(r.Context())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if user.ID != userID || user.Role != "Admin" || user.CompanyID != companyID {
|
||||
t.Fatalf("unexpected user: %+v", user)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
}
|
||||
|
|
@ -392,6 +392,72 @@ func TestRequireAuthWrongSigningMethod(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRequireAuthInvalidSubject(t *testing.T) {
|
||||
secret := "secret"
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": "not-a-uuid",
|
||||
"role": "Admin",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
tokenStr, _ := token.SignedString([]byte(secret))
|
||||
|
||||
handler := RequireAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("handler should not be called for invalid subject")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionalAuthMissingToken(t *testing.T) {
|
||||
var gotClaims bool
|
||||
handler := OptionalAuth([]byte("secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, gotClaims = GetClaims(r.Context())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", rec.Code)
|
||||
}
|
||||
if gotClaims {
|
||||
t.Error("expected no claims for missing token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionalAuthValidToken(t *testing.T) {
|
||||
secret := "secret"
|
||||
userID, _ := uuid.NewV7()
|
||||
tokenStr := createTestToken(secret, userID, "Admin", nil)
|
||||
|
||||
var gotClaims Claims
|
||||
handler := OptionalAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotClaims, _ = GetClaims(r.Context())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", rec.Code)
|
||||
}
|
||||
if gotClaims.UserID != userID {
|
||||
t.Errorf("expected userID %s, got %s", userID, gotClaims.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Security Headers Tests ---
|
||||
|
||||
func TestSecurityHeaders(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -141,3 +141,41 @@ func TestDatabaseConnectionAndPing(t *testing.T) {
|
|||
t.Errorf("Expected 1, got %d", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainMiddlewareOrder(t *testing.T) {
|
||||
order := make([]string, 0, 3)
|
||||
mw1 := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
order = append(order, "mw1")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
mw2 := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
order = append(order, "mw2")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
order = append(order, "handler")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
chained := chain(handler, mw1, mw2)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
chained.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", rec.Code)
|
||||
}
|
||||
expected := []string{"mw1", "mw2", "handler"}
|
||||
if len(order) != len(expected) {
|
||||
t.Fatalf("expected order %v, got %v", expected, order)
|
||||
}
|
||||
for i := range expected {
|
||||
if order[i] != expected[i] {
|
||||
t.Fatalf("expected order %v, got %v", expected, order)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/saveinmed/backend-go/internal/domain"
|
||||
)
|
||||
|
||||
|
|
@ -1337,3 +1338,124 @@ func TestRegisterAccount(t *testing.T) {
|
|||
t.Error("expected user to be linked to company")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokenValid(t *testing.T) {
|
||||
svc, repo := newTestService()
|
||||
ctx := context.Background()
|
||||
|
||||
user := &domain.User{
|
||||
ID: uuid.Must(uuid.NewV7()),
|
||||
Role: "admin",
|
||||
CompanyID: uuid.Must(uuid.NewV7()),
|
||||
}
|
||||
repo.users = append(repo.users, *user)
|
||||
|
||||
tokenStr, err := svc.signToken(jwt.MapClaims{
|
||||
"sub": user.ID.String(),
|
||||
}, time.Now().Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to sign token: %v", err)
|
||||
}
|
||||
|
||||
token, expiresAt, err := svc.RefreshToken(ctx, tokenStr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to refresh token: %v", err)
|
||||
}
|
||||
if token == "" {
|
||||
t.Error("expected new token")
|
||||
}
|
||||
if expiresAt.Before(time.Now()) {
|
||||
t.Error("expected expiration in the future")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokenInvalidScope(t *testing.T) {
|
||||
svc, _ := newTestService()
|
||||
|
||||
tokenStr, err := svc.signToken(jwt.MapClaims{
|
||||
"sub": uuid.Must(uuid.NewV7()).String(),
|
||||
"scope": "password_reset",
|
||||
}, time.Now().Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to sign token: %v", err)
|
||||
}
|
||||
|
||||
if _, _, err := svc.RefreshToken(context.Background(), tokenStr); err == nil {
|
||||
t.Error("expected invalid scope error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreatePasswordResetTokenAndResetPassword(t *testing.T) {
|
||||
svc, repo := newTestService()
|
||||
ctx := context.Background()
|
||||
|
||||
user := &domain.User{
|
||||
ID: uuid.Must(uuid.NewV7()),
|
||||
Email: "reset@example.com",
|
||||
}
|
||||
repo.users = append(repo.users, *user)
|
||||
|
||||
token, expiresAt, err := svc.CreatePasswordResetToken(ctx, user.Email)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create reset token: %v", err)
|
||||
}
|
||||
if token == "" {
|
||||
t.Error("expected reset token")
|
||||
}
|
||||
if expiresAt.Before(time.Now()) {
|
||||
t.Error("expected expiration in the future")
|
||||
}
|
||||
|
||||
if err := svc.ResetPassword(ctx, token, "newpass123"); err != nil {
|
||||
t.Fatalf("failed to reset password: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetPasswordInvalidScope(t *testing.T) {
|
||||
svc, _ := newTestService()
|
||||
|
||||
tokenStr, err := svc.signToken(jwt.MapClaims{
|
||||
"sub": uuid.Must(uuid.NewV7()).String(),
|
||||
}, time.Now().Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to sign token: %v", err)
|
||||
}
|
||||
|
||||
if err := svc.ResetPassword(context.Background(), tokenStr, "newpass"); err == nil {
|
||||
t.Error("expected invalid token scope error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyEmailMarksVerified(t *testing.T) {
|
||||
svc, repo := newTestService()
|
||||
ctx := context.Background()
|
||||
|
||||
user := &domain.User{
|
||||
ID: uuid.Must(uuid.NewV7()),
|
||||
EmailVerified: false,
|
||||
}
|
||||
repo.users = append(repo.users, *user)
|
||||
|
||||
tokenStr, err := svc.signToken(jwt.MapClaims{
|
||||
"sub": user.ID.String(),
|
||||
"scope": "email_verify",
|
||||
}, time.Now().Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to sign token: %v", err)
|
||||
}
|
||||
|
||||
updated, err := svc.VerifyEmail(ctx, tokenStr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to verify email: %v", err)
|
||||
}
|
||||
if !updated.EmailVerified {
|
||||
t.Error("expected email to be verified")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenEmpty(t *testing.T) {
|
||||
svc, _ := newTestService()
|
||||
if _, err := svc.parseToken(" "); err == nil {
|
||||
t.Error("expected error for empty token")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue