121 lines
3.5 KiB
Go
121 lines
3.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/gofrs/uuid/v5"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/saveinmed/backend-go/internal/domain"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const claimsKey contextKey = "authClaims"
|
|
|
|
// Claims represents authenticated user context extracted from JWT.
|
|
type Claims struct {
|
|
UserID uuid.UUID
|
|
Role string
|
|
CompanyID *uuid.UUID
|
|
}
|
|
|
|
// RequireAuth validates a JWT bearer token and optionally enforces allowed roles.
|
|
func RequireAuth(secret []byte, allowedRoles ...string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
claims, err := parseToken(r, secret)
|
|
if err != nil {
|
|
log.Printf("❌ [RequireAuth] Token parse error: %v", err)
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
log.Printf("🔍 [RequireAuth] User Role: %s, Allowed Roles: %v", claims.Role, allowedRoles)
|
|
|
|
if len(allowedRoles) > 0 && !isRoleAllowed(claims.Role, allowedRoles) {
|
|
log.Printf("❌ [RequireAuth] Role %s not in allowed roles %v", claims.Role, allowedRoles)
|
|
w.WriteHeader(http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
log.Printf("✅ [RequireAuth] Access granted for role: %s", claims.Role)
|
|
|
|
ctx := context.WithValue(r.Context(), claimsKey, *claims)
|
|
ctx = context.WithValue(ctx, "company_id", claims.CompanyID)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// OptionalAuth attempts to validate a JWT token if present, but proceeds without context if missing or invalid.
|
|
func OptionalAuth(secret []byte) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
claims, err := parseToken(r, secret)
|
|
if err == nil && claims != nil {
|
|
ctx := context.WithValue(r.Context(), claimsKey, *claims)
|
|
ctx = context.WithValue(ctx, "company_id", claims.CompanyID)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
} else {
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func parseToken(r *http.Request, secret []byte) (*Claims, error) {
|
|
authHeader := r.Header.Get("Authorization")
|
|
tokenStr := authHeader
|
|
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
|
tokenStr = strings.TrimSpace(authHeader[7:])
|
|
} else if authHeader == "" {
|
|
return nil, errors.New("missing authorization header")
|
|
}
|
|
jwtClaims := jwt.MapClaims{}
|
|
token, err := jwt.ParseWithClaims(tokenStr, jwtClaims, func(token *jwt.Token) (any, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, errors.New("unexpected signing method")
|
|
}
|
|
return secret, nil
|
|
})
|
|
if err != nil || !token.Valid {
|
|
return nil, err
|
|
}
|
|
|
|
sub, _ := jwtClaims["sub"].(string)
|
|
userID, err := uuid.FromString(sub)
|
|
if err != nil {
|
|
return nil, errors.New("invalid sub")
|
|
}
|
|
|
|
role, _ := jwtClaims["role"].(string)
|
|
|
|
var companyID *uuid.UUID
|
|
if cid, ok := jwtClaims["company_id"].(string); ok && cid != "" {
|
|
if parsed, err := uuid.FromString(cid); err == nil {
|
|
companyID = &parsed
|
|
}
|
|
}
|
|
|
|
return &Claims{UserID: userID, Role: domain.NormalizeRole(role), CompanyID: companyID}, nil
|
|
}
|
|
|
|
// GetClaims extracts JWT claims from the request context.
|
|
func GetClaims(ctx context.Context) (Claims, bool) {
|
|
claims, ok := ctx.Value(claimsKey).(Claims)
|
|
return claims, ok
|
|
}
|
|
|
|
func isRoleAllowed(role string, allowed []string) bool {
|
|
normalizedRole := domain.NormalizeRole(role)
|
|
for _, r := range allowed {
|
|
if domain.NormalizeRole(r) == normalizedRole {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|