152 lines
3.4 KiB
Go
152 lines
3.4 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"embed"
|
|
"fmt"
|
|
"sort"
|
|
"time"
|
|
)
|
|
|
|
//go:embed migrations/*.sql
|
|
var migrationFiles embed.FS
|
|
|
|
type migration struct {
|
|
version int
|
|
name string
|
|
sql string
|
|
}
|
|
|
|
// ApplyMigrations ensures the database schema is up to date.
|
|
func (r *Repository) ApplyMigrations(ctx context.Context) error {
|
|
migrations, err := loadMigrations()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := r.ensureMigrationsTable(ctx); err != nil {
|
|
return err
|
|
}
|
|
|
|
applied, err := r.fetchAppliedVersions(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, m := range migrations {
|
|
if applied[m.version] {
|
|
continue
|
|
}
|
|
|
|
if err := r.applyMigration(ctx, m); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func loadMigrations() ([]migration, error) {
|
|
entries, err := migrationFiles.ReadDir("migrations")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read migrations dir: %w", err)
|
|
}
|
|
|
|
migrations := make([]migration, 0, len(entries))
|
|
for _, entry := range entries {
|
|
if entry.IsDir() {
|
|
continue
|
|
}
|
|
|
|
name := entry.Name()
|
|
version, err := parseMigrationVersion(name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
content, err := migrationFiles.ReadFile("migrations/" + name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read migration %s: %w", name, err)
|
|
}
|
|
|
|
migrations = append(migrations, migration{
|
|
version: version,
|
|
name: name,
|
|
sql: string(content),
|
|
})
|
|
}
|
|
|
|
sort.Slice(migrations, func(i, j int) bool {
|
|
return migrations[i].version < migrations[j].version
|
|
})
|
|
|
|
return migrations, nil
|
|
}
|
|
|
|
func parseMigrationVersion(name string) (int, error) {
|
|
var version int
|
|
if _, err := fmt.Sscanf(name, "%d_", &version); err == nil {
|
|
return version, nil
|
|
}
|
|
if _, err := fmt.Sscanf(name, "%d-", &version); err == nil {
|
|
return version, nil
|
|
}
|
|
return 0, fmt.Errorf("invalid migration filename: %s", name)
|
|
}
|
|
|
|
func (r *Repository) ensureMigrationsTable(ctx context.Context) error {
|
|
const schema = `
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version INT PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
applied_at TIMESTAMPTZ NOT NULL
|
|
);`
|
|
if _, err := r.db.ExecContext(ctx, schema); err != nil {
|
|
return fmt.Errorf("create migrations table: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *Repository) fetchAppliedVersions(ctx context.Context) (map[int]bool, error) {
|
|
rows, err := r.db.QueryxContext(ctx, `SELECT version FROM schema_migrations`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("fetch applied migrations: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
applied := make(map[int]bool)
|
|
for rows.Next() {
|
|
var version int
|
|
if err := rows.Scan(&version); err != nil {
|
|
return nil, fmt.Errorf("scan applied migration: %w", err)
|
|
}
|
|
applied[version] = true
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("read applied migrations: %w", err)
|
|
}
|
|
return applied, nil
|
|
}
|
|
|
|
func (r *Repository) applyMigration(ctx context.Context, m migration) error {
|
|
tx, err := r.db.BeginTxx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("begin migration %d: %w", m.version, err)
|
|
}
|
|
|
|
if _, err := tx.ExecContext(ctx, m.sql); err != nil {
|
|
_ = tx.Rollback()
|
|
return fmt.Errorf("apply migration %d: %w", m.version, err)
|
|
}
|
|
|
|
if _, err := tx.ExecContext(ctx, `INSERT INTO schema_migrations (version, name, applied_at) VALUES ($1, $2, $3)`, m.version, m.name, time.Now().UTC()); err != nil {
|
|
_ = tx.Rollback()
|
|
return fmt.Errorf("record migration %d: %w", m.version, err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("commit migration %d: %w", m.version, err)
|
|
}
|
|
|
|
return nil
|
|
}
|