Files
Messenger/app/auth/service.py
benya fb812c9a39
All checks were successful
CI / test (push) Successful in 40s
auth(2fa): add one-time recovery codes with regenerate/status APIs
2026-03-08 19:16:15 +03:00

491 lines
18 KiB
Python

from datetime import datetime, timedelta, timezone
import json
import secrets
from uuid import uuid4
from fastapi import Depends, HTTPException, status
from fastapi import Request
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth import repository as auth_repository
from app.auth.token_store import (
RefreshSession,
get_refresh_token_user_id,
list_refresh_sessions_for_user,
revoke_all_refresh_sessions_for_user,
revoke_refresh_token_jti,
store_refresh_token_jti,
)
from app.auth.schemas import (
EmailStatusResponse,
LoginRequest,
RefreshTokenRequest,
RegisterRequest,
RequestPasswordResetRequest,
ResendVerificationRequest,
ResetPasswordRequest,
TokenResponse,
VerifyEmailRequest,
)
from app.config.settings import settings
from app.database.session import get_db
from app.email.service import EmailDeliveryError, EmailService, get_email_service
from app.users.models import User
from app.users.repository import create_user, get_user_by_email, get_user_by_id, get_user_by_username
from app.utils.totp import build_otpauth_uri, generate_totp_secret, verify_totp_code
from app.utils.security import (
create_access_token,
create_refresh_token,
decode_token,
generate_random_token,
hash_password,
verify_password,
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/login")
def _refresh_ttl_seconds() -> int:
return settings.refresh_token_expire_days * 24 * 60 * 60
async def register_user(
db: AsyncSession,
payload: RegisterRequest,
email_service: EmailService,
) -> None:
existing_email = await get_user_by_email(db, payload.email)
if existing_email:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email is already registered")
existing_username = await get_user_by_username(db, payload.username)
if existing_username:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Username is already taken")
user = await create_user(
db,
email=payload.email,
name=payload.name,
username=payload.username,
password_hash=hash_password(payload.password),
)
verification_token = generate_random_token()
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours)
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at)
try:
await email_service.send_verification_email(payload.email, verification_token)
except EmailDeliveryError as exc:
await db.rollback()
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Unable to send verification email") from exc
await db.commit()
async def verify_email(db: AsyncSession, payload: VerifyEmailRequest) -> None:
record = await auth_repository.get_email_verification_token(db, payload.token)
if not record:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification token")
now = datetime.now(timezone.utc)
expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc)
if expires_at < now:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Verification token expired")
user = await get_user_by_id(db, record.user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
user.email_verified = True
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
await db.commit()
async def resend_verification_email(
db: AsyncSession,
payload: ResendVerificationRequest,
email_service: EmailService,
) -> None:
user = await get_user_by_email(db, payload.email)
if not user or user.email_verified:
return
verification_token = generate_random_token()
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours)
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at)
try:
await email_service.send_verification_email(user.email, verification_token)
except EmailDeliveryError as exc:
await db.rollback()
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Unable to send verification email") from exc
await db.commit()
async def login_user(
db: AsyncSession,
payload: LoginRequest,
*,
ip_address: str | None = None,
user_agent: str | None = None,
) -> TokenResponse:
user = await get_user_by_email(db, payload.email)
if not user or not verify_password(payload.password, user.password_hash):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
if not user.email_verified:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified")
if user.twofa_enabled:
if payload.recovery_code:
recovery_used = await _consume_twofa_recovery_code(db, user=user, recovery_code=payload.recovery_code)
if not recovery_used:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid recovery code")
else:
if not payload.otp_code:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="2FA code required")
if not user.twofa_secret or not verify_totp_code(user.twofa_secret, payload.otp_code):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code")
refresh_jti = str(uuid4())
refresh_token = create_refresh_token(str(user.id), jti=refresh_jti)
await store_refresh_token_jti(
user_id=user.id,
jti=refresh_jti,
ttl_seconds=_refresh_ttl_seconds(),
ip_address=ip_address,
user_agent=user_agent,
)
return TokenResponse(
access_token=create_access_token(str(user.id)),
refresh_token=refresh_token,
)
async def refresh_tokens(
db: AsyncSession,
payload: RefreshTokenRequest,
*,
ip_address: str | None = None,
user_agent: str | None = None,
) -> TokenResponse:
credentials_error = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
try:
token_payload = decode_token(payload.refresh_token)
except ValueError as exc:
raise credentials_error from exc
if token_payload.get("type") != "refresh":
raise credentials_error
user_id = token_payload.get("sub")
refresh_jti = token_payload.get("jti")
if not user_id or not str(user_id).isdigit() or not refresh_jti:
raise credentials_error
active_user_id = await get_refresh_token_user_id(jti=refresh_jti)
if active_user_id is None or active_user_id != int(user_id):
raise credentials_error
user = await get_user_by_id(db, int(user_id))
if not user:
raise credentials_error
await revoke_refresh_token_jti(jti=refresh_jti)
new_jti = str(uuid4())
await store_refresh_token_jti(
user_id=int(user_id),
jti=new_jti,
ttl_seconds=_refresh_ttl_seconds(),
ip_address=ip_address,
user_agent=user_agent,
)
return TokenResponse(
access_token=create_access_token(str(user_id)),
refresh_token=create_refresh_token(str(user_id), jti=new_jti),
)
async def list_user_sessions(user_id: int) -> list[RefreshSession]:
return await list_refresh_sessions_for_user(user_id=user_id)
async def revoke_user_session(*, user_id: int, jti: str) -> None:
active_user_id = await get_refresh_token_user_id(jti=jti)
if active_user_id is None or active_user_id != user_id:
return
await revoke_refresh_token_jti(jti=jti)
async def revoke_all_user_sessions(db: AsyncSession, *, user_id: int) -> None:
await revoke_all_refresh_sessions_for_user(user_id=user_id)
user = await get_user_by_id(db, user_id)
if user:
user.access_revoked_before = datetime.now(timezone.utc)
await db.commit()
def _token_issued_at(payload: dict) -> datetime | None:
raw_iat = payload.get("iat")
if isinstance(raw_iat, datetime):
return raw_iat if raw_iat.tzinfo else raw_iat.replace(tzinfo=timezone.utc)
if isinstance(raw_iat, (int, float)):
try:
return datetime.fromtimestamp(float(raw_iat), tz=timezone.utc)
except Exception:
return None
if isinstance(raw_iat, str):
try:
parsed = datetime.fromisoformat(raw_iat.replace("Z", "+00:00"))
return parsed if parsed.tzinfo else parsed.replace(tzinfo=timezone.utc)
except Exception:
return None
return None
def get_request_metadata(request: Request) -> tuple[str | None, str | None]:
ip_address = request.client.host if request.client else None
user_agent = request.headers.get("user-agent")
return ip_address, user_agent
def get_access_session_info(token: str) -> tuple[str, datetime] | None:
try:
payload = decode_token(token)
except ValueError:
return None
if payload.get("type") != "access":
return None
jti = payload.get("jti")
if not isinstance(jti, str) or not jti:
return None
issued_at = _token_issued_at(payload) or datetime.now(timezone.utc)
return jti, issued_at
async def setup_twofa(db: AsyncSession, user: User) -> tuple[str, str]:
if user.twofa_enabled and user.twofa_secret:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled")
if user.twofa_secret:
secret = user.twofa_secret
else:
secret = generate_totp_secret()
user.twofa_secret = secret
await db.commit()
await db.refresh(user)
otpauth_url = build_otpauth_uri(secret=secret, account_name=user.email, issuer=settings.app_name)
return secret, otpauth_url
async def enable_twofa(db: AsyncSession, user: User, *, code: str) -> None:
if not user.twofa_secret:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not initialized")
if not verify_totp_code(user.twofa_secret, code):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid 2FA code")
user.twofa_enabled = True
await db.commit()
await db.refresh(user)
async def disable_twofa(db: AsyncSession, user: User, *, code: str) -> None:
if not user.twofa_enabled or not user.twofa_secret:
user.twofa_enabled = False
user.twofa_secret = None
user.twofa_recovery_codes_hashes = None
await db.commit()
await db.refresh(user)
return
if not verify_totp_code(user.twofa_secret, code):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid 2FA code")
user.twofa_enabled = False
user.twofa_secret = None
user.twofa_recovery_codes_hashes = None
await db.commit()
await db.refresh(user)
def _normalize_recovery_code(code: str) -> str:
return "".join(ch for ch in code.upper() if ch.isalnum())
def _generate_recovery_codes(count: int = 8) -> list[str]:
alphabet = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
codes: list[str] = []
for _ in range(max(1, count)):
raw = "".join(secrets.choice(alphabet) for _ in range(10))
codes.append(f"{raw[:5]}-{raw[5:]}")
return codes
def _load_twofa_recovery_hashes(user: User) -> list[str]:
raw = user.twofa_recovery_codes_hashes
if not raw:
return []
try:
parsed = json.loads(raw)
except Exception:
return []
if not isinstance(parsed, list):
return []
hashes: list[str] = []
for item in parsed:
if isinstance(item, str) and item:
hashes.append(item)
return hashes
async def regenerate_twofa_recovery_codes(db: AsyncSession, user: User, *, code: str) -> list[str]:
if not user.twofa_enabled or not user.twofa_secret:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled")
if not verify_totp_code(user.twofa_secret, code):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid 2FA code")
codes = _generate_recovery_codes()
normalized = [_normalize_recovery_code(item) for item in codes]
user.twofa_recovery_codes_hashes = json.dumps([hash_password(item) for item in normalized], ensure_ascii=True)
await db.commit()
await db.refresh(user)
return codes
def get_twofa_recovery_codes_remaining(user: User) -> int:
return len(_load_twofa_recovery_hashes(user))
async def _consume_twofa_recovery_code(db: AsyncSession, *, user: User, recovery_code: str) -> bool:
prepared = _normalize_recovery_code(recovery_code)
if not prepared:
return False
hashes = _load_twofa_recovery_hashes(user)
if not hashes:
return False
match_index = -1
for idx, code_hash in enumerate(hashes):
if verify_password(prepared, code_hash):
match_index = idx
break
if match_index < 0:
return False
del hashes[match_index]
user.twofa_recovery_codes_hashes = json.dumps(hashes, ensure_ascii=True) if hashes else None
await db.commit()
return True
async def request_password_reset(
db: AsyncSession,
payload: RequestPasswordResetRequest,
email_service: EmailService,
) -> None:
user = await get_user_by_email(db, payload.email)
if not user:
return
reset_token = generate_random_token()
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.password_reset_token_expire_hours)
await auth_repository.delete_password_reset_tokens_for_user(db, user.id)
await auth_repository.create_password_reset_token(db, user.id, reset_token, expires_at)
try:
await email_service.send_password_reset_email(user.email, reset_token)
except EmailDeliveryError as exc:
await db.rollback()
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Unable to send reset email") from exc
await db.commit()
async def reset_password(db: AsyncSession, payload: ResetPasswordRequest) -> None:
record = await auth_repository.get_password_reset_token(db, payload.token)
if not record:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid reset token")
now = datetime.now(timezone.utc)
expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc)
if expires_at < now:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Reset token expired")
user = await get_user_by_id(db, record.user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
user.password_hash = hash_password(payload.new_password)
await auth_repository.delete_password_reset_tokens_for_user(db, user.id)
await db.commit()
async def get_current_user(token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db)) -> User:
credentials_error = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = decode_token(token)
except ValueError as exc:
raise credentials_error from exc
if payload.get("type") != "access":
raise credentials_error
user_id = payload.get("sub")
if not user_id or not str(user_id).isdigit():
raise credentials_error
user = await get_user_by_id(db, int(user_id))
if not user:
raise credentials_error
issued_at = _token_issued_at(payload)
if user.access_revoked_before is not None and issued_at is not None:
revoked_before = (
user.access_revoked_before
if user.access_revoked_before.tzinfo
else user.access_revoked_before.replace(tzinfo=timezone.utc)
)
if issued_at <= revoked_before:
raise credentials_error
return user
async def get_current_user_for_ws(token: str, db: AsyncSession) -> User:
try:
payload = decode_token(token)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") from exc
if payload.get("type") != "access":
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type")
user_id = payload.get("sub")
if not user_id or not str(user_id).isdigit():
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload")
user = await get_user_by_id(db, int(user_id))
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
issued_at = _token_issued_at(payload)
if user.access_revoked_before is not None and issued_at is not None:
revoked_before = (
user.access_revoked_before
if user.access_revoked_before.tzinfo
else user.access_revoked_before.replace(tzinfo=timezone.utc)
)
if issued_at <= revoked_before:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token was revoked")
return user
def get_email_sender() -> EmailService:
return get_email_service()
async def get_email_status(db: AsyncSession, email: str) -> EmailStatusResponse:
user = await get_user_by_email(db, email)
if not user:
return EmailStatusResponse(email=email, registered=False)
return EmailStatusResponse(
email=email,
registered=True,
email_verified=user.email_verified,
twofa_enabled=user.twofa_enabled,
)