saveinmed/backend/internal/http/middleware/auth.go
2025-12-18 12:29:51 -03:00

86 lines
2.2 KiB
Go

package middleware
import (
"context"
"errors"
"net/http"
"strings"
"github.com/gofrs/uuid/v5"
"github.com/golang-jwt/jwt/v5"
)
type contextKey string
const claimsKey contextKey = "authClaims"
// Claims represents authenticated user context extracted from JWT.
type Claims struct {
UserID uuid.UUID
Role string
CompanyID *uuid.UUID
}
// RequireAuth validates a JWT bearer token and optionally enforces allowed roles.
func RequireAuth(secret []byte, allowedRoles ...string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
w.WriteHeader(http.StatusUnauthorized)
return
}
tokenStr := strings.TrimSpace(authHeader[7:])
claims := jwt.MapClaims{}
token, err := jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.New("unexpected signing method")
}
return secret, nil
})
if err != nil || !token.Valid {
w.WriteHeader(http.StatusUnauthorized)
return
}
role, _ := claims["role"].(string)
if len(allowedRoles) > 0 && !isRoleAllowed(role, allowedRoles) {
w.WriteHeader(http.StatusForbidden)
return
}
sub, _ := claims["sub"].(string)
userID, err := uuid.FromString(sub)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
var companyID *uuid.UUID
if cid, ok := claims["company_id"].(string); ok && cid != "" {
if parsed, err := uuid.FromString(cid); err == nil {
companyID = &parsed
}
}
ctx := context.WithValue(r.Context(), claimsKey, Claims{UserID: userID, Role: role, CompanyID: companyID})
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// GetClaims extracts JWT claims from the request context.
func GetClaims(ctx context.Context) (Claims, bool) {
claims, ok := ctx.Value(claimsKey).(Claims)
return claims, ok
}
func isRoleAllowed(role string, allowed []string) bool {
for _, r := range allowed {
if strings.EqualFold(r, role) {
return true
}
}
return false
}