from datetime import datetime, timedelta, timezone from uuid import uuid4 from fastapi import Depends, HTTPException, status from fastapi import Request from fastapi.security import OAuth2PasswordBearer from sqlalchemy.ext.asyncio import AsyncSession from app.auth import repository as auth_repository from app.auth.token_store import ( RefreshSession, get_refresh_token_user_id, list_refresh_sessions_for_user, revoke_all_refresh_sessions_for_user, revoke_refresh_token_jti, store_refresh_token_jti, ) from app.auth.schemas import ( EmailStatusResponse, LoginRequest, RefreshTokenRequest, RegisterRequest, RequestPasswordResetRequest, ResendVerificationRequest, ResetPasswordRequest, TokenResponse, VerifyEmailRequest, ) from app.config.settings import settings from app.database.session import get_db from app.email.service import EmailDeliveryError, EmailService, get_email_service from app.users.models import User from app.users.repository import create_user, get_user_by_email, get_user_by_id, get_user_by_username from app.utils.totp import build_otpauth_uri, generate_totp_secret, verify_totp_code from app.utils.security import ( create_access_token, create_refresh_token, decode_token, generate_random_token, hash_password, verify_password, ) oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/login") def _refresh_ttl_seconds() -> int: return settings.refresh_token_expire_days * 24 * 60 * 60 async def register_user( db: AsyncSession, payload: RegisterRequest, email_service: EmailService, ) -> None: existing_email = await get_user_by_email(db, payload.email) if existing_email: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email is already registered") existing_username = await get_user_by_username(db, payload.username) if existing_username: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Username is already taken") user = await create_user( db, email=payload.email, name=payload.name, username=payload.username, password_hash=hash_password(payload.password), ) verification_token = generate_random_token() expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours) await auth_repository.delete_email_verification_tokens_for_user(db, user.id) await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at) try: await email_service.send_verification_email(payload.email, verification_token) except EmailDeliveryError as exc: await db.rollback() raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Unable to send verification email") from exc await db.commit() async def verify_email(db: AsyncSession, payload: VerifyEmailRequest) -> None: record = await auth_repository.get_email_verification_token(db, payload.token) if not record: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification token") now = datetime.now(timezone.utc) expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc) if expires_at < now: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Verification token expired") user = await get_user_by_id(db, record.user_id) if not user: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") user.email_verified = True await auth_repository.delete_email_verification_tokens_for_user(db, user.id) await db.commit() async def resend_verification_email( db: AsyncSession, payload: ResendVerificationRequest, email_service: EmailService, ) -> None: user = await get_user_by_email(db, payload.email) if not user or user.email_verified: return verification_token = generate_random_token() expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours) await auth_repository.delete_email_verification_tokens_for_user(db, user.id) await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at) try: await email_service.send_verification_email(user.email, verification_token) except EmailDeliveryError as exc: await db.rollback() raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Unable to send verification email") from exc await db.commit() async def login_user( db: AsyncSession, payload: LoginRequest, *, ip_address: str | None = None, user_agent: str | None = None, ) -> TokenResponse: user = await get_user_by_email(db, payload.email) if not user or not verify_password(payload.password, user.password_hash): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") if not user.email_verified: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified") if user.twofa_enabled: if not payload.otp_code: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="2FA code required") if not user.twofa_secret or not verify_totp_code(user.twofa_secret, payload.otp_code): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code") refresh_jti = str(uuid4()) refresh_token = create_refresh_token(str(user.id), jti=refresh_jti) await store_refresh_token_jti( user_id=user.id, jti=refresh_jti, ttl_seconds=_refresh_ttl_seconds(), ip_address=ip_address, user_agent=user_agent, ) return TokenResponse( access_token=create_access_token(str(user.id)), refresh_token=refresh_token, ) async def refresh_tokens( db: AsyncSession, payload: RefreshTokenRequest, *, ip_address: str | None = None, user_agent: str | None = None, ) -> TokenResponse: credentials_error = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token", ) try: token_payload = decode_token(payload.refresh_token) except ValueError as exc: raise credentials_error from exc if token_payload.get("type") != "refresh": raise credentials_error user_id = token_payload.get("sub") refresh_jti = token_payload.get("jti") if not user_id or not str(user_id).isdigit() or not refresh_jti: raise credentials_error active_user_id = await get_refresh_token_user_id(jti=refresh_jti) if active_user_id is None or active_user_id != int(user_id): raise credentials_error user = await get_user_by_id(db, int(user_id)) if not user: raise credentials_error await revoke_refresh_token_jti(jti=refresh_jti) new_jti = str(uuid4()) await store_refresh_token_jti( user_id=int(user_id), jti=new_jti, ttl_seconds=_refresh_ttl_seconds(), ip_address=ip_address, user_agent=user_agent, ) return TokenResponse( access_token=create_access_token(str(user_id)), refresh_token=create_refresh_token(str(user_id), jti=new_jti), ) async def list_user_sessions(user_id: int) -> list[RefreshSession]: return await list_refresh_sessions_for_user(user_id=user_id) async def revoke_user_session(*, user_id: int, jti: str) -> None: active_user_id = await get_refresh_token_user_id(jti=jti) if active_user_id is None or active_user_id != user_id: return await revoke_refresh_token_jti(jti=jti) async def revoke_all_user_sessions(db: AsyncSession, *, user_id: int) -> None: await revoke_all_refresh_sessions_for_user(user_id=user_id) user = await get_user_by_id(db, user_id) if user: user.access_revoked_before = datetime.now(timezone.utc) await db.commit() def _token_issued_at(payload: dict) -> datetime | None: raw_iat = payload.get("iat") if isinstance(raw_iat, datetime): return raw_iat if raw_iat.tzinfo else raw_iat.replace(tzinfo=timezone.utc) if isinstance(raw_iat, (int, float)): try: return datetime.fromtimestamp(float(raw_iat), tz=timezone.utc) except Exception: return None if isinstance(raw_iat, str): try: parsed = datetime.fromisoformat(raw_iat.replace("Z", "+00:00")) return parsed if parsed.tzinfo else parsed.replace(tzinfo=timezone.utc) except Exception: return None return None def get_request_metadata(request: Request) -> tuple[str | None, str | None]: ip_address = request.client.host if request.client else None user_agent = request.headers.get("user-agent") return ip_address, user_agent def get_access_session_info(token: str) -> tuple[str, datetime] | None: try: payload = decode_token(token) except ValueError: return None if payload.get("type") != "access": return None jti = payload.get("jti") if not isinstance(jti, str) or not jti: return None issued_at = _token_issued_at(payload) or datetime.now(timezone.utc) return jti, issued_at async def setup_twofa(db: AsyncSession, user: User) -> tuple[str, str]: if user.twofa_enabled and user.twofa_secret: secret = user.twofa_secret else: secret = generate_totp_secret() user.twofa_secret = secret await db.commit() await db.refresh(user) otpauth_url = build_otpauth_uri(secret=secret, account_name=user.email, issuer=settings.app_name) return secret, otpauth_url async def enable_twofa(db: AsyncSession, user: User, *, code: str) -> None: if not user.twofa_secret: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not initialized") if not verify_totp_code(user.twofa_secret, code): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid 2FA code") user.twofa_enabled = True await db.commit() await db.refresh(user) async def disable_twofa(db: AsyncSession, user: User, *, code: str) -> None: if not user.twofa_enabled or not user.twofa_secret: user.twofa_enabled = False user.twofa_secret = None await db.commit() await db.refresh(user) return if not verify_totp_code(user.twofa_secret, code): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid 2FA code") user.twofa_enabled = False user.twofa_secret = None await db.commit() await db.refresh(user) async def request_password_reset( db: AsyncSession, payload: RequestPasswordResetRequest, email_service: EmailService, ) -> None: user = await get_user_by_email(db, payload.email) if not user: return reset_token = generate_random_token() expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.password_reset_token_expire_hours) await auth_repository.delete_password_reset_tokens_for_user(db, user.id) await auth_repository.create_password_reset_token(db, user.id, reset_token, expires_at) try: await email_service.send_password_reset_email(user.email, reset_token) except EmailDeliveryError as exc: await db.rollback() raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Unable to send reset email") from exc await db.commit() async def reset_password(db: AsyncSession, payload: ResetPasswordRequest) -> None: record = await auth_repository.get_password_reset_token(db, payload.token) if not record: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid reset token") now = datetime.now(timezone.utc) expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc) if expires_at < now: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Reset token expired") user = await get_user_by_id(db, record.user_id) if not user: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") user.password_hash = hash_password(payload.new_password) await auth_repository.delete_password_reset_tokens_for_user(db, user.id) await db.commit() async def get_current_user(token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db)) -> User: credentials_error = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = decode_token(token) except ValueError as exc: raise credentials_error from exc if payload.get("type") != "access": raise credentials_error user_id = payload.get("sub") if not user_id or not str(user_id).isdigit(): raise credentials_error user = await get_user_by_id(db, int(user_id)) if not user: raise credentials_error issued_at = _token_issued_at(payload) if user.access_revoked_before is not None and issued_at is not None: revoked_before = ( user.access_revoked_before if user.access_revoked_before.tzinfo else user.access_revoked_before.replace(tzinfo=timezone.utc) ) if issued_at <= revoked_before: raise credentials_error return user async def get_current_user_for_ws(token: str, db: AsyncSession) -> User: try: payload = decode_token(token) except ValueError as exc: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") from exc if payload.get("type") != "access": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type") user_id = payload.get("sub") if not user_id or not str(user_id).isdigit(): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload") user = await get_user_by_id(db, int(user_id)) if not user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found") issued_at = _token_issued_at(payload) if user.access_revoked_before is not None and issued_at is not None: revoked_before = ( user.access_revoked_before if user.access_revoked_before.tzinfo else user.access_revoked_before.replace(tzinfo=timezone.utc) ) if issued_at <= revoked_before: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token was revoked") return user def get_email_sender() -> EmailService: return get_email_service() async def get_email_status(db: AsyncSession, email: str) -> EmailStatusResponse: user = await get_user_by_email(db, email) if not user: return EmailStatusResponse(email=email, registered=False) return EmailStatusResponse( email=email, registered=True, email_verified=user.email_verified, twofa_enabled=user.twofa_enabled, )