Files

99 lines
2.5 KiB
Go

package db
import (
"context"
"database/sql"
"embed"
"fmt"
"os"
"path/filepath"
"sort"
"time"
_ "modernc.org/sqlite"
"github.com/benya/temporserv/internal/config"
)
//go:embed migrations/*.sql
var migrationFiles embed.FS
func Open(ctx context.Context, cfg config.Config) (*sql.DB, error) {
if err := os.MkdirAll(filepath.Dir(cfg.DatabasePath), 0o755); err != nil {
return nil, fmt.Errorf("create database directory: %w", err)
}
db, err := sql.Open("sqlite", cfg.DatabasePath)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
db.SetMaxOpenConns(1)
db.SetConnMaxLifetime(30 * time.Minute)
if err := db.PingContext(ctx); err != nil {
return nil, fmt.Errorf("ping database: %w", err)
}
if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = ON;"); err != nil {
return nil, fmt.Errorf("enable foreign keys: %w", err)
}
return db, nil
}
func Migrate(ctx context.Context, database *sql.DB) error {
if _, err := database.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS schema_migrations (name TEXT PRIMARY KEY, applied_at TEXT NOT NULL)`); err != nil {
return fmt.Errorf("create schema_migrations: %w", err)
}
entries, err := migrationFiles.ReadDir("migrations")
if err != nil {
return fmt.Errorf("read migrations: %w", err)
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Name() < entries[j].Name()
})
for _, entry := range entries {
if entry.IsDir() {
continue
}
var alreadyApplied int
if err := database.QueryRowContext(ctx, `SELECT COUNT(*) FROM schema_migrations WHERE name = ?`, entry.Name()).Scan(&alreadyApplied); err != nil {
return fmt.Errorf("check migration %s: %w", entry.Name(), err)
}
if alreadyApplied > 0 {
continue
}
sqlBytes, err := migrationFiles.ReadFile("migrations/" + entry.Name())
if err != nil {
return fmt.Errorf("read migration %s: %w", entry.Name(), err)
}
tx, err := database.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin migration %s: %w", entry.Name(), err)
}
if _, err := tx.ExecContext(ctx, string(sqlBytes)); err != nil {
_ = tx.Rollback()
return fmt.Errorf("apply migration %s: %w", entry.Name(), err)
}
if _, err := tx.ExecContext(ctx, `INSERT INTO schema_migrations (name, applied_at) VALUES (?, ?)`, entry.Name(), time.Now().UTC().Format(time.RFC3339)); err != nil {
_ = tx.Rollback()
return fmt.Errorf("record migration %s: %w", entry.Name(), err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit migration %s: %w", entry.Name(), err)
}
}
return nil
}