Refactor RBAC: Admin sees only their company data, Superadmin sees all

This commit is contained in:
Tiago Yamamoto 2025-12-26 10:05:57 -03:00
parent f9c9293a19
commit fb98016afc
3 changed files with 91 additions and 11 deletions

View file

@ -5,6 +5,7 @@ import (
"net/http"
"strconv"
"github.com/rede5/gohorsejobs/backend/internal/api/middleware"
"github.com/rede5/gohorsejobs/backend/internal/dto"
"github.com/rede5/gohorsejobs/backend/internal/services"
)
@ -160,10 +161,21 @@ func (h *AdminHandlers) UpdateCompanyStatus(w http.ResponseWriter, r *http.Reque
}
func (h *AdminHandlers) ListJobs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
status := r.URL.Query().Get("status")
// Extract role and companyID for scoping
roles := middleware.ExtractRoles(ctx.Value(middleware.ContextRoles))
isSuperadmin := false
for _, role := range roles {
if role == "SUPERADMIN" || role == "superadmin" {
isSuperadmin = true
break
}
}
filter := dto.JobFilterQuery{
PaginationQuery: dto.PaginationQuery{
Page: page,
@ -175,6 +187,14 @@ func (h *AdminHandlers) ListJobs(w http.ResponseWriter, r *http.Request) {
filter.Status = &status
}
// If Admin (not Superadmin), scope to their company
if !isSuperadmin {
companyID, _ := ctx.Value(middleware.ContextTenantID).(string)
if companyID != "" {
filter.CompanyID = &companyID
}
}
jobs, total, err := h.jobService.GetJobs(filter)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -300,7 +320,26 @@ func (h *AdminHandlers) UpdateTag(w http.ResponseWriter, r *http.Request) {
}
func (h *AdminHandlers) ListCandidates(w http.ResponseWriter, r *http.Request) {
candidates, stats, err := h.adminService.ListCandidates(r.Context())
ctx := r.Context()
// Extract role for scoping
roles := middleware.ExtractRoles(ctx.Value(middleware.ContextRoles))
isSuperadmin := false
for _, role := range roles {
if role == "SUPERADMIN" || role == "superadmin" {
isSuperadmin = true
break
}
}
var companyID *string
if !isSuperadmin {
if cid, ok := ctx.Value(middleware.ContextTenantID).(string); ok && cid != "" {
companyID = &cid
}
}
candidates, stats, err := h.adminService.ListCandidates(ctx, companyID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return

View file

@ -291,14 +291,19 @@ func (h *CoreHandlers) CreateUser(w http.ResponseWriter, r *http.Request) {
func (h *CoreHandlers) ListUsers(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Check if user is admin/superadmin (they can list all users)
// Check if user is admin/superadmin
roles := middleware.ExtractRoles(ctx.Value(middleware.ContextRoles))
isAdmin := false
isSuperadmin := false
for _, role := range roles {
if role == "ADMIN" || role == "SUPERADMIN" || role == "admin" || role == "superadmin" {
if role == "SUPERADMIN" || role == "superadmin" {
isSuperadmin = true
isAdmin = true
break
}
if role == "ADMIN" || role == "admin" {
isAdmin = true
}
}
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
@ -313,9 +318,32 @@ func (h *CoreHandlers) ListUsers(w http.ResponseWriter, r *http.Request) {
limit = 100
}
if isSuperadmin {
// Superadmin view: List all users using AdminService
users, total, err := h.adminService.ListUsers(ctx, page, limit, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
response := map[string]interface{}{
"data": users,
"pagination": map[string]interface{}{
"page": page,
"limit": limit,
"total": total,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
return
}
if isAdmin {
// Admin view: List all users using AdminService
users, total, err := h.adminService.ListUsers(ctx, page, limit)
// Admin view: List users from their company only
tenantID, _ := ctx.Value(middleware.ContextTenantID).(string)
users, total, err := h.adminService.ListUsers(ctx, page, limit, &tenantID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return

View file

@ -88,12 +88,19 @@ func (s *AdminService) ListCompanies(ctx context.Context, verified *bool, page,
}
// ListUsers returns all users with pagination (for admin view)
func (s *AdminService) ListUsers(ctx context.Context, page, limit int) ([]dto.User, int, error) {
// If companyID is provided, filters users by that company.
func (s *AdminService) ListUsers(ctx context.Context, page, limit int, companyID *string) ([]dto.User, int, error) {
offset := (page - 1) * limit
// Count Total
countQuery := `SELECT COUNT(*) FROM users`
var countArgs []interface{}
if companyID != nil && *companyID != "" {
countQuery += ` WHERE company_id = $1`
countArgs = append(countArgs, *companyID)
}
var total int
if err := s.DB.QueryRowContext(ctx, `SELECT COUNT(*) FROM users`).Scan(&total); err != nil {
if err := s.DB.QueryRowContext(ctx, countQuery, countArgs...).Scan(&total); err != nil {
return nil, 0, err
}
@ -101,11 +108,17 @@ func (s *AdminService) ListUsers(ctx context.Context, page, limit int) ([]dto.Us
query := `
SELECT id, COALESCE(name, full_name, identifier, ''), email, role, COALESCE(status, 'active'), created_at
FROM users
ORDER BY created_at DESC
LIMIT $1 OFFSET $2
`
var args []interface{}
if companyID != nil && *companyID != "" {
query += ` WHERE company_id = $1`
args = append(args, *companyID)
}
query += ` ORDER BY created_at DESC`
query += fmt.Sprintf(` LIMIT $%d OFFSET $%d`, len(args)+1, len(args)+2)
args = append(args, limit, offset)
rows, err := s.DB.QueryContext(ctx, query, limit, offset)
rows, err := s.DB.QueryContext(ctx, query, args...)
if err != nil {
return nil, 0, err
}
@ -310,7 +323,7 @@ func (s *AdminService) UpdateTag(ctx context.Context, id int, name *string, acti
return tag, nil
}
func (s *AdminService) ListCandidates(ctx context.Context) ([]dto.Candidate, dto.CandidateStats, error) {
func (s *AdminService) ListCandidates(ctx context.Context, companyID *string) ([]dto.Candidate, dto.CandidateStats, error) {
fmt.Println("[DEBUG] Starting ListCandidates")
query := `
SELECT id, full_name, email, phone, city, state, title, experience, bio, skills, avatar_url, created_at