package middleware import ( "context" "errors" "net/http" "strings" "github.com/gofrs/uuid/v5" "github.com/golang-jwt/jwt/v5" ) type contextKey string const claimsKey contextKey = "authClaims" // Claims represents authenticated user context extracted from JWT. type Claims struct { UserID uuid.UUID Role string CompanyID *uuid.UUID } // RequireAuth validates a JWT bearer token and optionally enforces allowed roles. func RequireAuth(secret []byte, allowedRoles ...string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims, err := parseToken(r, secret) if err != nil { w.WriteHeader(http.StatusUnauthorized) return } if len(allowedRoles) > 0 && !isRoleAllowed(claims.Role, allowedRoles) { w.WriteHeader(http.StatusForbidden) return } ctx := context.WithValue(r.Context(), claimsKey, *claims) ctx = context.WithValue(ctx, "company_id", claims.CompanyID) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // OptionalAuth attempts to validate a JWT token if present, but proceeds without context if missing or invalid. func OptionalAuth(secret []byte) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims, err := parseToken(r, secret) if err == nil && claims != nil { ctx := context.WithValue(r.Context(), claimsKey, *claims) ctx = context.WithValue(ctx, "company_id", claims.CompanyID) next.ServeHTTP(w, r.WithContext(ctx)) } else { next.ServeHTTP(w, r) } }) } } func parseToken(r *http.Request, secret []byte) (*Claims, error) { authHeader := r.Header.Get("Authorization") tokenStr := authHeader if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { tokenStr = strings.TrimSpace(authHeader[7:]) } else if authHeader == "" { return nil, errors.New("missing authorization header") } jwtClaims := jwt.MapClaims{} token, err := jwt.ParseWithClaims(tokenStr, jwtClaims, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, errors.New("unexpected signing method") } return secret, nil }) if err != nil || !token.Valid { return nil, err } sub, _ := jwtClaims["sub"].(string) userID, err := uuid.FromString(sub) if err != nil { return nil, errors.New("invalid sub") } role, _ := jwtClaims["role"].(string) var companyID *uuid.UUID if cid, ok := jwtClaims["company_id"].(string); ok && cid != "" { if parsed, err := uuid.FromString(cid); err == nil { companyID = &parsed } } return &Claims{UserID: userID, Role: role, CompanyID: companyID}, nil } // GetClaims extracts JWT claims from the request context. func GetClaims(ctx context.Context) (Claims, bool) { claims, ok := ctx.Value(claimsKey).(Claims) return claims, ok } func isRoleAllowed(role string, allowed []string) bool { for _, r := range allowed { if strings.EqualFold(r, role) { return true } } return false }