128 lines
3.6 KiB
Go
128 lines
3.6 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"time"
|
|
|
|
"photum-backend/internal/config"
|
|
"photum-backend/internal/db/generated"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
)
|
|
|
|
type Service struct {
|
|
queries *generated.Queries
|
|
cfg *config.Config
|
|
}
|
|
|
|
func NewService(queries *generated.Queries, cfg *config.Config) *Service {
|
|
return &Service{queries: queries, cfg: cfg}
|
|
}
|
|
|
|
func (s *Service) Register(ctx context.Context, email, password string) (*generated.Usuario, error) {
|
|
hash, err := HashPassword(password)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
user, err := s.queries.CreateUsuario(ctx, generated.CreateUsuarioParams{
|
|
Email: email,
|
|
SenhaHash: hash,
|
|
Role: "profissional",
|
|
})
|
|
return &user, err
|
|
}
|
|
|
|
func (s *Service) Login(ctx context.Context, email, password, userAgent, ip string) (string, string, time.Time, *generated.Usuario, error) {
|
|
user, err := s.queries.GetUsuarioByEmail(ctx, email)
|
|
if err != nil {
|
|
return "", "", time.Time{}, nil, errors.New("invalid credentials")
|
|
}
|
|
|
|
if !CheckPasswordHash(password, user.SenhaHash) {
|
|
return "", "", time.Time{}, nil, errors.New("invalid credentials")
|
|
}
|
|
|
|
// Convert pgtype.UUID to uuid.UUID
|
|
userUUID := uuid.UUID(user.ID.Bytes)
|
|
|
|
accessToken, accessExp, err := GenerateAccessToken(userUUID, user.Role, s.cfg.JwtAccessSecret, s.cfg.JwtAccessTTLMinutes)
|
|
if err != nil {
|
|
return "", "", time.Time{}, nil, err
|
|
}
|
|
|
|
refreshToken, _, err := s.createRefreshToken(ctx, user.ID, userAgent, ip)
|
|
if err != nil {
|
|
return "", "", time.Time{}, nil, err
|
|
}
|
|
|
|
// Return access token, refresh token (raw), access expiration, user
|
|
return accessToken, refreshToken, accessExp, &user, nil
|
|
}
|
|
|
|
func (s *Service) Refresh(ctx context.Context, refreshTokenRaw string) (string, time.Time, error) {
|
|
// Hash the raw token to find it in DB
|
|
hash := sha256.Sum256([]byte(refreshTokenRaw))
|
|
hashString := hex.EncodeToString(hash[:])
|
|
|
|
storedToken, err := s.queries.GetRefreshToken(ctx, hashString)
|
|
if err != nil {
|
|
return "", time.Time{}, errors.New("invalid refresh token")
|
|
}
|
|
|
|
if storedToken.Revogado {
|
|
return "", time.Time{}, errors.New("token revoked")
|
|
}
|
|
|
|
if time.Now().After(storedToken.ExpiraEm.Time) {
|
|
return "", time.Time{}, errors.New("token expired")
|
|
}
|
|
|
|
// Get user to check if active and get role
|
|
user, err := s.queries.GetUsuarioByID(ctx, storedToken.UsuarioID)
|
|
if err != nil {
|
|
return "", time.Time{}, errors.New("user not found")
|
|
}
|
|
|
|
// Convert pgtype.UUID to uuid.UUID
|
|
userUUID := uuid.UUID(user.ID.Bytes)
|
|
|
|
// Generate new access token
|
|
return GenerateAccessToken(userUUID, user.Role, s.cfg.JwtAccessSecret, s.cfg.JwtAccessTTLMinutes)
|
|
}
|
|
|
|
func (s *Service) Logout(ctx context.Context, refreshTokenRaw string) error {
|
|
hash := sha256.Sum256([]byte(refreshTokenRaw))
|
|
hashString := hex.EncodeToString(hash[:])
|
|
return s.queries.RevokeRefreshToken(ctx, hashString)
|
|
}
|
|
|
|
func (s *Service) createRefreshToken(ctx context.Context, userID pgtype.UUID, userAgent, ip string) (string, time.Time, error) {
|
|
// Generate random token
|
|
randomToken := uuid.New().String() // Simple UUID as refresh token
|
|
|
|
hash := sha256.Sum256([]byte(randomToken))
|
|
hashString := hex.EncodeToString(hash[:])
|
|
|
|
expiraEm := time.Now().Add(time.Duration(s.cfg.JwtRefreshTTLDays) * 24 * time.Hour)
|
|
|
|
// pgtype.Timestamptz conversion
|
|
pgExpiraEm := pgtype.Timestamptz{
|
|
Time: expiraEm,
|
|
Valid: true,
|
|
}
|
|
|
|
_, err := s.queries.CreateRefreshToken(ctx, generated.CreateRefreshTokenParams{
|
|
UsuarioID: userID,
|
|
TokenHash: hashString,
|
|
UserAgent: pgtype.Text{String: userAgent, Valid: userAgent != ""},
|
|
Ip: pgtype.Text{String: ip, Valid: ip != ""},
|
|
ExpiraEm: pgExpiraEm,
|
|
})
|
|
|
|
return randomToken, expiraEm, err
|
|
}
|