325 lines
8.2 KiB
Go
325 lines
8.2 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/rede5/gohorsejobs/backend/internal/utils"
|
|
)
|
|
|
|
func TestLoggingMiddleware(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
mw := LoggingMiddleware(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
mw.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_Success(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
claims, ok := r.Context().Value(UserKey).(*utils.Claims)
|
|
if !ok {
|
|
t.Error("Claims not found in context")
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
if claims.UserID != 1 {
|
|
t.Errorf("Expected userID 1, got %d", claims.UserID)
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
mw := AuthMiddleware(handler)
|
|
|
|
token, _ := utils.GenerateJWT(1, "test-user", "user")
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rr := httptest.NewRecorder()
|
|
|
|
mw.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_isAllowed(t *testing.T) {
|
|
limiter := NewRateLimiter(3, time.Minute)
|
|
|
|
// First 3 requests should be allowed
|
|
for i := 0; i < 3; i++ {
|
|
if !limiter.isAllowed("192.168.1.1") {
|
|
t.Errorf("Request %d should be allowed", i+1)
|
|
}
|
|
}
|
|
|
|
// 4th request should be denied
|
|
if limiter.isAllowed("192.168.1.1") {
|
|
t.Error("Request 4 should be denied")
|
|
}
|
|
|
|
// Different IP should still be allowed
|
|
if !limiter.isAllowed("192.168.1.2") {
|
|
t.Error("Different IP should be allowed")
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
mw := RateLimitMiddleware(2, time.Minute)(handler)
|
|
|
|
// Create test requests
|
|
for i := 0; i < 3; i++ {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "192.168.1.100:12345"
|
|
rr := httptest.NewRecorder()
|
|
|
|
mw.ServeHTTP(rr, req)
|
|
|
|
if i < 2 {
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code)
|
|
}
|
|
} else {
|
|
if rr.Code != http.StatusTooManyRequests {
|
|
t.Errorf("Request %d: expected status 429, got %d", i+1, rr.Code)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSecurityHeadersMiddleware(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
mw := SecurityHeadersMiddleware(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
mw.ServeHTTP(rr, req)
|
|
|
|
expectedHeaders := map[string]string{
|
|
"X-Frame-Options": "DENY",
|
|
"X-Content-Type-Options": "nosniff",
|
|
"X-XSS-Protection": "1; mode=block",
|
|
}
|
|
|
|
for header, expected := range expectedHeaders {
|
|
actual := rr.Header().Get(header)
|
|
if actual != expected {
|
|
t.Errorf("Header %s: expected %q, got %q", header, expected, actual)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_NoAuthHeader(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
mw := AuthMiddleware(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
mw.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Errorf("Expected status 401, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_InvalidFormat(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
mw := AuthMiddleware(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.Header.Set("Authorization", "InvalidFormat")
|
|
rr := httptest.NewRecorder()
|
|
|
|
mw.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Errorf("Expected status 401, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_InvalidToken(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
mw := AuthMiddleware(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.Header.Set("Authorization", "Bearer invalid.token.here")
|
|
rr := httptest.NewRecorder()
|
|
|
|
mw.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Errorf("Expected status 401, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireRole_NoClaims(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
mw := RequireRole("admin")(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
mw.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Errorf("Expected status 401, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestCORSMiddleware(t *testing.T) {
|
|
os.Setenv("CORS_ORIGINS", "http://allowed.com,http://another.com")
|
|
defer os.Unsetenv("CORS_ORIGINS")
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
mw := CORSMiddleware(handler)
|
|
|
|
// Test allowed origin
|
|
req := httptest.NewRequest("OPTIONS", "/test", nil)
|
|
req.Header.Set("Origin", "http://allowed.com")
|
|
rr := httptest.NewRecorder()
|
|
|
|
mw.ServeHTTP(rr, req)
|
|
|
|
if rr.Header().Get("Access-Control-Allow-Origin") != "http://allowed.com" {
|
|
t.Errorf("Expected allow origin http://allowed.com, got %s", rr.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|
|
|
|
// Test disallowed origin
|
|
req = httptest.NewRequest("OPTIONS", "/test", nil)
|
|
req.Header.Set("Origin", "http://hacker.com")
|
|
rr = httptest.NewRecorder()
|
|
|
|
mw.ServeHTTP(rr, req)
|
|
|
|
if rr.Header().Get("Access-Control-Allow-Origin") != "" {
|
|
t.Errorf("Expected empty allow origin, got %s", rr.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|
|
}
|
|
|
|
func TestSanitizeMiddleware(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Read body to verify sanitization
|
|
var body map[string]interface{}
|
|
json.NewDecoder(r.Body).Decode(&body)
|
|
w.Header().Set("X-Sanitized-Name", body["name"].(string))
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
mw := SanitizeMiddleware(handler)
|
|
|
|
jsonBody := `{"name": "<script>alert('xss')</script>"}`
|
|
req := httptest.NewRequest("POST", "/test", strings.NewReader(jsonBody))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
rr := httptest.NewRecorder()
|
|
|
|
mw.ServeHTTP(rr, req)
|
|
|
|
expected := "<script>alert('xss')</script>"
|
|
if rr.Header().Get("X-Sanitized-Name") != expected {
|
|
t.Errorf("Expected sanitized name %s, got %s", expected, rr.Header().Get("X-Sanitized-Name"))
|
|
}
|
|
}
|
|
|
|
func TestRequireRole_Success(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
mw := RequireRole("admin")(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
// Inject claims into context manually to simulate authenticated user
|
|
claims := &utils.Claims{
|
|
UserID: 1,
|
|
Role: "admin",
|
|
}
|
|
ctx := context.WithValue(req.Context(), UserKey, claims)
|
|
rr := httptest.NewRecorder()
|
|
|
|
mw.ServeHTTP(rr, req.WithContext(ctx))
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireRole_Forbidden(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
mw := RequireRole("admin")(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
claims := &utils.Claims{
|
|
UserID: 1,
|
|
Role: "user", // Wrong role
|
|
}
|
|
ctx := context.WithValue(req.Context(), UserKey, claims)
|
|
rr := httptest.NewRecorder()
|
|
|
|
mw.ServeHTTP(rr, req.WithContext(ctx))
|
|
|
|
if rr.Code != http.StatusForbidden {
|
|
t.Errorf("Expected status 403, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestSanitizeMiddleware_InvalidJSON(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
mw := SanitizeMiddleware(handler)
|
|
|
|
jsonBody := `{"name": "broken json`
|
|
req := httptest.NewRequest("POST", "/test", strings.NewReader(jsonBody))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
rr := httptest.NewRecorder()
|
|
|
|
mw.ServeHTTP(rr, req)
|
|
|
|
// Should pass through if JSON invalid (or handle gracefully)
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200 (pass through), got %d", rr.Code)
|
|
}
|
|
}
|