saveinmed/backend-old/internal/server/server_test.go
2026-01-16 10:51:52 -03:00

181 lines
4.7 KiB
Go

package server
import (
"context"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/jmoiron/sqlx"
"github.com/saveinmed/backend-go/internal/config"
)
// TestServerHealthCheck tests the /health endpoint without a database
func TestServerHealthCheck(t *testing.T) {
// Create a simple handler to test health endpoint
mux := http.NewServeMux()
mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
})
req := httptest.NewRequest(http.MethodGet, "/health", nil)
rec := httptest.NewRecorder()
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, rec.Code)
}
if rec.Body.String() != "ok" {
t.Errorf("expected body 'ok', got '%s'", rec.Body.String())
}
}
// TestServerRootEndpoint tests the root / endpoint
func TestServerRootEndpoint(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
response := `{"message":"💊 SaveInMed API is running!","docs":"/docs/index.html","health":"/health","version":"1.0.0"}`
_, _ = w.Write([]byte(response))
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, rec.Code)
}
body := rec.Body.String()
if !strings.Contains(body, "SaveInMed API is running") {
t.Errorf("expected body to contain 'SaveInMed API is running', got '%s'", body)
}
}
// TestServerCreationWithDatabase tests server creation with a real database
// Skip this test if SKIP_DB_TEST environment variable is set
func TestServerCreationWithDatabase(t *testing.T) {
if os.Getenv("SKIP_DB_TEST") != "" {
t.Skip("Skipping database tests")
}
// Simple .env loader for testing purposes
if content, err := os.ReadFile("../../.env"); err == nil {
for _, line := range strings.Split(string(content), "\n") {
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
os.Setenv(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
}
}
}
cfg := config.Load()
srv, err := New(cfg)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
if srv == nil {
t.Fatal("Server should not be nil")
}
// Clean up
if srv.db != nil {
srv.db.Close()
}
}
// TestDatabaseConnectionAndPing tests the database connection and ping
func TestDatabaseConnectionAndPing(t *testing.T) {
if os.Getenv("SKIP_DB_TEST") != "" {
t.Skip("Skipping database tests")
}
// Simple .env loader for testing purposes
if content, err := os.ReadFile("../../.env"); err == nil {
for _, line := range strings.Split(string(content), "\n") {
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
os.Setenv(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
}
}
}
cfg := config.Load()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
db, err := sqlx.ConnectContext(ctx, "pgx", cfg.DatabaseURL)
if err != nil {
t.Fatalf("Failed to connect to database: %v", err)
}
defer db.Close()
if err := db.PingContext(ctx); err != nil {
t.Fatalf("Failed to ping database: %v", err)
}
// Test a simple query
var result int
if err := db.QueryRowContext(ctx, "SELECT 1").Scan(&result); err != nil {
t.Fatalf("Failed to execute query: %v", err)
}
if result != 1 {
t.Errorf("Expected 1, got %d", result)
}
}
func TestChainMiddlewareOrder(t *testing.T) {
order := make([]string, 0, 3)
mw1 := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "mw1")
next.ServeHTTP(w, r)
})
}
mw2 := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "mw2")
next.ServeHTTP(w, r)
})
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "handler")
w.WriteHeader(http.StatusOK)
})
chained := chain(handler, mw1, mw2)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
chained.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rec.Code)
}
expected := []string{"mw1", "mw2", "handler"}
if len(order) != len(expected) {
t.Fatalf("expected order %v, got %v", expected, order)
}
for i := range expected {
if order[i] != expected[i] {
t.Fatalf("expected order %v, got %v", expected, order)
}
}
}