saveinmed/backend/internal/http/handler/dto_test.go
2026-01-02 11:01:56 -03:00

236 lines
6.9 KiB
Go

package handler
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gofrs/uuid/v5"
"github.com/golang-jwt/jwt/v5"
"github.com/saveinmed/backend-go/internal/http/middleware"
)
func TestWriteJSONAndError(t *testing.T) {
rec := httptest.NewRecorder()
writeJSON(rec, http.StatusCreated, map[string]string{"ok": "true"})
if rec.Code != http.StatusCreated {
t.Fatalf("expected status 201, got %d", rec.Code)
}
if ct := rec.Header().Get("Content-Type"); ct != "application/json" {
t.Fatalf("expected Content-Type application/json, got %q", ct)
}
var payload map[string]string
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if payload["ok"] != "true" {
t.Fatalf("expected payload ok=true, got %v", payload)
}
rec = httptest.NewRecorder()
writeError(rec, http.StatusBadRequest, errors.New("boom"))
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d", rec.Code)
}
var errPayload map[string]string
if err := json.Unmarshal(rec.Body.Bytes(), &errPayload); err != nil {
t.Fatalf("failed to decode error response: %v", err)
}
if errPayload["error"] != "boom" {
t.Fatalf("expected error boom, got %v", errPayload)
}
}
func TestDecodeJSONUnknownField(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(`{"unknown":1}`))
var payload struct {
Known string `json:"known"`
}
err := decodeJSON(context.Background(), req, &payload)
if err == nil {
t.Fatal("expected error for unknown field")
}
}
func TestDecodeJSONCanceledContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(`{"known":"ok"}`))
var payload struct {
Known string `json:"known"`
}
err := decodeJSON(ctx, req, &payload)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context canceled, got %v", err)
}
}
func TestParseUUIDFromPath(t *testing.T) {
id := uuid.Must(uuid.NewV7())
got, err := parseUUIDFromPath("/api/v1/companies/" + id.String() + "/rating")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != id {
t.Fatalf("expected %s, got %s", id, got)
}
if _, err := parseUUIDFromPath("/api/v1/companies/rating"); err == nil {
t.Fatal("expected error for missing UUID")
}
}
func TestSplitPath(t *testing.T) {
parts := splitPath("/api/v1/companies/123")
if len(parts) != 4 {
t.Fatalf("expected 4 parts, got %d", len(parts))
}
if parts[0] != "api" || parts[3] != "123" {
t.Fatalf("unexpected parts: %v", parts)
}
}
func TestIsValidStatus(t *testing.T) {
if !isValidStatus("pending") {
t.Fatal("expected pending to be valid")
}
if isValidStatus("cancelled") {
t.Fatal("expected cancelled to be invalid")
}
}
func TestParsePagination(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/?page=2&page_size=30", nil)
page, size := parsePagination(req)
if page != 2 || size != 30 {
t.Fatalf("expected page 2 size 30, got %d %d", page, size)
}
req = httptest.NewRequest(http.MethodGet, "/?page=-1&page_size=0", nil)
page, size = parsePagination(req)
if page != 1 || size != 20 {
t.Fatalf("expected defaults, got %d %d", page, size)
}
}
func TestGetRequesterFromHeaders(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
got, err := getRequester(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got.Role != "Admin" || got.CompanyID != nil {
t.Fatalf("unexpected requester: %+v", got)
}
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-User-Role", "Seller")
companyID := uuid.Must(uuid.NewV7())
req.Header.Set("X-Company-ID", companyID.String())
got, err = getRequester(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got.Role != "Seller" || got.CompanyID == nil || *got.CompanyID != companyID {
t.Fatalf("unexpected requester: %+v", got)
}
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Company-ID", "invalid")
if _, err := getRequester(req); err == nil {
t.Fatal("expected error for invalid company id header")
}
}
func TestGetRequesterFromClaims(t *testing.T) {
secret := []byte("secret")
userID := uuid.Must(uuid.NewV7())
companyID := uuid.Must(uuid.NewV7())
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": userID.String(),
"role": "Owner",
"company_id": companyID.String(),
"exp": time.Now().Add(time.Hour).Unix(),
})
tokenStr, _ := token.SignedString(secret)
var got requester
handler := middleware.RequireAuth(secret)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var err error
got, err = getRequester(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if got.Role != "Owner" || got.CompanyID == nil || *got.CompanyID != companyID {
t.Fatalf("unexpected requester: %+v", got)
}
}
func TestParseBearerToken(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if _, err := parseBearerToken(req); err == nil {
t.Fatal("expected error for missing header")
}
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Token abc")
if _, err := parseBearerToken(req); err == nil {
t.Fatal("expected error for invalid header")
}
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer ")
if _, err := parseBearerToken(req); err == nil {
t.Fatal("expected error for empty token")
}
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer token123")
if token, err := parseBearerToken(req); err != nil || token != "token123" {
t.Fatalf("unexpected token result: %v %v", token, err)
}
}
func TestGetUserFromContext(t *testing.T) {
secret := []byte("secret")
userID := uuid.Must(uuid.NewV7())
companyID := uuid.Must(uuid.NewV7())
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": userID.String(),
"role": "Admin",
"company_id": companyID.String(),
"exp": time.Now().Add(time.Hour).Unix(),
})
tokenStr, _ := token.SignedString(secret)
handler := middleware.RequireAuth(secret)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := &Handler{}
user, err := h.getUserFromContext(r.Context())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if user.ID != userID || user.Role != "Admin" || user.CompanyID != companyID {
t.Fatalf("unexpected user: %+v", user)
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
}