package postgres import ( "context" "embed" "fmt" "sort" "time" ) //go:embed migrations/*.sql var migrationFiles embed.FS type migration struct { version int name string sql string } // ApplyMigrations ensures the database schema is up to date. func (r *Repository) ApplyMigrations(ctx context.Context) error { migrations, err := loadMigrations() if err != nil { return err } if err := r.ensureMigrationsTable(ctx); err != nil { return err } applied, err := r.fetchAppliedVersions(ctx) if err != nil { return err } for _, m := range migrations { if applied[m.version] { continue } if err := r.applyMigration(ctx, m); err != nil { return err } } return nil } func loadMigrations() ([]migration, error) { entries, err := migrationFiles.ReadDir("migrations") if err != nil { return nil, fmt.Errorf("read migrations dir: %w", err) } migrations := make([]migration, 0, len(entries)) for _, entry := range entries { if entry.IsDir() { continue } name := entry.Name() version, err := parseMigrationVersion(name) if err != nil { return nil, err } content, err := migrationFiles.ReadFile("migrations/" + name) if err != nil { return nil, fmt.Errorf("read migration %s: %w", name, err) } migrations = append(migrations, migration{ version: version, name: name, sql: string(content), }) } sort.Slice(migrations, func(i, j int) bool { return migrations[i].version < migrations[j].version }) return migrations, nil } func parseMigrationVersion(name string) (int, error) { var version int if _, err := fmt.Sscanf(name, "%d_", &version); err == nil { return version, nil } if _, err := fmt.Sscanf(name, "%d-", &version); err == nil { return version, nil } return 0, fmt.Errorf("invalid migration filename: %s", name) } func (r *Repository) ensureMigrationsTable(ctx context.Context) error { const schema = ` CREATE TABLE IF NOT EXISTS schema_migrations ( version INT PRIMARY KEY, name TEXT NOT NULL, applied_at TIMESTAMPTZ NOT NULL );` if _, err := r.db.ExecContext(ctx, schema); err != nil { return fmt.Errorf("create migrations table: %w", err) } return nil } func (r *Repository) fetchAppliedVersions(ctx context.Context) (map[int]bool, error) { rows, err := r.db.QueryxContext(ctx, `SELECT version FROM schema_migrations`) if err != nil { return nil, fmt.Errorf("fetch applied migrations: %w", err) } defer rows.Close() applied := make(map[int]bool) for rows.Next() { var version int if err := rows.Scan(&version); err != nil { return nil, fmt.Errorf("scan applied migration: %w", err) } applied[version] = true } if err := rows.Err(); err != nil { return nil, fmt.Errorf("read applied migrations: %w", err) } return applied, nil } func (r *Repository) applyMigration(ctx context.Context, m migration) error { tx, err := r.db.BeginTxx(ctx, nil) if err != nil { return fmt.Errorf("begin migration %d: %w", m.version, err) } if _, err := tx.ExecContext(ctx, m.sql); err != nil { _ = tx.Rollback() return fmt.Errorf("apply migration %d: %w", m.version, err) } if _, err := tx.ExecContext(ctx, `INSERT INTO schema_migrations (version, name, applied_at) VALUES ($1, $2, $3)`, m.version, m.name, time.Now().UTC()); err != nil { _ = tx.Rollback() return fmt.Errorf("record migration %d: %w", m.version, err) } if err := tx.Commit(); err != nil { return fmt.Errorf("commit migration %d: %w", m.version, err) } return nil }