Refactor RBAC: Admin sees only their company data, Superadmin sees all
This commit is contained in:
parent
f9c9293a19
commit
fb98016afc
3 changed files with 91 additions and 11 deletions
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/rede5/gohorsejobs/backend/internal/api/middleware"
|
||||||
"github.com/rede5/gohorsejobs/backend/internal/dto"
|
"github.com/rede5/gohorsejobs/backend/internal/dto"
|
||||||
"github.com/rede5/gohorsejobs/backend/internal/services"
|
"github.com/rede5/gohorsejobs/backend/internal/services"
|
||||||
)
|
)
|
||||||
|
|
@ -160,10 +161,21 @@ func (h *AdminHandlers) UpdateCompanyStatus(w http.ResponseWriter, r *http.Reque
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *AdminHandlers) ListJobs(w http.ResponseWriter, r *http.Request) {
|
func (h *AdminHandlers) ListJobs(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
||||||
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
||||||
status := r.URL.Query().Get("status")
|
status := r.URL.Query().Get("status")
|
||||||
|
|
||||||
|
// Extract role and companyID for scoping
|
||||||
|
roles := middleware.ExtractRoles(ctx.Value(middleware.ContextRoles))
|
||||||
|
isSuperadmin := false
|
||||||
|
for _, role := range roles {
|
||||||
|
if role == "SUPERADMIN" || role == "superadmin" {
|
||||||
|
isSuperadmin = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
filter := dto.JobFilterQuery{
|
filter := dto.JobFilterQuery{
|
||||||
PaginationQuery: dto.PaginationQuery{
|
PaginationQuery: dto.PaginationQuery{
|
||||||
Page: page,
|
Page: page,
|
||||||
|
|
@ -175,6 +187,14 @@ func (h *AdminHandlers) ListJobs(w http.ResponseWriter, r *http.Request) {
|
||||||
filter.Status = &status
|
filter.Status = &status
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If Admin (not Superadmin), scope to their company
|
||||||
|
if !isSuperadmin {
|
||||||
|
companyID, _ := ctx.Value(middleware.ContextTenantID).(string)
|
||||||
|
if companyID != "" {
|
||||||
|
filter.CompanyID = &companyID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
jobs, total, err := h.jobService.GetJobs(filter)
|
jobs, total, err := h.jobService.GetJobs(filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
|
@ -300,7 +320,26 @@ func (h *AdminHandlers) UpdateTag(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *AdminHandlers) ListCandidates(w http.ResponseWriter, r *http.Request) {
|
func (h *AdminHandlers) ListCandidates(w http.ResponseWriter, r *http.Request) {
|
||||||
candidates, stats, err := h.adminService.ListCandidates(r.Context())
|
ctx := r.Context()
|
||||||
|
|
||||||
|
// Extract role for scoping
|
||||||
|
roles := middleware.ExtractRoles(ctx.Value(middleware.ContextRoles))
|
||||||
|
isSuperadmin := false
|
||||||
|
for _, role := range roles {
|
||||||
|
if role == "SUPERADMIN" || role == "superadmin" {
|
||||||
|
isSuperadmin = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var companyID *string
|
||||||
|
if !isSuperadmin {
|
||||||
|
if cid, ok := ctx.Value(middleware.ContextTenantID).(string); ok && cid != "" {
|
||||||
|
companyID = &cid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates, stats, err := h.adminService.ListCandidates(ctx, companyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -291,14 +291,19 @@ func (h *CoreHandlers) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||||
func (h *CoreHandlers) ListUsers(w http.ResponseWriter, r *http.Request) {
|
func (h *CoreHandlers) ListUsers(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
|
||||||
// Check if user is admin/superadmin (they can list all users)
|
// Check if user is admin/superadmin
|
||||||
roles := middleware.ExtractRoles(ctx.Value(middleware.ContextRoles))
|
roles := middleware.ExtractRoles(ctx.Value(middleware.ContextRoles))
|
||||||
isAdmin := false
|
isAdmin := false
|
||||||
|
isSuperadmin := false
|
||||||
for _, role := range roles {
|
for _, role := range roles {
|
||||||
if role == "ADMIN" || role == "SUPERADMIN" || role == "admin" || role == "superadmin" {
|
if role == "SUPERADMIN" || role == "superadmin" {
|
||||||
|
isSuperadmin = true
|
||||||
isAdmin = true
|
isAdmin = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
if role == "ADMIN" || role == "admin" {
|
||||||
|
isAdmin = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
||||||
|
|
@ -313,9 +318,32 @@ func (h *CoreHandlers) ListUsers(w http.ResponseWriter, r *http.Request) {
|
||||||
limit = 100
|
limit = 100
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isSuperadmin {
|
||||||
|
// Superadmin view: List all users using AdminService
|
||||||
|
users, total, err := h.adminService.ListUsers(ctx, page, limit, nil)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"data": users,
|
||||||
|
"pagination": map[string]interface{}{
|
||||||
|
"page": page,
|
||||||
|
"limit": limit,
|
||||||
|
"total": total,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(response)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if isAdmin {
|
if isAdmin {
|
||||||
// Admin view: List all users using AdminService
|
// Admin view: List users from their company only
|
||||||
users, total, err := h.adminService.ListUsers(ctx, page, limit)
|
tenantID, _ := ctx.Value(middleware.ContextTenantID).(string)
|
||||||
|
users, total, err := h.adminService.ListUsers(ctx, page, limit, &tenantID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -88,12 +88,19 @@ func (s *AdminService) ListCompanies(ctx context.Context, verified *bool, page,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListUsers returns all users with pagination (for admin view)
|
// ListUsers returns all users with pagination (for admin view)
|
||||||
func (s *AdminService) ListUsers(ctx context.Context, page, limit int) ([]dto.User, int, error) {
|
// If companyID is provided, filters users by that company.
|
||||||
|
func (s *AdminService) ListUsers(ctx context.Context, page, limit int, companyID *string) ([]dto.User, int, error) {
|
||||||
offset := (page - 1) * limit
|
offset := (page - 1) * limit
|
||||||
|
|
||||||
// Count Total
|
// Count Total
|
||||||
|
countQuery := `SELECT COUNT(*) FROM users`
|
||||||
|
var countArgs []interface{}
|
||||||
|
if companyID != nil && *companyID != "" {
|
||||||
|
countQuery += ` WHERE company_id = $1`
|
||||||
|
countArgs = append(countArgs, *companyID)
|
||||||
|
}
|
||||||
var total int
|
var total int
|
||||||
if err := s.DB.QueryRowContext(ctx, `SELECT COUNT(*) FROM users`).Scan(&total); err != nil {
|
if err := s.DB.QueryRowContext(ctx, countQuery, countArgs...).Scan(&total); err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -101,11 +108,17 @@ func (s *AdminService) ListUsers(ctx context.Context, page, limit int) ([]dto.Us
|
||||||
query := `
|
query := `
|
||||||
SELECT id, COALESCE(name, full_name, identifier, ''), email, role, COALESCE(status, 'active'), created_at
|
SELECT id, COALESCE(name, full_name, identifier, ''), email, role, COALESCE(status, 'active'), created_at
|
||||||
FROM users
|
FROM users
|
||||||
ORDER BY created_at DESC
|
|
||||||
LIMIT $1 OFFSET $2
|
|
||||||
`
|
`
|
||||||
|
var args []interface{}
|
||||||
|
if companyID != nil && *companyID != "" {
|
||||||
|
query += ` WHERE company_id = $1`
|
||||||
|
args = append(args, *companyID)
|
||||||
|
}
|
||||||
|
query += ` ORDER BY created_at DESC`
|
||||||
|
query += fmt.Sprintf(` LIMIT $%d OFFSET $%d`, len(args)+1, len(args)+2)
|
||||||
|
args = append(args, limit, offset)
|
||||||
|
|
||||||
rows, err := s.DB.QueryContext(ctx, query, limit, offset)
|
rows, err := s.DB.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
@ -310,7 +323,7 @@ func (s *AdminService) UpdateTag(ctx context.Context, id int, name *string, acti
|
||||||
return tag, nil
|
return tag, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AdminService) ListCandidates(ctx context.Context) ([]dto.Candidate, dto.CandidateStats, error) {
|
func (s *AdminService) ListCandidates(ctx context.Context, companyID *string) ([]dto.Candidate, dto.CandidateStats, error) {
|
||||||
fmt.Println("[DEBUG] Starting ListCandidates")
|
fmt.Println("[DEBUG] Starting ListCandidates")
|
||||||
query := `
|
query := `
|
||||||
SELECT id, full_name, email, phone, city, state, title, experience, bio, skills, avatar_url, created_at
|
SELECT id, full_name, email, phone, city, state, title, experience, bio, skills, avatar_url, created_at
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue