Implement JWT auth and company verification

This commit is contained in:
Tiago Yamamoto 2025-12-18 12:29:51 -03:00
parent b4fd89f4a8
commit e57445847b
9 changed files with 381 additions and 33 deletions

View file

@ -4,6 +4,7 @@ go 1.24.3
require ( require (
github.com/gofrs/uuid/v5 v5.4.0 github.com/gofrs/uuid/v5 v5.4.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/jackc/pgx/v5 v5.7.6 github.com/jackc/pgx/v5 v5.7.6
github.com/jmoiron/sqlx v1.4.0 github.com/jmoiron/sqlx v1.4.0
github.com/json-iterator/go v1.1.12 github.com/json-iterator/go v1.1.12

View file

@ -20,6 +20,8 @@ github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpv
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/gofrs/uuid/v5 v5.4.0 h1:EfbpCTjqMuGyq5ZJwxqzn3Cbr2d0rUZU7v5ycAk/e/0= github.com/gofrs/uuid/v5 v5.4.0 h1:EfbpCTjqMuGyq5ZJwxqzn3Cbr2d0rUZU7v5ycAk/e/0=
github.com/gofrs/uuid/v5 v5.4.0/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= github.com/gofrs/uuid/v5 v5.4.0/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=

View file

@ -17,6 +17,8 @@ type Config struct {
ConnMaxIdle time.Duration ConnMaxIdle time.Duration
MercadoPagoBaseURL string MercadoPagoBaseURL string
MarketplaceCommission float64 MarketplaceCommission float64
JWTSecret string
JWTExpiresIn time.Duration
} }
// Load reads configuration from environment variables and applies sane defaults // Load reads configuration from environment variables and applies sane defaults
@ -31,6 +33,8 @@ func Load() Config {
ConnMaxIdle: getEnvDuration("DB_CONN_MAX_IDLE", 5*time.Minute), ConnMaxIdle: getEnvDuration("DB_CONN_MAX_IDLE", 5*time.Minute),
MercadoPagoBaseURL: getEnv("MERCADOPAGO_BASE_URL", "https://api.mercadopago.com"), MercadoPagoBaseURL: getEnv("MERCADOPAGO_BASE_URL", "https://api.mercadopago.com"),
MarketplaceCommission: getEnvFloat("MARKETPLACE_COMMISSION", 2.5), MarketplaceCommission: getEnvFloat("MARKETPLACE_COMMISSION", 2.5),
JWTSecret: getEnv("JWT_SECRET", "dev-secret"),
JWTExpiresIn: getEnvDuration("JWT_EXPIRES_IN", 24*time.Hour),
} }
return cfg return cfg

View file

@ -8,13 +8,14 @@ import (
// Company represents a B2B actor in the marketplace. // Company represents a B2B actor in the marketplace.
type Company struct { type Company struct {
ID uuid.UUID `db:"id" json:"id"` ID uuid.UUID `db:"id" json:"id"`
Role string `db:"role" json:"role"` // pharmacy, distributor, admin Role string `db:"role" json:"role"` // pharmacy, distributor, admin
CNPJ string `db:"cnpj" json:"cnpj"` CNPJ string `db:"cnpj" json:"cnpj"`
CorporateName string `db:"corporate_name" json:"corporate_name"` CorporateName string `db:"corporate_name" json:"corporate_name"`
SanitaryLicense string `db:"sanitary_license" json:"sanitary_license"` LicenseNumber string `db:"license_number" json:"license_number"`
CreatedAt time.Time `db:"created_at" json:"created_at"` IsVerified bool `db:"is_verified" json:"is_verified"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"` CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
} }
// User represents an authenticated actor inside a company. // User represents an authenticated actor inside a company.

View file

@ -13,6 +13,7 @@ import (
"github.com/gofrs/uuid/v5" "github.com/gofrs/uuid/v5"
"github.com/saveinmed/backend-go/internal/domain" "github.com/saveinmed/backend-go/internal/domain"
"github.com/saveinmed/backend-go/internal/http/middleware"
"github.com/saveinmed/backend-go/internal/usecase" "github.com/saveinmed/backend-go/internal/usecase"
) )
@ -26,6 +27,68 @@ func New(svc *usecase.Service) *Handler {
return &Handler{svc: svc} return &Handler{svc: svc}
} }
// Register handles sign-up creating a company when requested.
func (h *Handler) Register(w http.ResponseWriter, r *http.Request) {
var req registerAuthRequest
if err := decodeJSON(r.Context(), r, &req); err != nil {
writeError(w, http.StatusBadRequest, err)
return
}
var company *domain.Company
if req.Company != nil {
company = &domain.Company{
ID: req.Company.ID,
Role: req.Company.Role,
CNPJ: req.Company.CNPJ,
CorporateName: req.Company.CorporateName,
LicenseNumber: req.Company.LicenseNumber,
}
}
user := &domain.User{
CompanyID: req.CompanyID,
Role: req.Role,
Name: req.Name,
Email: req.Email,
}
if user.CompanyID == uuid.Nil && company == nil {
writeError(w, http.StatusBadRequest, errors.New("company_id or company payload is required"))
return
}
if err := h.svc.RegisterAccount(r.Context(), company, user, req.Password); err != nil {
writeError(w, http.StatusInternalServerError, err)
return
}
token, exp, err := h.svc.Authenticate(r.Context(), user.Email, req.Password)
if err != nil {
writeError(w, http.StatusInternalServerError, err)
return
}
writeJSON(w, http.StatusCreated, authResponse{Token: token, ExpiresAt: exp})
}
// Login validates credentials and emits a JWT token.
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
var req loginRequest
if err := decodeJSON(r.Context(), r, &req); err != nil {
writeError(w, http.StatusBadRequest, err)
return
}
token, exp, err := h.svc.Authenticate(r.Context(), req.Email, req.Password)
if err != nil {
writeError(w, http.StatusUnauthorized, err)
return
}
writeJSON(w, http.StatusOK, authResponse{Token: token, ExpiresAt: exp})
}
// CreateCompany godoc // CreateCompany godoc
// @Summary Registro de empresas // @Summary Registro de empresas
// @Description Cadastra farmácia, distribuidora ou administrador com CNPJ e licença sanitária. // @Description Cadastra farmácia, distribuidora ou administrador com CNPJ e licença sanitária.
@ -43,10 +106,10 @@ func (h *Handler) CreateCompany(w http.ResponseWriter, r *http.Request) {
} }
company := &domain.Company{ company := &domain.Company{
Role: req.Role, Role: req.Role,
CNPJ: req.CNPJ, CNPJ: req.CNPJ,
CorporateName: req.CorporateName, CorporateName: req.CorporateName,
SanitaryLicense: req.SanitaryLicense, LicenseNumber: req.LicenseNumber,
} }
if err := h.svc.RegisterCompany(r.Context(), company); err != nil { if err := h.svc.RegisterCompany(r.Context(), company); err != nil {
@ -72,6 +135,45 @@ func (h *Handler) ListCompanies(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, companies) writeJSON(w, http.StatusOK, companies)
} }
// VerifyCompany toggles the verification flag for a company (admin only).
func (h *Handler) VerifyCompany(w http.ResponseWriter, r *http.Request) {
if !strings.HasSuffix(r.URL.Path, "/verify") {
http.NotFound(w, r)
return
}
id, err := parseUUIDFromPath(r.URL.Path)
if err != nil {
writeError(w, http.StatusBadRequest, err)
return
}
company, err := h.svc.VerifyCompany(r.Context(), id)
if err != nil {
writeError(w, http.StatusInternalServerError, err)
return
}
writeJSON(w, http.StatusOK, company)
}
// GetMyCompany returns the company linked to the authenticated user.
func (h *Handler) GetMyCompany(w http.ResponseWriter, r *http.Request) {
claims, ok := middleware.GetClaims(r.Context())
if !ok || claims.CompanyID == nil {
writeError(w, http.StatusBadRequest, errors.New("missing company context"))
return
}
company, err := h.svc.GetCompany(r.Context(), *claims.CompanyID)
if err != nil {
writeError(w, http.StatusNotFound, err)
return
}
writeJSON(w, http.StatusOK, company)
}
// CreateProduct godoc // CreateProduct godoc
// @Summary Cadastro de produto com rastreabilidade de lote // @Summary Cadastro de produto com rastreabilidade de lote
// @Tags Produtos // @Tags Produtos
@ -450,6 +552,33 @@ type createUserRequest struct {
Password string `json:"password"` Password string `json:"password"`
} }
type registerAuthRequest struct {
CompanyID *uuid.UUID `json:"company_id,omitempty"`
Company *registerCompanyTarget `json:"company,omitempty"`
Role string `json:"role"`
Name string `json:"name"`
Email string `json:"email"`
Password string `json:"password"`
}
type registerCompanyTarget struct {
ID uuid.UUID `json:"id,omitempty"`
Role string `json:"role"`
CNPJ string `json:"cnpj"`
CorporateName string `json:"corporate_name"`
LicenseNumber string `json:"license_number"`
}
type loginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}
type authResponse struct {
Token string `json:"token"`
ExpiresAt time.Time `json:"expires_at"`
}
type updateUserRequest struct { type updateUserRequest struct {
CompanyID *uuid.UUID `json:"company_id,omitempty"` CompanyID *uuid.UUID `json:"company_id,omitempty"`
Role *string `json:"role,omitempty"` Role *string `json:"role,omitempty"`
@ -483,6 +612,9 @@ func parsePagination(r *http.Request) (int, int) {
} }
func getRequester(r *http.Request) (requester, error) { func getRequester(r *http.Request) (requester, error) {
if claims, ok := middleware.GetClaims(r.Context()); ok {
return requester{Role: claims.Role, CompanyID: claims.CompanyID}, nil
}
role := r.Header.Get("X-User-Role") role := r.Header.Get("X-User-Role")
if role == "" { if role == "" {
role = "Admin" role = "Admin"
@ -501,10 +633,10 @@ func getRequester(r *http.Request) (requester, error) {
} }
type registerCompanyRequest struct { type registerCompanyRequest struct {
Role string `json:"role"` Role string `json:"role"`
CNPJ string `json:"cnpj"` CNPJ string `json:"cnpj"`
CorporateName string `json:"corporate_name"` CorporateName string `json:"corporate_name"`
SanitaryLicense string `json:"sanitary_license"` LicenseNumber string `json:"license_number"`
} }
type registerProductRequest struct { type registerProductRequest struct {

View file

@ -0,0 +1,86 @@
package middleware
import (
"context"
"errors"
"net/http"
"strings"
"github.com/gofrs/uuid/v5"
"github.com/golang-jwt/jwt/v5"
)
type contextKey string
const claimsKey contextKey = "authClaims"
// Claims represents authenticated user context extracted from JWT.
type Claims struct {
UserID uuid.UUID
Role string
CompanyID *uuid.UUID
}
// RequireAuth validates a JWT bearer token and optionally enforces allowed roles.
func RequireAuth(secret []byte, allowedRoles ...string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
w.WriteHeader(http.StatusUnauthorized)
return
}
tokenStr := strings.TrimSpace(authHeader[7:])
claims := jwt.MapClaims{}
token, err := jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.New("unexpected signing method")
}
return secret, nil
})
if err != nil || !token.Valid {
w.WriteHeader(http.StatusUnauthorized)
return
}
role, _ := claims["role"].(string)
if len(allowedRoles) > 0 && !isRoleAllowed(role, allowedRoles) {
w.WriteHeader(http.StatusForbidden)
return
}
sub, _ := claims["sub"].(string)
userID, err := uuid.FromString(sub)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
var companyID *uuid.UUID
if cid, ok := claims["company_id"].(string); ok && cid != "" {
if parsed, err := uuid.FromString(cid); err == nil {
companyID = &parsed
}
}
ctx := context.WithValue(r.Context(), claimsKey, Claims{UserID: userID, Role: role, CompanyID: companyID})
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// GetClaims extracts JWT claims from the request context.
func GetClaims(ctx context.Context) (Claims, bool) {
claims, ok := ctx.Value(claimsKey).(Claims)
return claims, ok
}
func isRoleAllowed(role string, allowed []string) bool {
for _, r := range allowed {
if strings.EqualFold(r, role) {
return true
}
}
return false
}

View file

@ -28,8 +28,8 @@ func (r *Repository) CreateCompany(ctx context.Context, company *domain.Company)
company.CreatedAt = now company.CreatedAt = now
company.UpdatedAt = now company.UpdatedAt = now
query := `INSERT INTO companies (id, role, cnpj, corporate_name, sanitary_license, created_at, updated_at) query := `INSERT INTO companies (id, role, cnpj, corporate_name, license_number, is_verified, created_at, updated_at)
VALUES (:id, :role, :cnpj, :corporate_name, :sanitary_license, :created_at, :updated_at)` VALUES (:id, :role, :cnpj, :corporate_name, :license_number, :is_verified, :created_at, :updated_at)`
_, err := r.db.NamedExecContext(ctx, query, company) _, err := r.db.NamedExecContext(ctx, query, company)
return err return err
@ -37,13 +37,43 @@ VALUES (:id, :role, :cnpj, :corporate_name, :sanitary_license, :created_at, :upd
func (r *Repository) ListCompanies(ctx context.Context) ([]domain.Company, error) { func (r *Repository) ListCompanies(ctx context.Context) ([]domain.Company, error) {
var companies []domain.Company var companies []domain.Company
query := `SELECT id, role, cnpj, corporate_name, sanitary_license, created_at, updated_at FROM companies ORDER BY created_at DESC` query := `SELECT id, role, cnpj, corporate_name, license_number, is_verified, created_at, updated_at FROM companies ORDER BY created_at DESC`
if err := r.db.SelectContext(ctx, &companies, query); err != nil { if err := r.db.SelectContext(ctx, &companies, query); err != nil {
return nil, err return nil, err
} }
return companies, nil return companies, nil
} }
func (r *Repository) GetCompany(ctx context.Context, id uuid.UUID) (*domain.Company, error) {
var company domain.Company
query := `SELECT id, role, cnpj, corporate_name, license_number, is_verified, created_at, updated_at FROM companies WHERE id = $1`
if err := r.db.GetContext(ctx, &company, query, id); err != nil {
return nil, err
}
return &company, nil
}
func (r *Repository) UpdateCompany(ctx context.Context, company *domain.Company) error {
company.UpdatedAt = time.Now().UTC()
query := `UPDATE companies
SET role = :role, cnpj = :cnpj, corporate_name = :corporate_name, license_number = :license_number, is_verified = :is_verified, updated_at = :updated_at
WHERE id = :id`
res, err := r.db.NamedExecContext(ctx, query, company)
if err != nil {
return err
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return errors.New("company not found")
}
return nil
}
func (r *Repository) CreateProduct(ctx context.Context, product *domain.Product) error { func (r *Repository) CreateProduct(ctx context.Context, product *domain.Product) error {
now := time.Now().UTC() now := time.Now().UTC()
product.CreatedAt = now product.CreatedAt = now
@ -181,6 +211,15 @@ func (r *Repository) GetUser(ctx context.Context, id uuid.UUID) (*domain.User, e
return &user, nil return &user, nil
} }
func (r *Repository) GetUserByEmail(ctx context.Context, email string) (*domain.User, error) {
var user domain.User
query := `SELECT id, company_id, role, name, email, password_hash, created_at, updated_at FROM users WHERE email = $1`
if err := r.db.GetContext(ctx, &user, query, email); err != nil {
return nil, err
}
return &user, nil
}
func (r *Repository) UpdateUser(ctx context.Context, user *domain.User) error { func (r *Repository) UpdateUser(ctx context.Context, user *domain.User) error {
user.UpdatedAt = time.Now().UTC() user.UpdatedAt = time.Now().UTC()
@ -225,7 +264,8 @@ CREATE TABLE IF NOT EXISTS companies (
role TEXT NOT NULL, role TEXT NOT NULL,
cnpj TEXT NOT NULL UNIQUE, cnpj TEXT NOT NULL UNIQUE,
corporate_name TEXT NOT NULL, corporate_name TEXT NOT NULL,
sanitary_license TEXT NOT NULL, license_number TEXT NOT NULL,
is_verified BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL, created_at TIMESTAMPTZ NOT NULL,
updated_at TIMESTAMPTZ NOT NULL updated_at TIMESTAMPTZ NOT NULL
); );

View file

@ -37,7 +37,7 @@ func New(cfg config.Config) (*Server, error) {
repo := postgres.New(db) repo := postgres.New(db)
gateway := payments.NewMercadoPagoGateway(cfg.MercadoPagoBaseURL, cfg.MarketplaceCommission) gateway := payments.NewMercadoPagoGateway(cfg.MercadoPagoBaseURL, cfg.MarketplaceCommission)
svc := usecase.NewService(repo, gateway) svc := usecase.NewService(repo, gateway, cfg.JWTSecret, cfg.JWTExpiresIn)
h := handler.New(svc) h := handler.New(svc)
mux := http.NewServeMux() mux := http.NewServeMux()
@ -59,22 +59,30 @@ func New(cfg config.Config) (*Server, error) {
_, _ = w.Write([]byte("ok")) _, _ = w.Write([]byte("ok"))
}) })
auth := middleware.RequireAuth([]byte(cfg.JWTSecret))
adminOnly := middleware.RequireAuth([]byte(cfg.JWTSecret), "Admin")
mux.Handle("POST /api/companies", chain(http.HandlerFunc(h.CreateCompany), middleware.Logger, middleware.Gzip)) mux.Handle("POST /api/companies", chain(http.HandlerFunc(h.CreateCompany), middleware.Logger, middleware.Gzip))
mux.Handle("GET /api/companies", chain(http.HandlerFunc(h.ListCompanies), middleware.Logger, middleware.Gzip)) mux.Handle("GET /api/companies", chain(http.HandlerFunc(h.ListCompanies), middleware.Logger, middleware.Gzip))
mux.Handle("PATCH /api/v1/companies/", chain(http.HandlerFunc(h.VerifyCompany), middleware.Logger, middleware.Gzip, adminOnly))
mux.Handle("GET /api/v1/companies/me", chain(http.HandlerFunc(h.GetMyCompany), middleware.Logger, middleware.Gzip, auth))
mux.Handle("POST /api/products", chain(http.HandlerFunc(h.CreateProduct), middleware.Logger, middleware.Gzip)) mux.Handle("POST /api/products", chain(http.HandlerFunc(h.CreateProduct), middleware.Logger, middleware.Gzip))
mux.Handle("GET /api/products", chain(http.HandlerFunc(h.ListProducts), middleware.Logger, middleware.Gzip)) mux.Handle("GET /api/products", chain(http.HandlerFunc(h.ListProducts), middleware.Logger, middleware.Gzip))
mux.Handle("POST /api/orders", chain(http.HandlerFunc(h.CreateOrder), middleware.Logger, middleware.Gzip)) mux.Handle("POST /api/orders", chain(http.HandlerFunc(h.CreateOrder), middleware.Logger, middleware.Gzip, auth))
mux.Handle("GET /api/orders/", chain(http.HandlerFunc(h.GetOrder), middleware.Logger, middleware.Gzip)) mux.Handle("GET /api/orders/", chain(http.HandlerFunc(h.GetOrder), middleware.Logger, middleware.Gzip, auth))
mux.Handle("PATCH /api/orders/", chain(http.HandlerFunc(h.UpdateOrderStatus), middleware.Logger, middleware.Gzip)) mux.Handle("PATCH /api/orders/", chain(http.HandlerFunc(h.UpdateOrderStatus), middleware.Logger, middleware.Gzip, auth))
mux.Handle("POST /api/orders/", chain(http.HandlerFunc(h.CreatePaymentPreference), middleware.Logger, middleware.Gzip)) mux.Handle("POST /api/orders/", chain(http.HandlerFunc(h.CreatePaymentPreference), middleware.Logger, middleware.Gzip, auth))
mux.Handle("POST /api/v1/users", chain(http.HandlerFunc(h.CreateUser), middleware.Logger, middleware.Gzip)) mux.Handle("POST /api/v1/auth/register", chain(http.HandlerFunc(h.Register), middleware.Logger, middleware.Gzip))
mux.Handle("GET /api/v1/users", chain(http.HandlerFunc(h.ListUsers), middleware.Logger, middleware.Gzip)) mux.Handle("POST /api/v1/auth/login", chain(http.HandlerFunc(h.Login), middleware.Logger, middleware.Gzip))
mux.Handle("GET /api/v1/users/", chain(http.HandlerFunc(h.GetUser), middleware.Logger, middleware.Gzip))
mux.Handle("PUT /api/v1/users/", chain(http.HandlerFunc(h.UpdateUser), middleware.Logger, middleware.Gzip)) mux.Handle("POST /api/v1/users", chain(http.HandlerFunc(h.CreateUser), middleware.Logger, middleware.Gzip, auth))
mux.Handle("DELETE /api/v1/users/", chain(http.HandlerFunc(h.DeleteUser), middleware.Logger, middleware.Gzip)) mux.Handle("GET /api/v1/users", chain(http.HandlerFunc(h.ListUsers), middleware.Logger, middleware.Gzip, auth))
mux.Handle("GET /api/v1/users/", chain(http.HandlerFunc(h.GetUser), middleware.Logger, middleware.Gzip, auth))
mux.Handle("PUT /api/v1/users/", chain(http.HandlerFunc(h.UpdateUser), middleware.Logger, middleware.Gzip, auth))
mux.Handle("DELETE /api/v1/users/", chain(http.HandlerFunc(h.DeleteUser), middleware.Logger, middleware.Gzip, auth))
mux.Handle("GET /swagger/", httpSwagger.Handler(httpSwagger.URL("/swagger/doc.json"))) mux.Handle("GET /swagger/", httpSwagger.Handler(httpSwagger.URL("/swagger/doc.json")))

View file

@ -2,7 +2,10 @@ package usecase
import ( import (
"context" "context"
"errors"
"time"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/gofrs/uuid/v5" "github.com/gofrs/uuid/v5"
@ -14,6 +17,8 @@ import (
type Repository interface { type Repository interface {
CreateCompany(ctx context.Context, company *domain.Company) error CreateCompany(ctx context.Context, company *domain.Company) error
ListCompanies(ctx context.Context) ([]domain.Company, error) ListCompanies(ctx context.Context) ([]domain.Company, error)
GetCompany(ctx context.Context, id uuid.UUID) (*domain.Company, error)
UpdateCompany(ctx context.Context, company *domain.Company) error
CreateProduct(ctx context.Context, product *domain.Product) error CreateProduct(ctx context.Context, product *domain.Product) error
ListProducts(ctx context.Context) ([]domain.Product, error) ListProducts(ctx context.Context) ([]domain.Product, error)
@ -25,6 +30,7 @@ type Repository interface {
CreateUser(ctx context.Context, user *domain.User) error CreateUser(ctx context.Context, user *domain.User) error
ListUsers(ctx context.Context, filter domain.UserFilter) ([]domain.User, int64, error) ListUsers(ctx context.Context, filter domain.UserFilter) ([]domain.User, int64, error)
GetUser(ctx context.Context, id uuid.UUID) (*domain.User, error) GetUser(ctx context.Context, id uuid.UUID) (*domain.User, error)
GetUserByEmail(ctx context.Context, email string) (*domain.User, error)
UpdateUser(ctx context.Context, user *domain.User) error UpdateUser(ctx context.Context, user *domain.User) error
DeleteUser(ctx context.Context, id uuid.UUID) error DeleteUser(ctx context.Context, id uuid.UUID) error
} }
@ -35,13 +41,15 @@ type PaymentGateway interface {
} }
type Service struct { type Service struct {
repo Repository repo Repository
pay PaymentGateway pay PaymentGateway
jwtSecret []byte
tokenTTL time.Duration
} }
// NewService wires use cases together. // NewService wires use cases together.
func NewService(repo Repository, pay PaymentGateway) *Service { func NewService(repo Repository, pay PaymentGateway, jwtSecret string, tokenTTL time.Duration) *Service {
return &Service{repo: repo, pay: pay} return &Service{repo: repo, pay: pay, jwtSecret: []byte(jwtSecret), tokenTTL: tokenTTL}
} }
func (s *Service) RegisterCompany(ctx context.Context, company *domain.Company) error { func (s *Service) RegisterCompany(ctx context.Context, company *domain.Company) error {
@ -53,6 +61,10 @@ func (s *Service) ListCompanies(ctx context.Context) ([]domain.Company, error) {
return s.repo.ListCompanies(ctx) return s.repo.ListCompanies(ctx)
} }
func (s *Service) GetCompany(ctx context.Context, id uuid.UUID) (*domain.Company, error) {
return s.repo.GetCompany(ctx, id)
}
func (s *Service) RegisterProduct(ctx context.Context, product *domain.Product) error { func (s *Service) RegisterProduct(ctx context.Context, product *domain.Product) error {
product.ID = uuid.Must(uuid.NewV7()) product.ID = uuid.Must(uuid.NewV7())
return s.repo.CreateProduct(ctx, product) return s.repo.CreateProduct(ctx, product)
@ -134,3 +146,65 @@ func (s *Service) UpdateUser(ctx context.Context, user *domain.User, newPassword
func (s *Service) DeleteUser(ctx context.Context, id uuid.UUID) error { func (s *Service) DeleteUser(ctx context.Context, id uuid.UUID) error {
return s.repo.DeleteUser(ctx, id) return s.repo.DeleteUser(ctx, id)
} }
// RegisterAccount creates a company when needed and persists a user bound to it.
func (s *Service) RegisterAccount(ctx context.Context, company *domain.Company, user *domain.User, password string) error {
if company != nil {
if company.ID == uuid.Nil {
company.ID = uuid.Must(uuid.NewV7())
company.IsVerified = false
if err := s.repo.CreateCompany(ctx, company); err != nil {
return err
}
} else {
if _, err := s.repo.GetCompany(ctx, company.ID); err != nil {
return err
}
}
user.CompanyID = company.ID
}
return s.CreateUser(ctx, user, password)
}
// Authenticate validates credentials and emits a signed JWT.
func (s *Service) Authenticate(ctx context.Context, email, password string) (string, time.Time, error) {
user, err := s.repo.GetUserByEmail(ctx, email)
if err != nil {
return "", time.Time{}, err
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
return "", time.Time{}, errors.New("invalid credentials")
}
expiresAt := time.Now().Add(s.tokenTTL)
claims := jwt.MapClaims{
"sub": user.ID.String(),
"role": user.Role,
"company_id": user.CompanyID.String(),
"exp": expiresAt.Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := token.SignedString(s.jwtSecret)
if err != nil {
return "", time.Time{}, err
}
return signed, expiresAt, nil
}
// VerifyCompany marks a company as verified.
func (s *Service) VerifyCompany(ctx context.Context, id uuid.UUID) (*domain.Company, error) {
company, err := s.repo.GetCompany(ctx, id)
if err != nil {
return nil, err
}
company.IsVerified = true
if err := s.repo.UpdateCompany(ctx, company); err != nil {
return nil, err
}
return company, nil
}