- Payment Methods: Added Pix/Credit/Debit selection in checkout, updated backend models and handlers. - Shipping: Updated Checkout UI, added shipping_settings table and seed data. - Swagger: Updated API docs, regenerated swagger.yaml. - UUIDv7: Migrated seeder and backend tests to use uuid.NewV7().
421 lines
12 KiB
Go
421 lines
12 KiB
Go
package middleware
|
|
|
|
import (
|
|
"compress/gzip"
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gofrs/uuid/v5"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
// --- CORS Tests ---
|
|
|
|
func TestCORSWithConfigAllowAll(t *testing.T) {
|
|
cfg := CORSConfig{AllowedOrigins: []string{"*"}}
|
|
handler := CORSWithConfig(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Origin", "https://example.com")
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Header().Get("Access-Control-Allow-Origin") != "*" {
|
|
t.Errorf("expected Access-Control-Allow-Origin '*', got '%s'", rec.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|
|
}
|
|
|
|
func TestCORSWithConfigSpecificOrigins(t *testing.T) {
|
|
cfg := CORSConfig{AllowedOrigins: []string{"https://allowed.com", "https://another.com"}}
|
|
handler := CORSWithConfig(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Test allowed origin
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Origin", "https://allowed.com")
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Header().Get("Access-Control-Allow-Origin") != "https://allowed.com" {
|
|
t.Errorf("expected origin 'https://allowed.com', got '%s'", rec.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|
|
|
|
// Test blocked origin
|
|
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req2.Header.Set("Origin", "https://blocked.com")
|
|
rec2 := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec2, req2)
|
|
|
|
if rec2.Header().Get("Access-Control-Allow-Origin") != "" {
|
|
t.Errorf("expected empty Access-Control-Allow-Origin, got '%s'", rec2.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|
|
}
|
|
|
|
func TestCORSPreflight(t *testing.T) {
|
|
cfg := CORSConfig{AllowedOrigins: []string{"*"}}
|
|
handler := CORSWithConfig(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusTeapot) // Should not reach here
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodOptions, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("expected status 200 for preflight, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestCORSHeaders(t *testing.T) {
|
|
cfg := CORSConfig{AllowedOrigins: []string{"*"}}
|
|
handler := CORSWithConfig(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if !strings.Contains(rec.Header().Get("Access-Control-Allow-Methods"), "GET") {
|
|
t.Error("expected Access-Control-Allow-Methods to include GET")
|
|
}
|
|
if !strings.Contains(rec.Header().Get("Access-Control-Allow-Headers"), "Authorization") {
|
|
t.Error("expected Access-Control-Allow-Headers to include Authorization")
|
|
}
|
|
}
|
|
|
|
// --- Auth Tests ---
|
|
|
|
func createTestToken(secret string, userID uuid.UUID, role string, companyID *uuid.UUID) string {
|
|
claims := jwt.MapClaims{
|
|
"sub": userID.String(),
|
|
"role": role,
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
}
|
|
if companyID != nil {
|
|
claims["company_id"] = companyID.String()
|
|
}
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
tokenStr, _ := token.SignedString([]byte(secret))
|
|
return tokenStr
|
|
}
|
|
|
|
func TestRequireAuthValidToken(t *testing.T) {
|
|
secret := "test-secret"
|
|
userID, _ := uuid.NewV7()
|
|
companyID, _ := uuid.NewV7()
|
|
tokenStr := createTestToken(secret, userID, "Admin", &companyID)
|
|
|
|
var receivedClaims Claims
|
|
handler := RequireAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
receivedClaims, _ = GetClaims(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", rec.Code)
|
|
}
|
|
if receivedClaims.UserID != userID {
|
|
t.Errorf("expected userID %s, got %s", userID, receivedClaims.UserID)
|
|
}
|
|
if receivedClaims.Role != "Admin" {
|
|
t.Errorf("expected role 'Admin', got '%s'", receivedClaims.Role)
|
|
}
|
|
}
|
|
|
|
func TestRequireAuthMissingToken(t *testing.T) {
|
|
handler := RequireAuth([]byte("secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Errorf("expected status 401, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireAuthInvalidToken(t *testing.T) {
|
|
handler := RequireAuth([]byte("secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer invalid-token")
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Errorf("expected status 401, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireAuthWrongSecret(t *testing.T) {
|
|
userID, _ := uuid.NewV7()
|
|
tokenStr := createTestToken("correct-secret", userID, "User", nil)
|
|
|
|
handler := RequireAuth([]byte("wrong-secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Errorf("expected status 401, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireAuthRoleRestriction(t *testing.T) {
|
|
secret := "secret"
|
|
userID, _ := uuid.NewV7()
|
|
tokenStr := createTestToken(secret, userID, "User", nil)
|
|
|
|
handler := RequireAuth([]byte(secret), "Admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Errorf("expected status 403, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireAuthRoleAllowed(t *testing.T) {
|
|
secret := "secret"
|
|
userID, _ := uuid.NewV7()
|
|
tokenStr := createTestToken(secret, userID, "Admin", nil)
|
|
|
|
handler := RequireAuth([]byte(secret), "Admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestGetClaimsFromContext(t *testing.T) {
|
|
claims := Claims{
|
|
UserID: uuid.Must(uuid.NewV7()),
|
|
Role: "Admin",
|
|
}
|
|
ctx := context.WithValue(context.Background(), claimsKey, claims)
|
|
|
|
retrieved, ok := GetClaims(ctx)
|
|
if !ok {
|
|
t.Error("expected to retrieve claims from context")
|
|
}
|
|
if retrieved.UserID != claims.UserID {
|
|
t.Errorf("expected userID %s, got %s", claims.UserID, retrieved.UserID)
|
|
}
|
|
}
|
|
|
|
func TestGetClaimsNotInContext(t *testing.T) {
|
|
_, ok := GetClaims(context.Background())
|
|
if ok {
|
|
t.Error("expected claims to not be in context")
|
|
}
|
|
}
|
|
|
|
// --- Gzip Tests ---
|
|
|
|
func TestGzipCompression(t *testing.T) {
|
|
handler := Gzip(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("Hello, World!"))
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Accept-Encoding", "gzip")
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Header().Get("Content-Encoding") != "gzip" {
|
|
t.Error("expected Content-Encoding 'gzip'")
|
|
}
|
|
|
|
// Decompress and verify
|
|
reader, err := gzip.NewReader(rec.Body)
|
|
if err != nil {
|
|
t.Fatalf("failed to create gzip reader: %v", err)
|
|
}
|
|
defer reader.Close()
|
|
|
|
body, err := io.ReadAll(reader)
|
|
if err != nil {
|
|
t.Fatalf("failed to read gzip body: %v", err)
|
|
}
|
|
|
|
if string(body) != "Hello, World!" {
|
|
t.Errorf("expected 'Hello, World!', got '%s'", string(body))
|
|
}
|
|
}
|
|
|
|
func TestGzipNoCompression(t *testing.T) {
|
|
handler := Gzip(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("Hello, World!"))
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
// No Accept-Encoding header
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Header().Get("Content-Encoding") == "gzip" {
|
|
t.Error("should not use gzip when not requested")
|
|
}
|
|
if rec.Body.String() != "Hello, World!" {
|
|
t.Errorf("expected 'Hello, World!', got '%s'", rec.Body.String())
|
|
}
|
|
}
|
|
|
|
// --- Logger Tests ---
|
|
|
|
func TestLoggerMiddleware(t *testing.T) {
|
|
handler := Logger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test-path", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// --- CORS Legacy Wrapper Test ---
|
|
|
|
func TestCORSLegacyWrapper(t *testing.T) {
|
|
handler := CORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Origin", "https://example.com")
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Header().Get("Access-Control-Allow-Origin") != "*" {
|
|
t.Errorf("expected '*', got '%s'", rec.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|
|
}
|
|
|
|
// --- Additional Auth Edge Case Tests ---
|
|
|
|
func TestRequireAuthExpiredToken(t *testing.T) {
|
|
secret := "test-secret"
|
|
userID, _ := uuid.NewV7()
|
|
|
|
// Create an expired token
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
"sub": userID.String(),
|
|
"role": "Admin",
|
|
"exp": time.Now().Add(-time.Hour).Unix(), // expired 1 hour ago
|
|
})
|
|
tokenStr, _ := token.SignedString([]byte(secret))
|
|
|
|
handler := RequireAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Error("handler should not be called for expired token")
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Errorf("expected 401, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireAuthWrongSigningMethod(t *testing.T) {
|
|
// Create a token with None signing method (should be rejected)
|
|
token := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{
|
|
"sub": "test-user-id",
|
|
"role": "Admin",
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
})
|
|
tokenStr, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
|
|
|
|
handler := RequireAuth([]byte("test-secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Error("handler should not be called")
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Errorf("expected 401, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// --- Security Headers Tests ---
|
|
|
|
func TestSecurityHeaders(t *testing.T) {
|
|
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Header().Get("X-Content-Type-Options") != "nosniff" {
|
|
t.Error("expected X-Content-Type-Options: nosniff")
|
|
}
|
|
if rec.Header().Get("X-Frame-Options") != "DENY" {
|
|
t.Error("expected X-Frame-Options: DENY")
|
|
}
|
|
if rec.Header().Get("X-XSS-Protection") != "1; mode=block" {
|
|
t.Error("expected X-XSS-Protection: 1; mode=block")
|
|
}
|
|
if rec.Header().Get("Content-Security-Policy") != "default-src 'none'" {
|
|
t.Error("expected Content-Security-Policy: default-src 'none'")
|
|
}
|
|
if rec.Header().Get("Cache-Control") != "no-store, max-age=0" {
|
|
t.Error("expected Cache-Control: no-store, max-age=0")
|
|
}
|
|
}
|