491 lines
18 KiB
Python
491 lines
18 KiB
Python
from datetime import datetime, timedelta, timezone
|
|
import json
|
|
import secrets
|
|
from uuid import uuid4
|
|
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi import Request
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.auth import repository as auth_repository
|
|
from app.auth.token_store import (
|
|
RefreshSession,
|
|
get_refresh_token_user_id,
|
|
list_refresh_sessions_for_user,
|
|
revoke_all_refresh_sessions_for_user,
|
|
revoke_refresh_token_jti,
|
|
store_refresh_token_jti,
|
|
)
|
|
from app.auth.schemas import (
|
|
EmailStatusResponse,
|
|
LoginRequest,
|
|
RefreshTokenRequest,
|
|
RegisterRequest,
|
|
RequestPasswordResetRequest,
|
|
ResendVerificationRequest,
|
|
ResetPasswordRequest,
|
|
TokenResponse,
|
|
VerifyEmailRequest,
|
|
)
|
|
from app.config.settings import settings
|
|
from app.database.session import get_db
|
|
from app.email.service import EmailDeliveryError, EmailService, get_email_service
|
|
from app.users.models import User
|
|
from app.users.repository import create_user, get_user_by_email, get_user_by_id, get_user_by_username
|
|
from app.utils.totp import build_otpauth_uri, generate_totp_secret, verify_totp_code
|
|
from app.utils.security import (
|
|
create_access_token,
|
|
create_refresh_token,
|
|
decode_token,
|
|
generate_random_token,
|
|
hash_password,
|
|
verify_password,
|
|
)
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/login")
|
|
|
|
|
|
def _refresh_ttl_seconds() -> int:
|
|
return settings.refresh_token_expire_days * 24 * 60 * 60
|
|
|
|
|
|
async def register_user(
|
|
db: AsyncSession,
|
|
payload: RegisterRequest,
|
|
email_service: EmailService,
|
|
) -> None:
|
|
existing_email = await get_user_by_email(db, payload.email)
|
|
if existing_email:
|
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email is already registered")
|
|
|
|
existing_username = await get_user_by_username(db, payload.username)
|
|
if existing_username:
|
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Username is already taken")
|
|
|
|
user = await create_user(
|
|
db,
|
|
email=payload.email,
|
|
name=payload.name,
|
|
username=payload.username,
|
|
password_hash=hash_password(payload.password),
|
|
)
|
|
|
|
verification_token = generate_random_token()
|
|
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours)
|
|
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
|
|
await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at)
|
|
try:
|
|
await email_service.send_verification_email(payload.email, verification_token)
|
|
except EmailDeliveryError as exc:
|
|
await db.rollback()
|
|
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Unable to send verification email") from exc
|
|
await db.commit()
|
|
|
|
|
|
async def verify_email(db: AsyncSession, payload: VerifyEmailRequest) -> None:
|
|
record = await auth_repository.get_email_verification_token(db, payload.token)
|
|
if not record:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification token")
|
|
|
|
now = datetime.now(timezone.utc)
|
|
expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc)
|
|
if expires_at < now:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Verification token expired")
|
|
|
|
user = await get_user_by_id(db, record.user_id)
|
|
if not user:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
|
|
|
user.email_verified = True
|
|
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
|
|
await db.commit()
|
|
|
|
|
|
async def resend_verification_email(
|
|
db: AsyncSession,
|
|
payload: ResendVerificationRequest,
|
|
email_service: EmailService,
|
|
) -> None:
|
|
user = await get_user_by_email(db, payload.email)
|
|
if not user or user.email_verified:
|
|
return
|
|
|
|
verification_token = generate_random_token()
|
|
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours)
|
|
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
|
|
await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at)
|
|
try:
|
|
await email_service.send_verification_email(user.email, verification_token)
|
|
except EmailDeliveryError as exc:
|
|
await db.rollback()
|
|
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Unable to send verification email") from exc
|
|
await db.commit()
|
|
|
|
|
|
async def login_user(
|
|
db: AsyncSession,
|
|
payload: LoginRequest,
|
|
*,
|
|
ip_address: str | None = None,
|
|
user_agent: str | None = None,
|
|
) -> TokenResponse:
|
|
user = await get_user_by_email(db, payload.email)
|
|
if not user or not verify_password(payload.password, user.password_hash):
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
|
|
|
if not user.email_verified:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified")
|
|
if user.twofa_enabled:
|
|
if payload.recovery_code:
|
|
recovery_used = await _consume_twofa_recovery_code(db, user=user, recovery_code=payload.recovery_code)
|
|
if not recovery_used:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid recovery code")
|
|
else:
|
|
if not payload.otp_code:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="2FA code required")
|
|
if not user.twofa_secret or not verify_totp_code(user.twofa_secret, payload.otp_code):
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code")
|
|
|
|
refresh_jti = str(uuid4())
|
|
refresh_token = create_refresh_token(str(user.id), jti=refresh_jti)
|
|
await store_refresh_token_jti(
|
|
user_id=user.id,
|
|
jti=refresh_jti,
|
|
ttl_seconds=_refresh_ttl_seconds(),
|
|
ip_address=ip_address,
|
|
user_agent=user_agent,
|
|
)
|
|
return TokenResponse(
|
|
access_token=create_access_token(str(user.id)),
|
|
refresh_token=refresh_token,
|
|
)
|
|
|
|
|
|
async def refresh_tokens(
|
|
db: AsyncSession,
|
|
payload: RefreshTokenRequest,
|
|
*,
|
|
ip_address: str | None = None,
|
|
user_agent: str | None = None,
|
|
) -> TokenResponse:
|
|
credentials_error = HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid refresh token",
|
|
)
|
|
try:
|
|
token_payload = decode_token(payload.refresh_token)
|
|
except ValueError as exc:
|
|
raise credentials_error from exc
|
|
|
|
if token_payload.get("type") != "refresh":
|
|
raise credentials_error
|
|
|
|
user_id = token_payload.get("sub")
|
|
refresh_jti = token_payload.get("jti")
|
|
if not user_id or not str(user_id).isdigit() or not refresh_jti:
|
|
raise credentials_error
|
|
|
|
active_user_id = await get_refresh_token_user_id(jti=refresh_jti)
|
|
if active_user_id is None or active_user_id != int(user_id):
|
|
raise credentials_error
|
|
user = await get_user_by_id(db, int(user_id))
|
|
if not user:
|
|
raise credentials_error
|
|
|
|
await revoke_refresh_token_jti(jti=refresh_jti)
|
|
new_jti = str(uuid4())
|
|
await store_refresh_token_jti(
|
|
user_id=int(user_id),
|
|
jti=new_jti,
|
|
ttl_seconds=_refresh_ttl_seconds(),
|
|
ip_address=ip_address,
|
|
user_agent=user_agent,
|
|
)
|
|
return TokenResponse(
|
|
access_token=create_access_token(str(user_id)),
|
|
refresh_token=create_refresh_token(str(user_id), jti=new_jti),
|
|
)
|
|
|
|
|
|
async def list_user_sessions(user_id: int) -> list[RefreshSession]:
|
|
return await list_refresh_sessions_for_user(user_id=user_id)
|
|
|
|
|
|
async def revoke_user_session(*, user_id: int, jti: str) -> None:
|
|
active_user_id = await get_refresh_token_user_id(jti=jti)
|
|
if active_user_id is None or active_user_id != user_id:
|
|
return
|
|
await revoke_refresh_token_jti(jti=jti)
|
|
|
|
|
|
async def revoke_all_user_sessions(db: AsyncSession, *, user_id: int) -> None:
|
|
await revoke_all_refresh_sessions_for_user(user_id=user_id)
|
|
user = await get_user_by_id(db, user_id)
|
|
if user:
|
|
user.access_revoked_before = datetime.now(timezone.utc)
|
|
await db.commit()
|
|
|
|
|
|
def _token_issued_at(payload: dict) -> datetime | None:
|
|
raw_iat = payload.get("iat")
|
|
if isinstance(raw_iat, datetime):
|
|
return raw_iat if raw_iat.tzinfo else raw_iat.replace(tzinfo=timezone.utc)
|
|
if isinstance(raw_iat, (int, float)):
|
|
try:
|
|
return datetime.fromtimestamp(float(raw_iat), tz=timezone.utc)
|
|
except Exception:
|
|
return None
|
|
if isinstance(raw_iat, str):
|
|
try:
|
|
parsed = datetime.fromisoformat(raw_iat.replace("Z", "+00:00"))
|
|
return parsed if parsed.tzinfo else parsed.replace(tzinfo=timezone.utc)
|
|
except Exception:
|
|
return None
|
|
return None
|
|
|
|
|
|
def get_request_metadata(request: Request) -> tuple[str | None, str | None]:
|
|
ip_address = request.client.host if request.client else None
|
|
user_agent = request.headers.get("user-agent")
|
|
return ip_address, user_agent
|
|
|
|
|
|
def get_access_session_info(token: str) -> tuple[str, datetime] | None:
|
|
try:
|
|
payload = decode_token(token)
|
|
except ValueError:
|
|
return None
|
|
if payload.get("type") != "access":
|
|
return None
|
|
jti = payload.get("jti")
|
|
if not isinstance(jti, str) or not jti:
|
|
return None
|
|
issued_at = _token_issued_at(payload) or datetime.now(timezone.utc)
|
|
return jti, issued_at
|
|
|
|
|
|
async def setup_twofa(db: AsyncSession, user: User) -> tuple[str, str]:
|
|
if user.twofa_enabled and user.twofa_secret:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled")
|
|
if user.twofa_secret:
|
|
secret = user.twofa_secret
|
|
else:
|
|
secret = generate_totp_secret()
|
|
user.twofa_secret = secret
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
otpauth_url = build_otpauth_uri(secret=secret, account_name=user.email, issuer=settings.app_name)
|
|
return secret, otpauth_url
|
|
|
|
|
|
async def enable_twofa(db: AsyncSession, user: User, *, code: str) -> None:
|
|
if not user.twofa_secret:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not initialized")
|
|
if not verify_totp_code(user.twofa_secret, code):
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid 2FA code")
|
|
user.twofa_enabled = True
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
|
|
|
|
async def disable_twofa(db: AsyncSession, user: User, *, code: str) -> None:
|
|
if not user.twofa_enabled or not user.twofa_secret:
|
|
user.twofa_enabled = False
|
|
user.twofa_secret = None
|
|
user.twofa_recovery_codes_hashes = None
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
return
|
|
if not verify_totp_code(user.twofa_secret, code):
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid 2FA code")
|
|
user.twofa_enabled = False
|
|
user.twofa_secret = None
|
|
user.twofa_recovery_codes_hashes = None
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
|
|
|
|
def _normalize_recovery_code(code: str) -> str:
|
|
return "".join(ch for ch in code.upper() if ch.isalnum())
|
|
|
|
|
|
def _generate_recovery_codes(count: int = 8) -> list[str]:
|
|
alphabet = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
|
codes: list[str] = []
|
|
for _ in range(max(1, count)):
|
|
raw = "".join(secrets.choice(alphabet) for _ in range(10))
|
|
codes.append(f"{raw[:5]}-{raw[5:]}")
|
|
return codes
|
|
|
|
|
|
def _load_twofa_recovery_hashes(user: User) -> list[str]:
|
|
raw = user.twofa_recovery_codes_hashes
|
|
if not raw:
|
|
return []
|
|
try:
|
|
parsed = json.loads(raw)
|
|
except Exception:
|
|
return []
|
|
if not isinstance(parsed, list):
|
|
return []
|
|
hashes: list[str] = []
|
|
for item in parsed:
|
|
if isinstance(item, str) and item:
|
|
hashes.append(item)
|
|
return hashes
|
|
|
|
|
|
async def regenerate_twofa_recovery_codes(db: AsyncSession, user: User, *, code: str) -> list[str]:
|
|
if not user.twofa_enabled or not user.twofa_secret:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled")
|
|
if not verify_totp_code(user.twofa_secret, code):
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid 2FA code")
|
|
codes = _generate_recovery_codes()
|
|
normalized = [_normalize_recovery_code(item) for item in codes]
|
|
user.twofa_recovery_codes_hashes = json.dumps([hash_password(item) for item in normalized], ensure_ascii=True)
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
return codes
|
|
|
|
|
|
def get_twofa_recovery_codes_remaining(user: User) -> int:
|
|
return len(_load_twofa_recovery_hashes(user))
|
|
|
|
|
|
async def _consume_twofa_recovery_code(db: AsyncSession, *, user: User, recovery_code: str) -> bool:
|
|
prepared = _normalize_recovery_code(recovery_code)
|
|
if not prepared:
|
|
return False
|
|
hashes = _load_twofa_recovery_hashes(user)
|
|
if not hashes:
|
|
return False
|
|
match_index = -1
|
|
for idx, code_hash in enumerate(hashes):
|
|
if verify_password(prepared, code_hash):
|
|
match_index = idx
|
|
break
|
|
if match_index < 0:
|
|
return False
|
|
del hashes[match_index]
|
|
user.twofa_recovery_codes_hashes = json.dumps(hashes, ensure_ascii=True) if hashes else None
|
|
await db.commit()
|
|
return True
|
|
|
|
|
|
async def request_password_reset(
|
|
db: AsyncSession,
|
|
payload: RequestPasswordResetRequest,
|
|
email_service: EmailService,
|
|
) -> None:
|
|
user = await get_user_by_email(db, payload.email)
|
|
if not user:
|
|
return
|
|
|
|
reset_token = generate_random_token()
|
|
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.password_reset_token_expire_hours)
|
|
await auth_repository.delete_password_reset_tokens_for_user(db, user.id)
|
|
await auth_repository.create_password_reset_token(db, user.id, reset_token, expires_at)
|
|
try:
|
|
await email_service.send_password_reset_email(user.email, reset_token)
|
|
except EmailDeliveryError as exc:
|
|
await db.rollback()
|
|
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Unable to send reset email") from exc
|
|
await db.commit()
|
|
|
|
|
|
async def reset_password(db: AsyncSession, payload: ResetPasswordRequest) -> None:
|
|
record = await auth_repository.get_password_reset_token(db, payload.token)
|
|
if not record:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid reset token")
|
|
|
|
now = datetime.now(timezone.utc)
|
|
expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc)
|
|
if expires_at < now:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Reset token expired")
|
|
|
|
user = await get_user_by_id(db, record.user_id)
|
|
if not user:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
|
|
|
user.password_hash = hash_password(payload.new_password)
|
|
await auth_repository.delete_password_reset_tokens_for_user(db, user.id)
|
|
await db.commit()
|
|
|
|
|
|
async def get_current_user(token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db)) -> User:
|
|
credentials_error = HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
try:
|
|
payload = decode_token(token)
|
|
except ValueError as exc:
|
|
raise credentials_error from exc
|
|
|
|
if payload.get("type") != "access":
|
|
raise credentials_error
|
|
|
|
user_id = payload.get("sub")
|
|
if not user_id or not str(user_id).isdigit():
|
|
raise credentials_error
|
|
|
|
user = await get_user_by_id(db, int(user_id))
|
|
if not user:
|
|
raise credentials_error
|
|
issued_at = _token_issued_at(payload)
|
|
if user.access_revoked_before is not None and issued_at is not None:
|
|
revoked_before = (
|
|
user.access_revoked_before
|
|
if user.access_revoked_before.tzinfo
|
|
else user.access_revoked_before.replace(tzinfo=timezone.utc)
|
|
)
|
|
if issued_at <= revoked_before:
|
|
raise credentials_error
|
|
return user
|
|
|
|
|
|
async def get_current_user_for_ws(token: str, db: AsyncSession) -> User:
|
|
try:
|
|
payload = decode_token(token)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") from exc
|
|
|
|
if payload.get("type") != "access":
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type")
|
|
|
|
user_id = payload.get("sub")
|
|
if not user_id or not str(user_id).isdigit():
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload")
|
|
|
|
user = await get_user_by_id(db, int(user_id))
|
|
if not user:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
|
issued_at = _token_issued_at(payload)
|
|
if user.access_revoked_before is not None and issued_at is not None:
|
|
revoked_before = (
|
|
user.access_revoked_before
|
|
if user.access_revoked_before.tzinfo
|
|
else user.access_revoked_before.replace(tzinfo=timezone.utc)
|
|
)
|
|
if issued_at <= revoked_before:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token was revoked")
|
|
return user
|
|
|
|
|
|
def get_email_sender() -> EmailService:
|
|
return get_email_service()
|
|
|
|
|
|
async def get_email_status(db: AsyncSession, email: str) -> EmailStatusResponse:
|
|
user = await get_user_by_email(db, email)
|
|
if not user:
|
|
return EmailStatusResponse(email=email, registered=False)
|
|
return EmailStatusResponse(
|
|
email=email,
|
|
registered=True,
|
|
email_verified=user.email_verified,
|
|
twofa_enabled=user.twofa_enabled,
|
|
)
|