from datetime import datetime, timedelta, timezone from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from sqlalchemy.ext.asyncio import AsyncSession from app.auth import repository as auth_repository from app.auth.schemas import ( LoginRequest, RegisterRequest, RequestPasswordResetRequest, ResendVerificationRequest, ResetPasswordRequest, TokenResponse, VerifyEmailRequest, ) from app.config.settings import settings from app.database.session import get_db from app.email.service import EmailService, get_email_service from app.users.models import User from app.users.repository import create_user, get_user_by_email, get_user_by_id, get_user_by_username from app.utils.security import ( create_access_token, create_refresh_token, decode_token, generate_random_token, hash_password, verify_password, ) oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/login") async def register_user( db: AsyncSession, payload: RegisterRequest, email_service: EmailService, ) -> None: existing_email = await get_user_by_email(db, payload.email) if existing_email: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email is already registered") existing_username = await get_user_by_username(db, payload.username) if existing_username: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Username is already taken") user = await create_user( db, email=payload.email, username=payload.username, password_hash=hash_password(payload.password), ) verification_token = generate_random_token() expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours) await auth_repository.delete_email_verification_tokens_for_user(db, user.id) await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at) await db.commit() await email_service.send_verification_email(payload.email, verification_token) async def verify_email(db: AsyncSession, payload: VerifyEmailRequest) -> None: record = await auth_repository.get_email_verification_token(db, payload.token) if not record: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification token") now = datetime.now(timezone.utc) expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc) if expires_at < now: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Verification token expired") user = await get_user_by_id(db, record.user_id) if not user: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") user.email_verified = True await auth_repository.delete_email_verification_tokens_for_user(db, user.id) await db.commit() async def resend_verification_email( db: AsyncSession, payload: ResendVerificationRequest, email_service: EmailService, ) -> None: user = await get_user_by_email(db, payload.email) if not user or user.email_verified: return verification_token = generate_random_token() expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours) await auth_repository.delete_email_verification_tokens_for_user(db, user.id) await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at) await db.commit() await email_service.send_verification_email(user.email, verification_token) async def login_user(db: AsyncSession, payload: LoginRequest) -> TokenResponse: user = await get_user_by_email(db, payload.email) if not user or not verify_password(payload.password, user.password_hash): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") if not user.email_verified: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified") return TokenResponse( access_token=create_access_token(str(user.id)), refresh_token=create_refresh_token(str(user.id)), ) async def request_password_reset( db: AsyncSession, payload: RequestPasswordResetRequest, email_service: EmailService, ) -> None: user = await get_user_by_email(db, payload.email) if not user: return reset_token = generate_random_token() expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.password_reset_token_expire_hours) await auth_repository.delete_password_reset_tokens_for_user(db, user.id) await auth_repository.create_password_reset_token(db, user.id, reset_token, expires_at) await db.commit() await email_service.send_password_reset_email(user.email, reset_token) async def reset_password(db: AsyncSession, payload: ResetPasswordRequest) -> None: record = await auth_repository.get_password_reset_token(db, payload.token) if not record: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid reset token") now = datetime.now(timezone.utc) expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc) if expires_at < now: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Reset token expired") user = await get_user_by_id(db, record.user_id) if not user: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") user.password_hash = hash_password(payload.new_password) await auth_repository.delete_password_reset_tokens_for_user(db, user.id) await db.commit() async def get_current_user(token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db)) -> User: credentials_error = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = decode_token(token) except ValueError as exc: raise credentials_error from exc if payload.get("type") != "access": raise credentials_error user_id = payload.get("sub") if not user_id or not str(user_id).isdigit(): raise credentials_error user = await get_user_by_id(db, int(user_id)) if not user: raise credentials_error return user async def get_current_user_for_ws(token: str, db: AsyncSession) -> User: try: payload = decode_token(token) except ValueError as exc: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") from exc if payload.get("type") != "access": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type") user_id = payload.get("sub") if not user_id or not str(user_id).isdigit(): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload") user = await get_user_by_id(db, int(user_id)) if not user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found") return user def get_email_sender() -> EmailService: return get_email_service()