Set company context from auth claims
This commit is contained in:
parent
fea6dde7ef
commit
33399171d7
2 changed files with 14 additions and 1 deletions
|
|
@ -37,6 +37,7 @@ func RequireAuth(secret []byte, allowedRoles ...string) func(http.Handler) http.
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.WithValue(r.Context(), claimsKey, *claims)
|
ctx := context.WithValue(r.Context(), claimsKey, *claims)
|
||||||
|
ctx = context.WithValue(ctx, "company_id", claims.CompanyID)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -49,6 +50,7 @@ func OptionalAuth(secret []byte) func(http.Handler) http.Handler {
|
||||||
claims, err := parseToken(r, secret)
|
claims, err := parseToken(r, secret)
|
||||||
if err == nil && claims != nil {
|
if err == nil && claims != nil {
|
||||||
ctx := context.WithValue(r.Context(), claimsKey, *claims)
|
ctx := context.WithValue(r.Context(), claimsKey, *claims)
|
||||||
|
ctx = context.WithValue(ctx, "company_id", claims.CompanyID)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
} else {
|
} else {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
|
|
|
||||||
|
|
@ -118,8 +118,10 @@ func TestRequireAuthValidToken(t *testing.T) {
|
||||||
tokenStr := createTestToken(secret, userID, "Admin", &companyID)
|
tokenStr := createTestToken(secret, userID, "Admin", &companyID)
|
||||||
|
|
||||||
var receivedClaims Claims
|
var receivedClaims Claims
|
||||||
|
var receivedCompanyID *uuid.UUID
|
||||||
handler := RequireAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := RequireAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
receivedClaims, _ = GetClaims(r.Context())
|
receivedClaims, _ = GetClaims(r.Context())
|
||||||
|
receivedCompanyID, _ = r.Context().Value("company_id").(*uuid.UUID)
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
@ -138,6 +140,9 @@ func TestRequireAuthValidToken(t *testing.T) {
|
||||||
if receivedClaims.Role != "Admin" {
|
if receivedClaims.Role != "Admin" {
|
||||||
t.Errorf("expected role 'Admin', got '%s'", receivedClaims.Role)
|
t.Errorf("expected role 'Admin', got '%s'", receivedClaims.Role)
|
||||||
}
|
}
|
||||||
|
if receivedCompanyID == nil || *receivedCompanyID != companyID {
|
||||||
|
t.Errorf("expected companyID %s, got %v", companyID, receivedCompanyID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRequireAuthMissingToken(t *testing.T) {
|
func TestRequireAuthMissingToken(t *testing.T) {
|
||||||
|
|
@ -437,11 +442,14 @@ func TestOptionalAuthMissingToken(t *testing.T) {
|
||||||
func TestOptionalAuthValidToken(t *testing.T) {
|
func TestOptionalAuthValidToken(t *testing.T) {
|
||||||
secret := "secret"
|
secret := "secret"
|
||||||
userID, _ := uuid.NewV7()
|
userID, _ := uuid.NewV7()
|
||||||
tokenStr := createTestToken(secret, userID, "Admin", nil)
|
companyID, _ := uuid.NewV7()
|
||||||
|
tokenStr := createTestToken(secret, userID, "Admin", &companyID)
|
||||||
|
|
||||||
var gotClaims Claims
|
var gotClaims Claims
|
||||||
|
var receivedCompanyID *uuid.UUID
|
||||||
handler := OptionalAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := OptionalAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
gotClaims, _ = GetClaims(r.Context())
|
gotClaims, _ = GetClaims(r.Context())
|
||||||
|
receivedCompanyID, _ = r.Context().Value("company_id").(*uuid.UUID)
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
@ -456,6 +464,9 @@ func TestOptionalAuthValidToken(t *testing.T) {
|
||||||
if gotClaims.UserID != userID {
|
if gotClaims.UserID != userID {
|
||||||
t.Errorf("expected userID %s, got %s", userID, gotClaims.UserID)
|
t.Errorf("expected userID %s, got %s", userID, gotClaims.UserID)
|
||||||
}
|
}
|
||||||
|
if receivedCompanyID == nil || *receivedCompanyID != companyID {
|
||||||
|
t.Errorf("expected companyID %s, got %v", companyID, receivedCompanyID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Security Headers Tests ---
|
// --- Security Headers Tests ---
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue