package db import ( "context" "database/sql" "embed" "fmt" "os" "path/filepath" "sort" "time" _ "modernc.org/sqlite" "github.com/benya/temporserv/internal/config" ) //go:embed migrations/*.sql var migrationFiles embed.FS func Open(ctx context.Context, cfg config.Config) (*sql.DB, error) { if err := os.MkdirAll(filepath.Dir(cfg.DatabasePath), 0o755); err != nil { return nil, fmt.Errorf("create database directory: %w", err) } db, err := sql.Open("sqlite", cfg.DatabasePath) if err != nil { return nil, fmt.Errorf("open database: %w", err) } db.SetMaxOpenConns(1) db.SetConnMaxLifetime(30 * time.Minute) if err := db.PingContext(ctx); err != nil { return nil, fmt.Errorf("ping database: %w", err) } if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = ON;"); err != nil { return nil, fmt.Errorf("enable foreign keys: %w", err) } return db, nil } func Migrate(ctx context.Context, database *sql.DB) error { if _, err := database.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS schema_migrations (name TEXT PRIMARY KEY, applied_at TEXT NOT NULL)`); err != nil { return fmt.Errorf("create schema_migrations: %w", err) } entries, err := migrationFiles.ReadDir("migrations") if err != nil { return fmt.Errorf("read migrations: %w", err) } sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() }) for _, entry := range entries { if entry.IsDir() { continue } var alreadyApplied int if err := database.QueryRowContext(ctx, `SELECT COUNT(*) FROM schema_migrations WHERE name = ?`, entry.Name()).Scan(&alreadyApplied); err != nil { return fmt.Errorf("check migration %s: %w", entry.Name(), err) } if alreadyApplied > 0 { continue } sqlBytes, err := migrationFiles.ReadFile("migrations/" + entry.Name()) if err != nil { return fmt.Errorf("read migration %s: %w", entry.Name(), err) } tx, err := database.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin migration %s: %w", entry.Name(), err) } if _, err := tx.ExecContext(ctx, string(sqlBytes)); err != nil { _ = tx.Rollback() return fmt.Errorf("apply migration %s: %w", entry.Name(), err) } if _, err := tx.ExecContext(ctx, `INSERT INTO schema_migrations (name, applied_at) VALUES (?, ?)`, entry.Name(), time.Now().UTC().Format(time.RFC3339)); err != nil { _ = tx.Rollback() return fmt.Errorf("record migration %s: %w", entry.Name(), err) } if err := tx.Commit(); err != nil { return fmt.Errorf("commit migration %s: %w", entry.Name(), err) } } return nil }