saveinmed/backend-old/internal/http/middleware/middleware_test.go
2026-02-07 11:43:31 -03:00

498 lines
14 KiB
Go

package middleware
import (
"compress/gzip"
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gofrs/uuid/v5"
"github.com/golang-jwt/jwt/v5"
)
// --- CORS Tests ---
func TestCORSWithConfigAllowAll(t *testing.T) {
cfg := CORSConfig{AllowedOrigins: []string{"*"}}
handler := CORSWithConfig(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "https://example.com")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Header().Get("Access-Control-Allow-Origin") != "*" {
t.Errorf("expected Access-Control-Allow-Origin '*', got '%s'", rec.Header().Get("Access-Control-Allow-Origin"))
}
}
func TestCORSWithConfigSpecificOrigins(t *testing.T) {
cfg := CORSConfig{AllowedOrigins: []string{"https://allowed.com", "https://another.com"}}
handler := CORSWithConfig(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Test allowed origin
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "https://allowed.com")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Header().Get("Access-Control-Allow-Origin") != "https://allowed.com" {
t.Errorf("expected origin 'https://allowed.com', got '%s'", rec.Header().Get("Access-Control-Allow-Origin"))
}
// Test blocked origin
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
req2.Header.Set("Origin", "https://blocked.com")
rec2 := httptest.NewRecorder()
handler.ServeHTTP(rec2, req2)
if rec2.Header().Get("Access-Control-Allow-Origin") != "" {
t.Errorf("expected empty Access-Control-Allow-Origin, got '%s'", rec2.Header().Get("Access-Control-Allow-Origin"))
}
}
func TestCORSPreflight(t *testing.T) {
cfg := CORSConfig{AllowedOrigins: []string{"*"}}
handler := CORSWithConfig(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot) // Should not reach here
}))
req := httptest.NewRequest(http.MethodOptions, "/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("expected status 200 for preflight, got %d", rec.Code)
}
}
func TestCORSHeaders(t *testing.T) {
cfg := CORSConfig{AllowedOrigins: []string{"*"}}
handler := CORSWithConfig(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if !strings.Contains(rec.Header().Get("Access-Control-Allow-Methods"), "GET") {
t.Error("expected Access-Control-Allow-Methods to include GET")
}
if !strings.Contains(rec.Header().Get("Access-Control-Allow-Headers"), "Authorization") {
t.Error("expected Access-Control-Allow-Headers to include Authorization")
}
}
// --- Auth Tests ---
func createTestToken(secret string, userID uuid.UUID, role string, companyID *uuid.UUID) string {
claims := jwt.MapClaims{
"sub": userID.String(),
"role": role,
"exp": time.Now().Add(time.Hour).Unix(),
}
if companyID != nil {
claims["company_id"] = companyID.String()
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, _ := token.SignedString([]byte(secret))
return tokenStr
}
func TestRequireAuthValidToken(t *testing.T) {
secret := "test-secret"
userID, _ := uuid.NewV7()
companyID, _ := uuid.NewV7()
tokenStr := createTestToken(secret, userID, "Admin", &companyID)
var receivedClaims Claims
var receivedCompanyID *uuid.UUID
handler := RequireAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedClaims, _ = GetClaims(r.Context())
receivedCompanyID, _ = r.Context().Value("company_id").(*uuid.UUID)
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rec.Code)
}
if receivedClaims.UserID != userID {
t.Errorf("expected userID %s, got %s", userID, receivedClaims.UserID)
}
if receivedClaims.Role != "Admin" {
t.Errorf("expected role 'Admin', got '%s'", receivedClaims.Role)
}
if receivedCompanyID == nil || *receivedCompanyID != companyID {
t.Errorf("expected companyID %s, got %v", companyID, receivedCompanyID)
}
}
func TestRequireAuthMissingToken(t *testing.T) {
handler := RequireAuth([]byte("secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("expected status 401, got %d", rec.Code)
}
}
func TestRequireAuthInvalidToken(t *testing.T) {
handler := RequireAuth([]byte("secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer invalid-token")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("expected status 401, got %d", rec.Code)
}
}
func TestRequireAuthWrongSecret(t *testing.T) {
userID, _ := uuid.NewV7()
tokenStr := createTestToken("correct-secret", userID, "User", nil)
handler := RequireAuth([]byte("wrong-secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("expected status 401, got %d", rec.Code)
}
}
func TestRequireAuthRoleRestriction(t *testing.T) {
secret := "secret"
userID, _ := uuid.NewV7()
tokenStr := createTestToken(secret, userID, "User", nil)
handler := RequireAuth([]byte(secret), "Admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Errorf("expected status 403, got %d", rec.Code)
}
}
func TestRequireAuthRoleAllowed(t *testing.T) {
secret := "secret"
userID, _ := uuid.NewV7()
tokenStr := createTestToken(secret, userID, "Admin", nil)
handler := RequireAuth([]byte(secret), "Admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rec.Code)
}
}
func TestGetClaimsFromContext(t *testing.T) {
claims := Claims{
UserID: uuid.Must(uuid.NewV7()),
Role: "Admin",
}
ctx := context.WithValue(context.Background(), claimsKey, claims)
retrieved, ok := GetClaims(ctx)
if !ok {
t.Error("expected to retrieve claims from context")
}
if retrieved.UserID != claims.UserID {
t.Errorf("expected userID %s, got %s", claims.UserID, retrieved.UserID)
}
}
func TestGetClaimsNotInContext(t *testing.T) {
_, ok := GetClaims(context.Background())
if ok {
t.Error("expected claims to not be in context")
}
}
// --- Gzip Tests ---
func TestGzipCompression(t *testing.T) {
handler := Gzip(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, World!"))
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "gzip")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Header().Get("Content-Encoding") != "gzip" {
t.Error("expected Content-Encoding 'gzip'")
}
// Decompress and verify
reader, err := gzip.NewReader(rec.Body)
if err != nil {
t.Fatalf("failed to create gzip reader: %v", err)
}
defer reader.Close()
body, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("failed to read gzip body: %v", err)
}
if string(body) != "Hello, World!" {
t.Errorf("expected 'Hello, World!', got '%s'", string(body))
}
}
func TestGzipNoCompression(t *testing.T) {
handler := Gzip(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, World!"))
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
// No Accept-Encoding header
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Header().Get("Content-Encoding") == "gzip" {
t.Error("should not use gzip when not requested")
}
if rec.Body.String() != "Hello, World!" {
t.Errorf("expected 'Hello, World!', got '%s'", rec.Body.String())
}
}
// --- Logger Tests ---
func TestLoggerMiddleware(t *testing.T) {
handler := Logger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/test-path", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rec.Code)
}
}
// --- CORS Legacy Wrapper Test ---
func TestCORSLegacyWrapper(t *testing.T) {
handler := CORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "https://example.com")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Header().Get("Access-Control-Allow-Origin") != "*" {
t.Errorf("expected '*', got '%s'", rec.Header().Get("Access-Control-Allow-Origin"))
}
}
// --- Additional Auth Edge Case Tests ---
func TestRequireAuthExpiredToken(t *testing.T) {
secret := "test-secret"
userID, _ := uuid.NewV7()
// Create an expired token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": userID.String(),
"role": "Admin",
"exp": time.Now().Add(-time.Hour).Unix(), // expired 1 hour ago
})
tokenStr, _ := token.SignedString([]byte(secret))
handler := RequireAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be called for expired token")
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", rec.Code)
}
}
func TestRequireAuthWrongSigningMethod(t *testing.T) {
// Create a token with None signing method (should be rejected)
token := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{
"sub": "test-user-id",
"role": "Admin",
"exp": time.Now().Add(time.Hour).Unix(),
})
tokenStr, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
handler := RequireAuth([]byte("test-secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be called")
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", rec.Code)
}
}
func TestRequireAuthInvalidSubject(t *testing.T) {
secret := "secret"
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": "not-a-uuid",
"role": "Admin",
"exp": time.Now().Add(time.Hour).Unix(),
})
tokenStr, _ := token.SignedString([]byte(secret))
handler := RequireAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be called for invalid subject")
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", rec.Code)
}
}
func TestOptionalAuthMissingToken(t *testing.T) {
var gotClaims bool
handler := OptionalAuth([]byte("secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, gotClaims = GetClaims(r.Context())
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rec.Code)
}
if gotClaims {
t.Error("expected no claims for missing token")
}
}
func TestOptionalAuthValidToken(t *testing.T) {
secret := "secret"
userID, _ := uuid.NewV7()
companyID, _ := uuid.NewV7()
tokenStr := createTestToken(secret, userID, "Admin", &companyID)
var gotClaims Claims
var receivedCompanyID *uuid.UUID
handler := OptionalAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotClaims, _ = GetClaims(r.Context())
receivedCompanyID, _ = r.Context().Value("company_id").(*uuid.UUID)
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rec.Code)
}
if gotClaims.UserID != userID {
t.Errorf("expected userID %s, got %s", userID, gotClaims.UserID)
}
if receivedCompanyID == nil || *receivedCompanyID != companyID {
t.Errorf("expected companyID %s, got %v", companyID, receivedCompanyID)
}
}
// --- Security Headers Tests ---
func TestSecurityHeaders(t *testing.T) {
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Header().Get("X-Content-Type-Options") != "nosniff" {
t.Error("expected X-Content-Type-Options: nosniff")
}
if rec.Header().Get("X-Frame-Options") != "DENY" {
t.Error("expected X-Frame-Options: DENY")
}
if rec.Header().Get("X-XSS-Protection") != "1; mode=block" {
t.Error("expected X-XSS-Protection: 1; mode=block")
}
if rec.Header().Get("Content-Security-Policy") != "default-src 'none'" {
t.Error("expected Content-Security-Policy: default-src 'none'")
}
if rec.Header().Get("Cache-Control") != "no-store, max-age=0" {
t.Error("expected Cache-Control: no-store, max-age=0")
}
}