Merge pull request #79 from rede5/codex/update-context-middleware-for-pharmacy-tenant
Inject company_id into request context from auth middleware
This commit is contained in:
commit
11bff5af03
2 changed files with 14 additions and 1 deletions
|
|
@ -37,6 +37,7 @@ func RequireAuth(secret []byte, allowedRoles ...string) func(http.Handler) http.
|
|||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), claimsKey, *claims)
|
||||
ctx = context.WithValue(ctx, "company_id", claims.CompanyID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
|
@ -49,6 +50,7 @@ func OptionalAuth(secret []byte) func(http.Handler) http.Handler {
|
|||
claims, err := parseToken(r, secret)
|
||||
if err == nil && claims != nil {
|
||||
ctx := context.WithValue(r.Context(), claimsKey, *claims)
|
||||
ctx = context.WithValue(ctx, "company_id", claims.CompanyID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
} else {
|
||||
next.ServeHTTP(w, r)
|
||||
|
|
|
|||
|
|
@ -118,8 +118,10 @@ func TestRequireAuthValidToken(t *testing.T) {
|
|||
tokenStr := createTestToken(secret, userID, "Admin", &companyID)
|
||||
|
||||
var receivedClaims Claims
|
||||
var receivedCompanyID *uuid.UUID
|
||||
handler := RequireAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedClaims, _ = GetClaims(r.Context())
|
||||
receivedCompanyID, _ = r.Context().Value("company_id").(*uuid.UUID)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
|
|
@ -138,6 +140,9 @@ func TestRequireAuthValidToken(t *testing.T) {
|
|||
if receivedClaims.Role != "Admin" {
|
||||
t.Errorf("expected role 'Admin', got '%s'", receivedClaims.Role)
|
||||
}
|
||||
if receivedCompanyID == nil || *receivedCompanyID != companyID {
|
||||
t.Errorf("expected companyID %s, got %v", companyID, receivedCompanyID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuthMissingToken(t *testing.T) {
|
||||
|
|
@ -437,11 +442,14 @@ func TestOptionalAuthMissingToken(t *testing.T) {
|
|||
func TestOptionalAuthValidToken(t *testing.T) {
|
||||
secret := "secret"
|
||||
userID, _ := uuid.NewV7()
|
||||
tokenStr := createTestToken(secret, userID, "Admin", nil)
|
||||
companyID, _ := uuid.NewV7()
|
||||
tokenStr := createTestToken(secret, userID, "Admin", &companyID)
|
||||
|
||||
var gotClaims Claims
|
||||
var receivedCompanyID *uuid.UUID
|
||||
handler := OptionalAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotClaims, _ = GetClaims(r.Context())
|
||||
receivedCompanyID, _ = r.Context().Value("company_id").(*uuid.UUID)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
|
|
@ -456,6 +464,9 @@ func TestOptionalAuthValidToken(t *testing.T) {
|
|||
if gotClaims.UserID != userID {
|
||||
t.Errorf("expected userID %s, got %s", userID, gotClaims.UserID)
|
||||
}
|
||||
if receivedCompanyID == nil || *receivedCompanyID != companyID {
|
||||
t.Errorf("expected companyID %s, got %v", companyID, receivedCompanyID)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Security Headers Tests ---
|
||||
|
|
|
|||
Loading…
Reference in a new issue