99 lines
2.5 KiB
Go
99 lines
2.5 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"embed"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"time"
|
|
|
|
_ "modernc.org/sqlite"
|
|
|
|
"github.com/benya/temporserv/internal/config"
|
|
)
|
|
|
|
//go:embed migrations/*.sql
|
|
var migrationFiles embed.FS
|
|
|
|
func Open(ctx context.Context, cfg config.Config) (*sql.DB, error) {
|
|
if err := os.MkdirAll(filepath.Dir(cfg.DatabasePath), 0o755); err != nil {
|
|
return nil, fmt.Errorf("create database directory: %w", err)
|
|
}
|
|
|
|
db, err := sql.Open("sqlite", cfg.DatabasePath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open database: %w", err)
|
|
}
|
|
|
|
db.SetMaxOpenConns(1)
|
|
db.SetConnMaxLifetime(30 * time.Minute)
|
|
|
|
if err := db.PingContext(ctx); err != nil {
|
|
return nil, fmt.Errorf("ping database: %w", err)
|
|
}
|
|
|
|
if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = ON;"); err != nil {
|
|
return nil, fmt.Errorf("enable foreign keys: %w", err)
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func Migrate(ctx context.Context, database *sql.DB) error {
|
|
if _, err := database.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS schema_migrations (name TEXT PRIMARY KEY, applied_at TEXT NOT NULL)`); err != nil {
|
|
return fmt.Errorf("create schema_migrations: %w", err)
|
|
}
|
|
|
|
entries, err := migrationFiles.ReadDir("migrations")
|
|
if err != nil {
|
|
return fmt.Errorf("read migrations: %w", err)
|
|
}
|
|
|
|
sort.Slice(entries, func(i, j int) bool {
|
|
return entries[i].Name() < entries[j].Name()
|
|
})
|
|
|
|
for _, entry := range entries {
|
|
if entry.IsDir() {
|
|
continue
|
|
}
|
|
|
|
var alreadyApplied int
|
|
if err := database.QueryRowContext(ctx, `SELECT COUNT(*) FROM schema_migrations WHERE name = ?`, entry.Name()).Scan(&alreadyApplied); err != nil {
|
|
return fmt.Errorf("check migration %s: %w", entry.Name(), err)
|
|
}
|
|
if alreadyApplied > 0 {
|
|
continue
|
|
}
|
|
|
|
sqlBytes, err := migrationFiles.ReadFile("migrations/" + entry.Name())
|
|
if err != nil {
|
|
return fmt.Errorf("read migration %s: %w", entry.Name(), err)
|
|
}
|
|
|
|
tx, err := database.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("begin migration %s: %w", entry.Name(), err)
|
|
}
|
|
|
|
if _, err := tx.ExecContext(ctx, string(sqlBytes)); err != nil {
|
|
_ = tx.Rollback()
|
|
return fmt.Errorf("apply migration %s: %w", entry.Name(), err)
|
|
}
|
|
|
|
if _, err := tx.ExecContext(ctx, `INSERT INTO schema_migrations (name, applied_at) VALUES (?, ?)`, entry.Name(), time.Now().UTC().Format(time.RFC3339)); err != nil {
|
|
_ = tx.Rollback()
|
|
return fmt.Errorf("record migration %s: %w", entry.Name(), err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("commit migration %s: %w", entry.Name(), err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|