122 lines
3.3 KiB
Go
122 lines
3.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/rede5/gohorsejobs/backend/internal/core/ports"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
ContextUserID contextKey = "userID"
|
|
ContextTenantID contextKey = "tenantID"
|
|
ContextRoles contextKey = "roles"
|
|
)
|
|
|
|
type Middleware struct {
|
|
authService ports.AuthService
|
|
}
|
|
|
|
func NewMiddleware(authService ports.AuthService) *Middleware {
|
|
return &Middleware{authService: authService}
|
|
}
|
|
|
|
func (m *Middleware) HeaderAuthGuard(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
http.Error(w, "Missing Authorization Header", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
parts := strings.Split(authHeader, " ")
|
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
|
http.Error(w, "Invalid Header Format", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
token := parts[1]
|
|
claims, err := m.authService.ValidateToken(token)
|
|
if err != nil {
|
|
http.Error(w, "Invalid Token: "+err.Error(), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Inject into Context
|
|
ctx := context.WithValue(r.Context(), ContextUserID, claims["sub"])
|
|
ctx = context.WithValue(ctx, ContextTenantID, claims["tenant"])
|
|
ctx = context.WithValue(ctx, ContextRoles, claims["roles"])
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
// RequireRoles ensures the authenticated user has at least one of the required roles.
|
|
func (m *Middleware) RequireRoles(roles ...string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
roleValues := extractRoles(r.Context().Value(ContextRoles))
|
|
if len(roleValues) == 0 {
|
|
http.Error(w, "Roles not found", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
if hasRole(roleValues, roles) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
http.Error(w, "Forbidden: insufficient permissions", http.StatusForbidden)
|
|
})
|
|
}
|
|
}
|
|
|
|
func extractRoles(value interface{}) []string {
|
|
switch roles := value.(type) {
|
|
case []string:
|
|
return roles
|
|
case []interface{}:
|
|
result := make([]string, 0, len(roles))
|
|
for _, role := range roles {
|
|
if text, ok := role.(string); ok {
|
|
result = append(result, text)
|
|
}
|
|
}
|
|
return result
|
|
default:
|
|
return []string{}
|
|
}
|
|
}
|
|
|
|
func hasRole(userRoles []string, allowedRoles []string) bool {
|
|
roleSet := make(map[string]struct{}, len(userRoles))
|
|
for _, role := range userRoles {
|
|
roleSet[strings.ToLower(role)] = struct{}{}
|
|
}
|
|
|
|
for _, role := range allowedRoles {
|
|
if _, ok := roleSet[strings.ToLower(role)]; ok {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// TenantGuard ensures that the request is made by a user belonging to the prompt tenant
|
|
// Note: In this architecture, the token *defines* the tenant. So HeaderAuthGuard implicitly guards the tenant.
|
|
// This middleware is for extra checks if URL params conflict with Token tenant.
|
|
func (m *Middleware) TenantGuard(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
tenantID := r.Context().Value(ContextTenantID)
|
|
if tenantID == nil || tenantID == "" {
|
|
http.Error(w, "Tenant Context Missing", http.StatusForbidden)
|
|
return
|
|
}
|
|
// Logic to compare with URL param if needed...
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|