package middleware import ( "context" "errors" "net/http" "strings" "github.com/gofrs/uuid/v5" "github.com/golang-jwt/jwt/v5" ) type contextKey string const claimsKey contextKey = "authClaims" // Claims represents authenticated user context extracted from JWT. type Claims struct { UserID uuid.UUID Role string CompanyID *uuid.UUID } // RequireAuth validates a JWT bearer token and optionally enforces allowed roles. func RequireAuth(secret []byte, allowedRoles ...string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") if !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { w.WriteHeader(http.StatusUnauthorized) return } tokenStr := strings.TrimSpace(authHeader[7:]) claims := jwt.MapClaims{} token, err := jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, errors.New("unexpected signing method") } return secret, nil }) if err != nil || !token.Valid { w.WriteHeader(http.StatusUnauthorized) return } role, _ := claims["role"].(string) if len(allowedRoles) > 0 && !isRoleAllowed(role, allowedRoles) { w.WriteHeader(http.StatusForbidden) return } sub, _ := claims["sub"].(string) userID, err := uuid.FromString(sub) if err != nil { w.WriteHeader(http.StatusUnauthorized) return } var companyID *uuid.UUID if cid, ok := claims["company_id"].(string); ok && cid != "" { if parsed, err := uuid.FromString(cid); err == nil { companyID = &parsed } } ctx := context.WithValue(r.Context(), claimsKey, Claims{UserID: userID, Role: role, CompanyID: companyID}) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // GetClaims extracts JWT claims from the request context. func GetClaims(ctx context.Context) (Claims, bool) { claims, ok := ctx.Value(claimsKey).(Claims) return claims, ok } func isRoleAllowed(role string, allowed []string) bool { for _, r := range allowed { if strings.EqualFold(r, role) { return true } } return false }