202 lines
6 KiB
Go
202 lines
6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/rede5/gohorsejobs/backend/internal/core/ports"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
ContextUserID contextKey = "userID"
|
|
ContextTenantID contextKey = "tenantID"
|
|
ContextRoles contextKey = "roles"
|
|
)
|
|
|
|
type Middleware struct {
|
|
authService ports.AuthService
|
|
}
|
|
|
|
func NewMiddleware(authService ports.AuthService) *Middleware {
|
|
return &Middleware{authService: authService}
|
|
}
|
|
|
|
// HeaderAuthGuard ensures valid JWT token is present.
|
|
func (m *Middleware) HeaderAuthGuard(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Printf("[AUTH DEBUG] === HeaderAuthGuard START ===\n")
|
|
fmt.Printf("[AUTH DEBUG] Method: %s, Path: %s\n", r.Method, r.URL.Path)
|
|
|
|
authHeader := r.Header.Get("Authorization")
|
|
var token string
|
|
|
|
fmt.Printf("[AUTH DEBUG] Authorization Header: '%s'\n", authHeader)
|
|
|
|
if authHeader != "" {
|
|
parts := strings.Split(authHeader, " ")
|
|
if len(parts) == 2 && parts[0] == "Bearer" {
|
|
token = parts[1]
|
|
fmt.Printf("[AUTH DEBUG] Token from Header (first 20 chars): '%s...'\n", token[:min(20, len(token))])
|
|
} else {
|
|
fmt.Printf("[AUTH DEBUG] Invalid header format: %d parts, first part: '%s'\n", len(parts), parts[0])
|
|
}
|
|
}
|
|
|
|
// Fallback to Cookie
|
|
if token == "" {
|
|
cookie, err := r.Cookie("jwt")
|
|
if err == nil {
|
|
token = cookie.Value
|
|
fmt.Printf("[AUTH DEBUG] Token from Cookie (first 20 chars): '%s...'\n", token[:min(20, len(token))])
|
|
} else {
|
|
fmt.Printf("[AUTH DEBUG] No jwt cookie found: %v\n", err)
|
|
}
|
|
}
|
|
|
|
if token == "" {
|
|
fmt.Printf("[AUTH DEBUG] No token found - returning 401\n")
|
|
http.Error(w, "Missing Authorization Header or Cookie", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
fmt.Printf("[AUTH DEBUG] Validating token...\n")
|
|
claims, err := m.authService.ValidateToken(token)
|
|
if err != nil {
|
|
fmt.Printf("[AUTH DEBUG] Token validation FAILED: %v\n", err)
|
|
http.Error(w, "Invalid Token: "+err.Error(), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
fmt.Printf("[AUTH DEBUG] Token VALID! Claims: sub=%v, tenant=%v, roles=%v\n", claims["sub"], claims["tenant"], claims["roles"])
|
|
|
|
// Inject into Context
|
|
ctx := context.WithValue(r.Context(), ContextUserID, claims["sub"])
|
|
ctx = context.WithValue(ctx, ContextTenantID, claims["tenant"])
|
|
ctx = context.WithValue(ctx, ContextRoles, claims["roles"])
|
|
|
|
fmt.Printf("[AUTH DEBUG] === HeaderAuthGuard SUCCESS ===\n")
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
// OptionalHeaderAuthGuard checks for token but allows request if missing (Context will be empty)
|
|
func (m *Middleware) OptionalHeaderAuthGuard(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
authHeader := r.Header.Get("Authorization")
|
|
var token string
|
|
|
|
if authHeader != "" {
|
|
parts := strings.Split(authHeader, " ")
|
|
if len(parts) == 2 && parts[0] == "Bearer" {
|
|
token = parts[1]
|
|
}
|
|
}
|
|
|
|
if token == "" {
|
|
cookie, err := r.Cookie("jwt")
|
|
if err == nil {
|
|
token = cookie.Value
|
|
}
|
|
}
|
|
|
|
if token == "" {
|
|
// Proceed without context
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
claims, err := m.authService.ValidateToken(token)
|
|
if err != nil {
|
|
http.Error(w, "Invalid Token: "+err.Error(), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Inject into Context
|
|
ctx := context.WithValue(r.Context(), ContextUserID, claims["sub"])
|
|
ctx = context.WithValue(ctx, ContextTenantID, claims["tenant"])
|
|
ctx = context.WithValue(ctx, ContextRoles, claims["roles"])
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
// RequireRoles ensures the authenticated user has at least one of the required roles.
|
|
func (m *Middleware) RequireRoles(roles ...string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Printf("[RBAC DEBUG] === RequireRoles START for %s %s ===\n", r.Method, r.URL.Path)
|
|
fmt.Printf("[RBAC DEBUG] Required roles: %v\n", roles)
|
|
|
|
rawRoles := r.Context().Value(ContextRoles)
|
|
fmt.Printf("[RBAC DEBUG] Raw roles from context: %v (type: %T)\n", rawRoles, rawRoles)
|
|
|
|
roleValues := ExtractRoles(rawRoles)
|
|
fmt.Printf("[RBAC DEBUG] Extracted roles: %v\n", roleValues)
|
|
|
|
if len(roleValues) == 0 {
|
|
fmt.Printf("[RBAC DEBUG] FAILED: No roles found in context\n")
|
|
http.Error(w, "Roles not found", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
if hasRole(roleValues, roles) {
|
|
fmt.Printf("[RBAC DEBUG] SUCCESS: User has required role\n")
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
fmt.Printf("[RBAC DEBUG] FAILED: User roles %v do not match required %v\n", roleValues, roles)
|
|
http.Error(w, "Forbidden: insufficient permissions", http.StatusForbidden)
|
|
})
|
|
}
|
|
}
|
|
|
|
func ExtractRoles(value interface{}) []string {
|
|
switch roles := value.(type) {
|
|
case []string:
|
|
return roles
|
|
case []interface{}:
|
|
result := make([]string, 0, len(roles))
|
|
for _, role := range roles {
|
|
if text, ok := role.(string); ok {
|
|
result = append(result, text)
|
|
}
|
|
}
|
|
return result
|
|
default:
|
|
return []string{}
|
|
}
|
|
}
|
|
|
|
func hasRole(userRoles []string, allowedRoles []string) bool {
|
|
roleSet := make(map[string]struct{}, len(userRoles))
|
|
for _, role := range userRoles {
|
|
roleSet[strings.ToLower(role)] = struct{}{}
|
|
}
|
|
|
|
for _, role := range allowedRoles {
|
|
if _, ok := roleSet[strings.ToLower(role)]; ok {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// TenantGuard ensures that the request is made by a user belonging to the prompt tenant
|
|
// Note: In this architecture, the token *defines* the tenant. So HeaderAuthGuard implicitly guards the tenant.
|
|
// This middleware is for extra checks if URL params conflict with Token tenant.
|
|
func (m *Middleware) TenantGuard(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
tenantID := r.Context().Value(ContextTenantID)
|
|
if tenantID == nil || tenantID == "" {
|
|
http.Error(w, "Tenant Context Missing", http.StatusForbidden)
|
|
return
|
|
}
|
|
// Logic to compare with URL param if needed...
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|