saveinmed/backend-old/internal/http/middleware/auth.go
2026-02-07 11:43:31 -03:00

112 lines
3 KiB
Go

package middleware
import (
"context"
"errors"
"net/http"
"strings"
"github.com/gofrs/uuid/v5"
"github.com/golang-jwt/jwt/v5"
)
type contextKey string
const claimsKey contextKey = "authClaims"
// Claims represents authenticated user context extracted from JWT.
type Claims struct {
UserID uuid.UUID
Role string
CompanyID *uuid.UUID
}
// RequireAuth validates a JWT bearer token and optionally enforces allowed roles.
func RequireAuth(secret []byte, allowedRoles ...string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, err := parseToken(r, secret)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
if len(allowedRoles) > 0 && !isRoleAllowed(claims.Role, allowedRoles) {
w.WriteHeader(http.StatusForbidden)
return
}
ctx := context.WithValue(r.Context(), claimsKey, *claims)
ctx = context.WithValue(ctx, "company_id", claims.CompanyID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// OptionalAuth attempts to validate a JWT token if present, but proceeds without context if missing or invalid.
func OptionalAuth(secret []byte) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, err := parseToken(r, secret)
if err == nil && claims != nil {
ctx := context.WithValue(r.Context(), claimsKey, *claims)
ctx = context.WithValue(ctx, "company_id", claims.CompanyID)
next.ServeHTTP(w, r.WithContext(ctx))
} else {
next.ServeHTTP(w, r)
}
})
}
}
func parseToken(r *http.Request, secret []byte) (*Claims, error) {
authHeader := r.Header.Get("Authorization")
tokenStr := authHeader
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
tokenStr = strings.TrimSpace(authHeader[7:])
} else if authHeader == "" {
return nil, errors.New("missing authorization header")
}
jwtClaims := jwt.MapClaims{}
token, err := jwt.ParseWithClaims(tokenStr, jwtClaims, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.New("unexpected signing method")
}
return secret, nil
})
if err != nil || !token.Valid {
return nil, err
}
sub, _ := jwtClaims["sub"].(string)
userID, err := uuid.FromString(sub)
if err != nil {
return nil, errors.New("invalid sub")
}
role, _ := jwtClaims["role"].(string)
var companyID *uuid.UUID
if cid, ok := jwtClaims["company_id"].(string); ok && cid != "" {
if parsed, err := uuid.FromString(cid); err == nil {
companyID = &parsed
}
}
return &Claims{UserID: userID, Role: role, CompanyID: companyID}, nil
}
// GetClaims extracts JWT claims from the request context.
func GetClaims(ctx context.Context) (Claims, bool) {
claims, ok := ctx.Value(claimsKey).(Claims)
return claims, ok
}
func isRoleAllowed(role string, allowed []string) bool {
for _, r := range allowed {
if strings.EqualFold(r, role) {
return true
}
}
return false
}