341 lines
8.0 KiB
Go
341 lines
8.0 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/md5"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidCredentials = errors.New("invalid credentials")
|
|
ErrUnauthorized = errors.New("unauthorized")
|
|
)
|
|
|
|
type User struct {
|
|
ID string `json:"id"`
|
|
Username string `json:"username"`
|
|
IsAdmin bool `json:"isAdmin"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
LastLoginAt time.Time `json:"lastLoginAt"`
|
|
}
|
|
|
|
type Session struct {
|
|
Token string `json:"token"`
|
|
User User `json:"user"`
|
|
}
|
|
|
|
type Service struct {
|
|
db *sql.DB
|
|
encryptionKey string
|
|
}
|
|
|
|
func NewService(db *sql.DB, encryptionKey string) *Service {
|
|
return &Service{db: db, encryptionKey: encryptionKey}
|
|
}
|
|
|
|
func (s *Service) Login(ctx context.Context, username, password string) (Session, error) {
|
|
if err := s.cleanupExpiredSessions(ctx); err != nil {
|
|
return Session{}, fmt.Errorf("cleanup expired sessions: %w", err)
|
|
}
|
|
|
|
user, passwordHash, _, err := s.findUserByUsername(ctx, username)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return Session{}, ErrInvalidCredentials
|
|
}
|
|
return Session{}, fmt.Errorf("find user: %w", err)
|
|
}
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password)); err != nil {
|
|
return Session{}, ErrInvalidCredentials
|
|
}
|
|
|
|
if err := s.storeSubsonicSecret(ctx, user.ID, password); err != nil {
|
|
return Session{}, fmt.Errorf("store subsonic secret: %w", err)
|
|
}
|
|
|
|
token, err := newToken()
|
|
if err != nil {
|
|
return Session{}, fmt.Errorf("generate token: %w", err)
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
expiresAt := now.Add(30 * 24 * time.Hour)
|
|
|
|
if _, err := s.db.ExecContext(
|
|
ctx,
|
|
`INSERT INTO sessions (token, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)`,
|
|
token,
|
|
user.ID,
|
|
now.Format(time.RFC3339),
|
|
expiresAt.Format(time.RFC3339),
|
|
); err != nil {
|
|
return Session{}, fmt.Errorf("insert session: %w", err)
|
|
}
|
|
|
|
if _, err := s.db.ExecContext(
|
|
ctx,
|
|
`UPDATE users SET last_login_at = ? WHERE id = ?`,
|
|
now.Format(time.RFC3339),
|
|
user.ID,
|
|
); err == nil {
|
|
user.LastLoginAt = now
|
|
}
|
|
|
|
return Session{
|
|
Token: token,
|
|
User: user,
|
|
}, nil
|
|
}
|
|
|
|
func (s *Service) CurrentUser(ctx context.Context, authorizationHeader string) (User, error) {
|
|
token := strings.TrimSpace(strings.TrimPrefix(authorizationHeader, "Bearer "))
|
|
return s.CurrentUserByToken(ctx, token)
|
|
}
|
|
|
|
func (s *Service) CurrentUserByToken(ctx context.Context, token string) (User, error) {
|
|
if token == "" {
|
|
return User{}, ErrUnauthorized
|
|
}
|
|
|
|
_ = s.cleanupExpiredSessions(ctx)
|
|
|
|
user, err := s.findUserByToken(ctx, token)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return User{}, ErrUnauthorized
|
|
}
|
|
return User{}, fmt.Errorf("find user by token: %w", err)
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (s *Service) Logout(ctx context.Context, token string) error {
|
|
if strings.TrimSpace(token) == "" {
|
|
return nil
|
|
}
|
|
if _, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE token = ?`, strings.TrimSpace(token)); err != nil {
|
|
return fmt.Errorf("delete session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) CurrentUserBySubsonicAuth(ctx context.Context, username, password, token, salt string) (User, error) {
|
|
if username == "" {
|
|
return User{}, ErrUnauthorized
|
|
}
|
|
|
|
user, passwordHash, encryptedSecret, err := s.findUserByUsername(ctx, username)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return User{}, ErrUnauthorized
|
|
}
|
|
return User{}, fmt.Errorf("find user by username: %w", err)
|
|
}
|
|
|
|
if token != "" || salt != "" {
|
|
if token == "" || salt == "" || encryptedSecret == "" {
|
|
return User{}, ErrUnauthorized
|
|
}
|
|
secret, err := decryptSecret(encryptedSecret, s.encryptionKey)
|
|
if err != nil {
|
|
return User{}, ErrUnauthorized
|
|
}
|
|
if hexMD5(secret+salt) != strings.ToLower(token) {
|
|
return User{}, ErrUnauthorized
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
if strings.HasPrefix(password, "enc:") {
|
|
decoded, err := hex.DecodeString(strings.TrimPrefix(password, "enc:"))
|
|
if err != nil {
|
|
return User{}, ErrUnauthorized
|
|
}
|
|
password = string(decoded)
|
|
}
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password)); err != nil {
|
|
return User{}, ErrUnauthorized
|
|
}
|
|
|
|
if err := s.storeSubsonicSecret(ctx, user.ID, password); err != nil {
|
|
return User{}, fmt.Errorf("store subsonic secret: %w", err)
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (s *Service) findUserByUsername(ctx context.Context, username string) (User, string, string, error) {
|
|
var user User
|
|
var passwordHash string
|
|
var subsonicSecret sql.NullString
|
|
var createdAt string
|
|
var lastLoginAt sql.NullString
|
|
var isAdmin int
|
|
|
|
err := s.db.QueryRowContext(
|
|
ctx,
|
|
`SELECT id, username, password_hash, COALESCE(subsonic_auth_secret, ''), is_admin, created_at, last_login_at
|
|
FROM users
|
|
WHERE username = ?`,
|
|
username,
|
|
).Scan(
|
|
&user.ID,
|
|
&user.Username,
|
|
&passwordHash,
|
|
&subsonicSecret,
|
|
&isAdmin,
|
|
&createdAt,
|
|
&lastLoginAt,
|
|
)
|
|
if err != nil {
|
|
return User{}, "", "", err
|
|
}
|
|
|
|
user.IsAdmin = isAdmin == 1
|
|
user.CreatedAt = parseTime(createdAt)
|
|
if lastLoginAt.Valid {
|
|
user.LastLoginAt = parseTime(lastLoginAt.String)
|
|
}
|
|
|
|
return user, passwordHash, subsonicSecret.String, nil
|
|
}
|
|
|
|
func (s *Service) findUserByToken(ctx context.Context, token string) (User, error) {
|
|
var user User
|
|
var createdAt string
|
|
var lastLoginAt sql.NullString
|
|
var isAdmin int
|
|
|
|
err := s.db.QueryRowContext(
|
|
ctx,
|
|
`SELECT u.id, u.username, u.is_admin, u.created_at, u.last_login_at
|
|
FROM users u
|
|
JOIN sessions s ON s.user_id = u.id
|
|
WHERE s.token = ? AND s.expires_at > ?`,
|
|
token,
|
|
time.Now().UTC().Format(time.RFC3339),
|
|
).Scan(
|
|
&user.ID,
|
|
&user.Username,
|
|
&isAdmin,
|
|
&createdAt,
|
|
&lastLoginAt,
|
|
)
|
|
if err != nil {
|
|
return User{}, err
|
|
}
|
|
|
|
user.IsAdmin = isAdmin == 1
|
|
user.CreatedAt = parseTime(createdAt)
|
|
if lastLoginAt.Valid {
|
|
user.LastLoginAt = parseTime(lastLoginAt.String)
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func parseTime(raw string) time.Time {
|
|
parsed, err := time.Parse(time.RFC3339, raw)
|
|
if err != nil {
|
|
return time.Time{}
|
|
}
|
|
return parsed
|
|
}
|
|
|
|
func newToken() (string, error) {
|
|
bytes := make([]byte, 32)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(bytes), nil
|
|
}
|
|
|
|
func (s *Service) storeSubsonicSecret(ctx context.Context, userID, password string) error {
|
|
if userID == "" || password == "" {
|
|
return nil
|
|
}
|
|
encrypted, err := encryptSecret(password, s.encryptionKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = s.db.ExecContext(ctx, `UPDATE users SET subsonic_auth_secret = ? WHERE id = ?`, encrypted, userID)
|
|
return err
|
|
}
|
|
|
|
func (s *Service) cleanupExpiredSessions(ctx context.Context) error {
|
|
_, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE expires_at <= ?`, time.Now().UTC().Format(time.RFC3339))
|
|
return err
|
|
}
|
|
|
|
func EncryptSubsonicSecret(value, key string) (string, error) {
|
|
return encryptSecret(value, key)
|
|
}
|
|
|
|
func encryptSecret(value, key string) (string, error) {
|
|
block, err := aes.NewCipher(deriveKey(key))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
nonce := make([]byte, gcm.NonceSize())
|
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
|
return "", err
|
|
}
|
|
ciphertext := gcm.Seal(nonce, nonce, []byte(value), nil)
|
|
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
|
}
|
|
|
|
func decryptSecret(value, key string) (string, error) {
|
|
block, err := aes.NewCipher(deriveKey(key))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
raw, err := base64.StdEncoding.DecodeString(value)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if len(raw) < gcm.NonceSize() {
|
|
return "", ErrUnauthorized
|
|
}
|
|
nonce := raw[:gcm.NonceSize()]
|
|
ciphertext := raw[gcm.NonceSize():]
|
|
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(plaintext), nil
|
|
}
|
|
|
|
func deriveKey(key string) []byte {
|
|
sum := sha256.Sum256([]byte(key))
|
|
return sum[:]
|
|
}
|
|
|
|
func hexMD5(value string) string {
|
|
sum := md5.Sum([]byte(value))
|
|
return hex.EncodeToString(sum[:])
|
|
}
|