From 33399171d76dc1d58eacf6091c5850f29bcb1ba3 Mon Sep 17 00:00:00 2001 From: Tiago Yamamoto Date: Sat, 7 Feb 2026 11:43:31 -0300 Subject: [PATCH] Set company context from auth claims --- backend-old/internal/http/middleware/auth.go | 2 ++ .../internal/http/middleware/middleware_test.go | 13 ++++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/backend-old/internal/http/middleware/auth.go b/backend-old/internal/http/middleware/auth.go index 2fc2aae..ad1c61f 100644 --- a/backend-old/internal/http/middleware/auth.go +++ b/backend-old/internal/http/middleware/auth.go @@ -37,6 +37,7 @@ func RequireAuth(secret []byte, allowedRoles ...string) func(http.Handler) http. } ctx := context.WithValue(r.Context(), claimsKey, *claims) + ctx = context.WithValue(ctx, "company_id", claims.CompanyID) next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -49,6 +50,7 @@ func OptionalAuth(secret []byte) func(http.Handler) http.Handler { claims, err := parseToken(r, secret) if err == nil && claims != nil { ctx := context.WithValue(r.Context(), claimsKey, *claims) + ctx = context.WithValue(ctx, "company_id", claims.CompanyID) next.ServeHTTP(w, r.WithContext(ctx)) } else { next.ServeHTTP(w, r) diff --git a/backend-old/internal/http/middleware/middleware_test.go b/backend-old/internal/http/middleware/middleware_test.go index 8fb9ea0..3e22c40 100644 --- a/backend-old/internal/http/middleware/middleware_test.go +++ b/backend-old/internal/http/middleware/middleware_test.go @@ -118,8 +118,10 @@ func TestRequireAuthValidToken(t *testing.T) { tokenStr := createTestToken(secret, userID, "Admin", &companyID) var receivedClaims Claims + var receivedCompanyID *uuid.UUID handler := RequireAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { receivedClaims, _ = GetClaims(r.Context()) + receivedCompanyID, _ = r.Context().Value("company_id").(*uuid.UUID) w.WriteHeader(http.StatusOK) })) @@ -138,6 +140,9 @@ func TestRequireAuthValidToken(t *testing.T) { if receivedClaims.Role != "Admin" { t.Errorf("expected role 'Admin', got '%s'", receivedClaims.Role) } + if receivedCompanyID == nil || *receivedCompanyID != companyID { + t.Errorf("expected companyID %s, got %v", companyID, receivedCompanyID) + } } func TestRequireAuthMissingToken(t *testing.T) { @@ -437,11 +442,14 @@ func TestOptionalAuthMissingToken(t *testing.T) { func TestOptionalAuthValidToken(t *testing.T) { secret := "secret" userID, _ := uuid.NewV7() - tokenStr := createTestToken(secret, userID, "Admin", nil) + companyID, _ := uuid.NewV7() + tokenStr := createTestToken(secret, userID, "Admin", &companyID) var gotClaims Claims + var receivedCompanyID *uuid.UUID handler := OptionalAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotClaims, _ = GetClaims(r.Context()) + receivedCompanyID, _ = r.Context().Value("company_id").(*uuid.UUID) w.WriteHeader(http.StatusOK) })) @@ -456,6 +464,9 @@ func TestOptionalAuthValidToken(t *testing.T) { if gotClaims.UserID != userID { t.Errorf("expected userID %s, got %s", userID, gotClaims.UserID) } + if receivedCompanyID == nil || *receivedCompanyID != companyID { + t.Errorf("expected companyID %s, got %v", companyID, receivedCompanyID) + } } // --- Security Headers Tests ---