From fb98016afc13fb17496b9d84474ce110d4173e7e Mon Sep 17 00:00:00 2001 From: Tiago Yamamoto Date: Fri, 26 Dec 2025 10:05:57 -0300 Subject: [PATCH] Refactor RBAC: Admin sees only their company data, Superadmin sees all --- .../internal/api/handlers/admin_handlers.go | 41 ++++++++++++++++++- .../internal/api/handlers/core_handlers.go | 36 ++++++++++++++-- backend/internal/services/admin_service.go | 25 ++++++++--- 3 files changed, 91 insertions(+), 11 deletions(-) diff --git a/backend/internal/api/handlers/admin_handlers.go b/backend/internal/api/handlers/admin_handlers.go index 207cd17..ea23491 100644 --- a/backend/internal/api/handlers/admin_handlers.go +++ b/backend/internal/api/handlers/admin_handlers.go @@ -5,6 +5,7 @@ import ( "net/http" "strconv" + "github.com/rede5/gohorsejobs/backend/internal/api/middleware" "github.com/rede5/gohorsejobs/backend/internal/dto" "github.com/rede5/gohorsejobs/backend/internal/services" ) @@ -160,10 +161,21 @@ func (h *AdminHandlers) UpdateCompanyStatus(w http.ResponseWriter, r *http.Reque } func (h *AdminHandlers) ListJobs(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() page, _ := strconv.Atoi(r.URL.Query().Get("page")) limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) status := r.URL.Query().Get("status") + // Extract role and companyID for scoping + roles := middleware.ExtractRoles(ctx.Value(middleware.ContextRoles)) + isSuperadmin := false + for _, role := range roles { + if role == "SUPERADMIN" || role == "superadmin" { + isSuperadmin = true + break + } + } + filter := dto.JobFilterQuery{ PaginationQuery: dto.PaginationQuery{ Page: page, @@ -175,6 +187,14 @@ func (h *AdminHandlers) ListJobs(w http.ResponseWriter, r *http.Request) { filter.Status = &status } + // If Admin (not Superadmin), scope to their company + if !isSuperadmin { + companyID, _ := ctx.Value(middleware.ContextTenantID).(string) + if companyID != "" { + filter.CompanyID = &companyID + } + } + jobs, total, err := h.jobService.GetJobs(filter) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -300,7 +320,26 @@ func (h *AdminHandlers) UpdateTag(w http.ResponseWriter, r *http.Request) { } func (h *AdminHandlers) ListCandidates(w http.ResponseWriter, r *http.Request) { - candidates, stats, err := h.adminService.ListCandidates(r.Context()) + ctx := r.Context() + + // Extract role for scoping + roles := middleware.ExtractRoles(ctx.Value(middleware.ContextRoles)) + isSuperadmin := false + for _, role := range roles { + if role == "SUPERADMIN" || role == "superadmin" { + isSuperadmin = true + break + } + } + + var companyID *string + if !isSuperadmin { + if cid, ok := ctx.Value(middleware.ContextTenantID).(string); ok && cid != "" { + companyID = &cid + } + } + + candidates, stats, err := h.adminService.ListCandidates(ctx, companyID) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/backend/internal/api/handlers/core_handlers.go b/backend/internal/api/handlers/core_handlers.go index 1e4bfeb..f0657f7 100644 --- a/backend/internal/api/handlers/core_handlers.go +++ b/backend/internal/api/handlers/core_handlers.go @@ -291,14 +291,19 @@ func (h *CoreHandlers) CreateUser(w http.ResponseWriter, r *http.Request) { func (h *CoreHandlers) ListUsers(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - // Check if user is admin/superadmin (they can list all users) + // Check if user is admin/superadmin roles := middleware.ExtractRoles(ctx.Value(middleware.ContextRoles)) isAdmin := false + isSuperadmin := false for _, role := range roles { - if role == "ADMIN" || role == "SUPERADMIN" || role == "admin" || role == "superadmin" { + if role == "SUPERADMIN" || role == "superadmin" { + isSuperadmin = true isAdmin = true break } + if role == "ADMIN" || role == "admin" { + isAdmin = true + } } page, _ := strconv.Atoi(r.URL.Query().Get("page")) @@ -313,9 +318,32 @@ func (h *CoreHandlers) ListUsers(w http.ResponseWriter, r *http.Request) { limit = 100 } + if isSuperadmin { + // Superadmin view: List all users using AdminService + users, total, err := h.adminService.ListUsers(ctx, page, limit, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + response := map[string]interface{}{ + "data": users, + "pagination": map[string]interface{}{ + "page": page, + "limit": limit, + "total": total, + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + return + } + if isAdmin { - // Admin view: List all users using AdminService - users, total, err := h.adminService.ListUsers(ctx, page, limit) + // Admin view: List users from their company only + tenantID, _ := ctx.Value(middleware.ContextTenantID).(string) + users, total, err := h.adminService.ListUsers(ctx, page, limit, &tenantID) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/backend/internal/services/admin_service.go b/backend/internal/services/admin_service.go index 4102006..b5eefe7 100644 --- a/backend/internal/services/admin_service.go +++ b/backend/internal/services/admin_service.go @@ -88,12 +88,19 @@ func (s *AdminService) ListCompanies(ctx context.Context, verified *bool, page, } // ListUsers returns all users with pagination (for admin view) -func (s *AdminService) ListUsers(ctx context.Context, page, limit int) ([]dto.User, int, error) { +// If companyID is provided, filters users by that company. +func (s *AdminService) ListUsers(ctx context.Context, page, limit int, companyID *string) ([]dto.User, int, error) { offset := (page - 1) * limit // Count Total + countQuery := `SELECT COUNT(*) FROM users` + var countArgs []interface{} + if companyID != nil && *companyID != "" { + countQuery += ` WHERE company_id = $1` + countArgs = append(countArgs, *companyID) + } var total int - if err := s.DB.QueryRowContext(ctx, `SELECT COUNT(*) FROM users`).Scan(&total); err != nil { + if err := s.DB.QueryRowContext(ctx, countQuery, countArgs...).Scan(&total); err != nil { return nil, 0, err } @@ -101,11 +108,17 @@ func (s *AdminService) ListUsers(ctx context.Context, page, limit int) ([]dto.Us query := ` SELECT id, COALESCE(name, full_name, identifier, ''), email, role, COALESCE(status, 'active'), created_at FROM users - ORDER BY created_at DESC - LIMIT $1 OFFSET $2 ` + var args []interface{} + if companyID != nil && *companyID != "" { + query += ` WHERE company_id = $1` + args = append(args, *companyID) + } + query += ` ORDER BY created_at DESC` + query += fmt.Sprintf(` LIMIT $%d OFFSET $%d`, len(args)+1, len(args)+2) + args = append(args, limit, offset) - rows, err := s.DB.QueryContext(ctx, query, limit, offset) + rows, err := s.DB.QueryContext(ctx, query, args...) if err != nil { return nil, 0, err } @@ -310,7 +323,7 @@ func (s *AdminService) UpdateTag(ctx context.Context, id int, name *string, acti return tag, nil } -func (s *AdminService) ListCandidates(ctx context.Context) ([]dto.Candidate, dto.CandidateStats, error) { +func (s *AdminService) ListCandidates(ctx context.Context, companyID *string) ([]dto.Candidate, dto.CandidateStats, error) { fmt.Println("[DEBUG] Starting ListCandidates") query := ` SELECT id, full_name, email, phone, city, state, title, experience, bio, skills, avatar_url, created_at