auth(2fa): add one-time recovery codes with regenerate/status APIs
All checks were successful
CI / test (push) Successful in 40s
All checks were successful
CI / test (push) Successful in 40s
This commit is contained in:
@@ -16,6 +16,8 @@ from app.auth.schemas import (
|
||||
TokenResponse,
|
||||
SessionRead,
|
||||
TwoFactorCodeRequest,
|
||||
TwoFactorRecoveryCodesRead,
|
||||
TwoFactorRecoveryStatusRead,
|
||||
TwoFactorSetupRead,
|
||||
VerifyEmailRequest,
|
||||
)
|
||||
@@ -37,6 +39,8 @@ from app.auth.service import (
|
||||
resend_verification_email,
|
||||
reset_password,
|
||||
setup_twofa,
|
||||
regenerate_twofa_recovery_codes,
|
||||
get_twofa_recovery_codes_remaining,
|
||||
verify_email,
|
||||
oauth2_scheme,
|
||||
)
|
||||
@@ -224,3 +228,20 @@ async def disable_2fa(
|
||||
) -> MessageResponse:
|
||||
await disable_twofa(db, current_user, code=payload.code)
|
||||
return MessageResponse(message="2FA disabled")
|
||||
|
||||
|
||||
@router.post("/2fa/recovery-codes/regenerate", response_model=TwoFactorRecoveryCodesRead)
|
||||
async def regenerate_2fa_recovery_codes(
|
||||
payload: TwoFactorCodeRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> TwoFactorRecoveryCodesRead:
|
||||
codes = await regenerate_twofa_recovery_codes(db, current_user, code=payload.code)
|
||||
return TwoFactorRecoveryCodesRead(codes=codes)
|
||||
|
||||
|
||||
@router.get("/2fa/recovery-codes/status", response_model=TwoFactorRecoveryStatusRead)
|
||||
async def get_2fa_recovery_codes_status(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> TwoFactorRecoveryStatusRead:
|
||||
return TwoFactorRecoveryStatusRead(remaining_codes=get_twofa_recovery_codes_remaining(current_user))
|
||||
|
||||
@@ -15,6 +15,7 @@ class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=8, max_length=128)
|
||||
otp_code: str | None = Field(default=None, min_length=6, max_length=8)
|
||||
recovery_code: str | None = Field(default=None, min_length=6, max_length=32)
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
@@ -86,6 +87,14 @@ class TwoFactorCodeRequest(BaseModel):
|
||||
code: str = Field(min_length=6, max_length=8)
|
||||
|
||||
|
||||
class TwoFactorRecoveryCodesRead(BaseModel):
|
||||
codes: list[str]
|
||||
|
||||
|
||||
class TwoFactorRecoveryStatusRead(BaseModel):
|
||||
remaining_codes: int
|
||||
|
||||
|
||||
class EmailStatusResponse(BaseModel):
|
||||
email: EmailStr
|
||||
registered: bool
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import json
|
||||
import secrets
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
@@ -135,10 +137,15 @@ async def login_user(
|
||||
if not user.email_verified:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified")
|
||||
if user.twofa_enabled:
|
||||
if not payload.otp_code:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="2FA code required")
|
||||
if not user.twofa_secret or not verify_totp_code(user.twofa_secret, payload.otp_code):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code")
|
||||
if payload.recovery_code:
|
||||
recovery_used = await _consume_twofa_recovery_code(db, user=user, recovery_code=payload.recovery_code)
|
||||
if not recovery_used:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid recovery code")
|
||||
else:
|
||||
if not payload.otp_code:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="2FA code required")
|
||||
if not user.twofa_secret or not verify_totp_code(user.twofa_secret, payload.otp_code):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code")
|
||||
|
||||
refresh_jti = str(uuid4())
|
||||
refresh_token = create_refresh_token(str(user.id), jti=refresh_jti)
|
||||
@@ -286,6 +293,7 @@ async def disable_twofa(db: AsyncSession, user: User, *, code: str) -> None:
|
||||
if not user.twofa_enabled or not user.twofa_secret:
|
||||
user.twofa_enabled = False
|
||||
user.twofa_secret = None
|
||||
user.twofa_recovery_codes_hashes = None
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return
|
||||
@@ -293,10 +301,78 @@ async def disable_twofa(db: AsyncSession, user: User, *, code: str) -> None:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid 2FA code")
|
||||
user.twofa_enabled = False
|
||||
user.twofa_secret = None
|
||||
user.twofa_recovery_codes_hashes = None
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
|
||||
def _normalize_recovery_code(code: str) -> str:
|
||||
return "".join(ch for ch in code.upper() if ch.isalnum())
|
||||
|
||||
|
||||
def _generate_recovery_codes(count: int = 8) -> list[str]:
|
||||
alphabet = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
||||
codes: list[str] = []
|
||||
for _ in range(max(1, count)):
|
||||
raw = "".join(secrets.choice(alphabet) for _ in range(10))
|
||||
codes.append(f"{raw[:5]}-{raw[5:]}")
|
||||
return codes
|
||||
|
||||
|
||||
def _load_twofa_recovery_hashes(user: User) -> list[str]:
|
||||
raw = user.twofa_recovery_codes_hashes
|
||||
if not raw:
|
||||
return []
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except Exception:
|
||||
return []
|
||||
if not isinstance(parsed, list):
|
||||
return []
|
||||
hashes: list[str] = []
|
||||
for item in parsed:
|
||||
if isinstance(item, str) and item:
|
||||
hashes.append(item)
|
||||
return hashes
|
||||
|
||||
|
||||
async def regenerate_twofa_recovery_codes(db: AsyncSession, user: User, *, code: str) -> list[str]:
|
||||
if not user.twofa_enabled or not user.twofa_secret:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled")
|
||||
if not verify_totp_code(user.twofa_secret, code):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid 2FA code")
|
||||
codes = _generate_recovery_codes()
|
||||
normalized = [_normalize_recovery_code(item) for item in codes]
|
||||
user.twofa_recovery_codes_hashes = json.dumps([hash_password(item) for item in normalized], ensure_ascii=True)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return codes
|
||||
|
||||
|
||||
def get_twofa_recovery_codes_remaining(user: User) -> int:
|
||||
return len(_load_twofa_recovery_hashes(user))
|
||||
|
||||
|
||||
async def _consume_twofa_recovery_code(db: AsyncSession, *, user: User, recovery_code: str) -> bool:
|
||||
prepared = _normalize_recovery_code(recovery_code)
|
||||
if not prepared:
|
||||
return False
|
||||
hashes = _load_twofa_recovery_hashes(user)
|
||||
if not hashes:
|
||||
return False
|
||||
match_index = -1
|
||||
for idx, code_hash in enumerate(hashes):
|
||||
if verify_password(prepared, code_hash):
|
||||
match_index = idx
|
||||
break
|
||||
if match_index < 0:
|
||||
return False
|
||||
del hashes[match_index]
|
||||
user.twofa_recovery_codes_hashes = json.dumps(hashes, ensure_ascii=True) if hashes else None
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def request_password_reset(
|
||||
db: AsyncSession,
|
||||
payload: RequestPasswordResetRequest,
|
||||
|
||||
Reference in New Issue
Block a user