Files
Messenger/app/auth/service.py
2026-03-07 21:31:38 +03:00

196 lines
7.2 KiB
Python

from datetime import datetime, timedelta, timezone
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth import repository as auth_repository
from app.auth.schemas import (
LoginRequest,
RegisterRequest,
RequestPasswordResetRequest,
ResendVerificationRequest,
ResetPasswordRequest,
TokenResponse,
VerifyEmailRequest,
)
from app.config.settings import settings
from app.database.session import get_db
from app.email.service import EmailService, get_email_service
from app.users.models import User
from app.users.repository import create_user, get_user_by_email, get_user_by_id, get_user_by_username
from app.utils.security import (
create_access_token,
create_refresh_token,
decode_token,
generate_random_token,
hash_password,
verify_password,
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/login")
async def register_user(
db: AsyncSession,
payload: RegisterRequest,
email_service: EmailService,
) -> None:
existing_email = await get_user_by_email(db, payload.email)
if existing_email:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email is already registered")
existing_username = await get_user_by_username(db, payload.username)
if existing_username:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Username is already taken")
user = await create_user(
db,
email=payload.email,
username=payload.username,
password_hash=hash_password(payload.password),
)
verification_token = generate_random_token()
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours)
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at)
await db.commit()
await email_service.send_verification_email(payload.email, verification_token)
async def verify_email(db: AsyncSession, payload: VerifyEmailRequest) -> None:
record = await auth_repository.get_email_verification_token(db, payload.token)
if not record:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification token")
now = datetime.now(timezone.utc)
expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc)
if expires_at < now:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Verification token expired")
user = await get_user_by_id(db, record.user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
user.email_verified = True
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
await db.commit()
async def resend_verification_email(
db: AsyncSession,
payload: ResendVerificationRequest,
email_service: EmailService,
) -> None:
user = await get_user_by_email(db, payload.email)
if not user or user.email_verified:
return
verification_token = generate_random_token()
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours)
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at)
await db.commit()
await email_service.send_verification_email(user.email, verification_token)
async def login_user(db: AsyncSession, payload: LoginRequest) -> TokenResponse:
user = await get_user_by_email(db, payload.email)
if not user or not verify_password(payload.password, user.password_hash):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
if not user.email_verified:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified")
return TokenResponse(
access_token=create_access_token(str(user.id)),
refresh_token=create_refresh_token(str(user.id)),
)
async def request_password_reset(
db: AsyncSession,
payload: RequestPasswordResetRequest,
email_service: EmailService,
) -> None:
user = await get_user_by_email(db, payload.email)
if not user:
return
reset_token = generate_random_token()
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.password_reset_token_expire_hours)
await auth_repository.delete_password_reset_tokens_for_user(db, user.id)
await auth_repository.create_password_reset_token(db, user.id, reset_token, expires_at)
await db.commit()
await email_service.send_password_reset_email(user.email, reset_token)
async def reset_password(db: AsyncSession, payload: ResetPasswordRequest) -> None:
record = await auth_repository.get_password_reset_token(db, payload.token)
if not record:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid reset token")
now = datetime.now(timezone.utc)
expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc)
if expires_at < now:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Reset token expired")
user = await get_user_by_id(db, record.user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
user.password_hash = hash_password(payload.new_password)
await auth_repository.delete_password_reset_tokens_for_user(db, user.id)
await db.commit()
async def get_current_user(token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db)) -> User:
credentials_error = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = decode_token(token)
except ValueError as exc:
raise credentials_error from exc
if payload.get("type") != "access":
raise credentials_error
user_id = payload.get("sub")
if not user_id or not str(user_id).isdigit():
raise credentials_error
user = await get_user_by_id(db, int(user_id))
if not user:
raise credentials_error
return user
async def get_current_user_for_ws(token: str, db: AsyncSession) -> User:
try:
payload = decode_token(token)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") from exc
if payload.get("type") != "access":
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type")
user_id = payload.get("sub")
if not user_id or not str(user_id).isdigit():
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload")
user = await get_user_by_id(db, int(user_id))
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
return user
def get_email_sender() -> EmailService:
return get_email_service()