diff --git a/backend/internal/http/middleware/middleware_test.go b/backend/internal/http/middleware/middleware_test.go index 95ece34..0651d28 100644 --- a/backend/internal/http/middleware/middleware_test.go +++ b/backend/internal/http/middleware/middleware_test.go @@ -340,3 +340,54 @@ func TestCORSLegacyWrapper(t *testing.T) { t.Errorf("expected '*', got '%s'", rec.Header().Get("Access-Control-Allow-Origin")) } } + +// --- Additional Auth Edge Case Tests --- + +func TestRequireAuthExpiredToken(t *testing.T) { + secret := "test-secret" + userID, _ := uuid.NewV4() + + // Create an expired token + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": userID.String(), + "role": "Admin", + "exp": time.Now().Add(-time.Hour).Unix(), // expired 1 hour ago + }) + tokenStr, _ := token.SignedString([]byte(secret)) + + handler := RequireAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("handler should not be called for expired token") + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer "+tokenStr) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", rec.Code) + } +} + +func TestRequireAuthWrongSigningMethod(t *testing.T) { + // Create a token with None signing method (should be rejected) + token := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{ + "sub": "test-user-id", + "role": "Admin", + "exp": time.Now().Add(time.Hour).Unix(), + }) + tokenStr, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) + + handler := RequireAuth([]byte("test-secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("handler should not be called") + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer "+tokenStr) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", rec.Code) + } +}