saveinmed/backend/internal/repository/postgres/migrations.go
2025-12-20 10:32:54 -03:00

152 lines
3.4 KiB
Go

package postgres
import (
"context"
"embed"
"fmt"
"sort"
"time"
)
//go:embed migrations/*.sql
var migrationFiles embed.FS
type migration struct {
version int
name string
sql string
}
// ApplyMigrations ensures the database schema is up to date.
func (r *Repository) ApplyMigrations(ctx context.Context) error {
migrations, err := loadMigrations()
if err != nil {
return err
}
if err := r.ensureMigrationsTable(ctx); err != nil {
return err
}
applied, err := r.fetchAppliedVersions(ctx)
if err != nil {
return err
}
for _, m := range migrations {
if applied[m.version] {
continue
}
if err := r.applyMigration(ctx, m); err != nil {
return err
}
}
return nil
}
func loadMigrations() ([]migration, error) {
entries, err := migrationFiles.ReadDir("migrations")
if err != nil {
return nil, fmt.Errorf("read migrations dir: %w", err)
}
migrations := make([]migration, 0, len(entries))
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
version, err := parseMigrationVersion(name)
if err != nil {
return nil, err
}
content, err := migrationFiles.ReadFile("migrations/" + name)
if err != nil {
return nil, fmt.Errorf("read migration %s: %w", name, err)
}
migrations = append(migrations, migration{
version: version,
name: name,
sql: string(content),
})
}
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].version < migrations[j].version
})
return migrations, nil
}
func parseMigrationVersion(name string) (int, error) {
var version int
if _, err := fmt.Sscanf(name, "%d_", &version); err == nil {
return version, nil
}
if _, err := fmt.Sscanf(name, "%d-", &version); err == nil {
return version, nil
}
return 0, fmt.Errorf("invalid migration filename: %s", name)
}
func (r *Repository) ensureMigrationsTable(ctx context.Context) error {
const schema = `
CREATE TABLE IF NOT EXISTS schema_migrations (
version INT PRIMARY KEY,
name TEXT NOT NULL,
applied_at TIMESTAMPTZ NOT NULL
);`
if _, err := r.db.ExecContext(ctx, schema); err != nil {
return fmt.Errorf("create migrations table: %w", err)
}
return nil
}
func (r *Repository) fetchAppliedVersions(ctx context.Context) (map[int]bool, error) {
rows, err := r.db.QueryxContext(ctx, `SELECT version FROM schema_migrations`)
if err != nil {
return nil, fmt.Errorf("fetch applied migrations: %w", err)
}
defer rows.Close()
applied := make(map[int]bool)
for rows.Next() {
var version int
if err := rows.Scan(&version); err != nil {
return nil, fmt.Errorf("scan applied migration: %w", err)
}
applied[version] = true
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("read applied migrations: %w", err)
}
return applied, nil
}
func (r *Repository) applyMigration(ctx context.Context, m migration) error {
tx, err := r.db.BeginTxx(ctx, nil)
if err != nil {
return fmt.Errorf("begin migration %d: %w", m.version, err)
}
if _, err := tx.ExecContext(ctx, m.sql); err != nil {
_ = tx.Rollback()
return fmt.Errorf("apply migration %d: %w", m.version, err)
}
if _, err := tx.ExecContext(ctx, `INSERT INTO schema_migrations (version, name, applied_at) VALUES ($1, $2, $3)`, m.version, m.name, time.Now().UTC()); err != nil {
_ = tx.Rollback()
return fmt.Errorf("record migration %d: %w", m.version, err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit migration %d: %w", m.version, err)
}
return nil
}