- store refresh session metadata in redis (ip/user-agent/created_at) - add auth APIs: list sessions, revoke one, revoke all - add web privacy UI for active sessions
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, Request, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -11,12 +13,17 @@ from app.auth.schemas import (
|
||||
ResendVerificationRequest,
|
||||
ResetPasswordRequest,
|
||||
TokenResponse,
|
||||
SessionRead,
|
||||
VerifyEmailRequest,
|
||||
)
|
||||
from app.auth.service import (
|
||||
get_current_user,
|
||||
get_email_sender,
|
||||
get_request_metadata,
|
||||
login_user,
|
||||
list_user_sessions,
|
||||
revoke_all_user_sessions,
|
||||
revoke_user_session,
|
||||
refresh_tokens,
|
||||
register_user,
|
||||
request_password_reset,
|
||||
@@ -56,7 +63,8 @@ async def login(payload: LoginRequest, request: Request, db: AsyncSession = Depe
|
||||
scope="auth_login",
|
||||
limit=settings.login_rate_limit_per_minute,
|
||||
)
|
||||
return await login_user(db, payload)
|
||||
ip_address, user_agent = get_request_metadata(request)
|
||||
return await login_user(db, payload, ip_address=ip_address, user_agent=user_agent)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
@@ -70,7 +78,8 @@ async def refresh(
|
||||
scope="auth_refresh",
|
||||
limit=settings.refresh_rate_limit_per_minute,
|
||||
)
|
||||
return await refresh_tokens(db, payload)
|
||||
ip_address, user_agent = get_request_metadata(request)
|
||||
return await refresh_tokens(db, payload, ip_address=ip_address, user_agent=user_agent)
|
||||
|
||||
|
||||
@router.post("/verify-email", response_model=MessageResponse)
|
||||
@@ -120,3 +129,29 @@ async def reset_password_endpoint(payload: ResetPasswordRequest, db: AsyncSessio
|
||||
@router.get("/me", response_model=AuthUserResponse)
|
||||
async def me(current_user: User = Depends(get_current_user)) -> AuthUserResponse:
|
||||
return current_user
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=list[SessionRead])
|
||||
async def list_sessions(current_user: User = Depends(get_current_user)) -> list[SessionRead]:
|
||||
sessions = await list_user_sessions(current_user.id)
|
||||
out: list[SessionRead] = []
|
||||
for item in sessions:
|
||||
out.append(
|
||||
SessionRead(
|
||||
jti=item.jti,
|
||||
created_at=datetime.fromtimestamp(item.created_at, tz=timezone.utc),
|
||||
ip_address=item.ip_address,
|
||||
user_agent=item.user_agent,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@router.delete("/sessions/{jti}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def revoke_session(jti: str, current_user: User = Depends(get_current_user)) -> None:
|
||||
await revoke_user_session(user_id=current_user.id, jti=jti)
|
||||
|
||||
|
||||
@router.delete("/sessions", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def revoke_all_sessions(current_user: User = Depends(get_current_user)) -> None:
|
||||
await revoke_all_user_sessions(user_id=current_user.id)
|
||||
|
||||
@@ -58,3 +58,10 @@ class AuthUserResponse(BaseModel):
|
||||
email_verified: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class SessionRead(BaseModel):
|
||||
jti: str
|
||||
created_at: datetime
|
||||
ip_address: str | None = None
|
||||
user_agent: str | None = None
|
||||
|
||||
@@ -2,11 +2,19 @@ from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi import Request
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth import repository as auth_repository
|
||||
from app.auth.token_store import get_refresh_token_user_id, revoke_refresh_token_jti, store_refresh_token_jti
|
||||
from app.auth.token_store import (
|
||||
RefreshSession,
|
||||
get_refresh_token_user_id,
|
||||
list_refresh_sessions_for_user,
|
||||
revoke_all_refresh_sessions_for_user,
|
||||
revoke_refresh_token_jti,
|
||||
store_refresh_token_jti,
|
||||
)
|
||||
from app.auth.schemas import (
|
||||
LoginRequest,
|
||||
RefreshTokenRequest,
|
||||
@@ -111,7 +119,13 @@ async def resend_verification_email(
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def login_user(db: AsyncSession, payload: LoginRequest) -> TokenResponse:
|
||||
async def login_user(
|
||||
db: AsyncSession,
|
||||
payload: LoginRequest,
|
||||
*,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> TokenResponse:
|
||||
user = await get_user_by_email(db, payload.email)
|
||||
if not user or not verify_password(payload.password, user.password_hash):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
||||
@@ -121,14 +135,26 @@ async def login_user(db: AsyncSession, payload: LoginRequest) -> TokenResponse:
|
||||
|
||||
refresh_jti = str(uuid4())
|
||||
refresh_token = create_refresh_token(str(user.id), jti=refresh_jti)
|
||||
await store_refresh_token_jti(user_id=user.id, jti=refresh_jti, ttl_seconds=_refresh_ttl_seconds())
|
||||
await store_refresh_token_jti(
|
||||
user_id=user.id,
|
||||
jti=refresh_jti,
|
||||
ttl_seconds=_refresh_ttl_seconds(),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
return TokenResponse(
|
||||
access_token=create_access_token(str(user.id)),
|
||||
refresh_token=refresh_token,
|
||||
)
|
||||
|
||||
|
||||
async def refresh_tokens(db: AsyncSession, payload: RefreshTokenRequest) -> TokenResponse:
|
||||
async def refresh_tokens(
|
||||
db: AsyncSession,
|
||||
payload: RefreshTokenRequest,
|
||||
*,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> TokenResponse:
|
||||
credentials_error = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
@@ -155,13 +181,40 @@ async def refresh_tokens(db: AsyncSession, payload: RefreshTokenRequest) -> Toke
|
||||
|
||||
await revoke_refresh_token_jti(jti=refresh_jti)
|
||||
new_jti = str(uuid4())
|
||||
await store_refresh_token_jti(user_id=int(user_id), jti=new_jti, ttl_seconds=_refresh_ttl_seconds())
|
||||
await store_refresh_token_jti(
|
||||
user_id=int(user_id),
|
||||
jti=new_jti,
|
||||
ttl_seconds=_refresh_ttl_seconds(),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
return TokenResponse(
|
||||
access_token=create_access_token(str(user_id)),
|
||||
refresh_token=create_refresh_token(str(user_id), jti=new_jti),
|
||||
)
|
||||
|
||||
|
||||
async def list_user_sessions(user_id: int) -> list[RefreshSession]:
|
||||
return await list_refresh_sessions_for_user(user_id=user_id)
|
||||
|
||||
|
||||
async def revoke_user_session(*, user_id: int, jti: str) -> None:
|
||||
active_user_id = await get_refresh_token_user_id(jti=jti)
|
||||
if active_user_id is None or active_user_id != user_id:
|
||||
return
|
||||
await revoke_refresh_token_jti(jti=jti)
|
||||
|
||||
|
||||
async def revoke_all_user_sessions(*, user_id: int) -> None:
|
||||
await revoke_all_refresh_sessions_for_user(user_id=user_id)
|
||||
|
||||
|
||||
def get_request_metadata(request: Request) -> tuple[str | None, str | None]:
|
||||
ip_address = request.client.host if request.client else None
|
||||
user_agent = request.headers.get("user-agent")
|
||||
return ip_address, user_agent
|
||||
|
||||
|
||||
async def request_password_reset(
|
||||
db: AsyncSession,
|
||||
payload: RequestPasswordResetRequest,
|
||||
|
||||
@@ -1,10 +1,22 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from app.utils.redis_client import get_redis_client
|
||||
|
||||
_fallback_tokens: dict[str, tuple[int, float]] = {}
|
||||
_fallback_token_meta: dict[str, tuple[int, str | None, str | None, float, float]] = {}
|
||||
_fallback_user_sessions: dict[int, set[str]] = {}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RefreshSession:
|
||||
jti: str
|
||||
user_id: int
|
||||
created_at: float
|
||||
user_agent: str | None = None
|
||||
ip_address: str | None = None
|
||||
|
||||
|
||||
def _cleanup_fallback() -> None:
|
||||
@@ -12,15 +24,46 @@ def _cleanup_fallback() -> None:
|
||||
expired = [jti for jti, (_, exp_at) in _fallback_tokens.items() if exp_at <= now]
|
||||
for jti in expired:
|
||||
_fallback_tokens.pop(jti, None)
|
||||
_fallback_token_meta.pop(jti, None)
|
||||
for user_id, session_ids in list(_fallback_user_sessions.items()):
|
||||
live = {jti for jti in session_ids if jti in _fallback_tokens}
|
||||
if live:
|
||||
_fallback_user_sessions[user_id] = live
|
||||
else:
|
||||
_fallback_user_sessions.pop(user_id, None)
|
||||
|
||||
|
||||
async def store_refresh_token_jti(*, user_id: int, jti: str, ttl_seconds: int) -> None:
|
||||
async def store_refresh_token_jti(
|
||||
*,
|
||||
user_id: int,
|
||||
jti: str,
|
||||
ttl_seconds: int,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> None:
|
||||
created_at = time.time()
|
||||
try:
|
||||
redis = get_redis_client()
|
||||
await redis.set(f"auth:refresh:{jti}", str(user_id), ex=ttl_seconds)
|
||||
await redis.hset(
|
||||
f"auth:refresh_meta:{jti}",
|
||||
mapping={
|
||||
"user_id": str(user_id),
|
||||
"created_at": str(created_at),
|
||||
"ip_address": ip_address or "",
|
||||
"user_agent": user_agent or "",
|
||||
},
|
||||
)
|
||||
await redis.expire(f"auth:refresh_meta:{jti}", ttl_seconds)
|
||||
await redis.sadd(f"auth:user_refresh:{user_id}", jti)
|
||||
await redis.expire(f"auth:user_refresh:{user_id}", ttl_seconds + 86400)
|
||||
except RedisError:
|
||||
_cleanup_fallback()
|
||||
_fallback_tokens[jti] = (user_id, time.time() + ttl_seconds)
|
||||
_fallback_token_meta[jti] = (user_id, ip_address, user_agent, created_at, time.time() + ttl_seconds)
|
||||
if user_id not in _fallback_user_sessions:
|
||||
_fallback_user_sessions[user_id] = set()
|
||||
_fallback_user_sessions[user_id].add(jti)
|
||||
|
||||
|
||||
async def get_refresh_token_user_id(*, jti: str) -> int | None:
|
||||
@@ -41,6 +84,81 @@ async def get_refresh_token_user_id(*, jti: str) -> int | None:
|
||||
async def revoke_refresh_token_jti(*, jti: str) -> None:
|
||||
try:
|
||||
redis = get_redis_client()
|
||||
user_id = await redis.get(f"auth:refresh:{jti}")
|
||||
await redis.delete(f"auth:refresh:{jti}")
|
||||
await redis.delete(f"auth:refresh_meta:{jti}")
|
||||
if user_id and str(user_id).isdigit():
|
||||
await redis.srem(f"auth:user_refresh:{int(user_id)}", jti)
|
||||
except RedisError:
|
||||
user_info = _fallback_tokens.get(jti)
|
||||
_fallback_tokens.pop(jti, None)
|
||||
_fallback_token_meta.pop(jti, None)
|
||||
if user_info:
|
||||
user_id = user_info[0]
|
||||
if user_id in _fallback_user_sessions:
|
||||
_fallback_user_sessions[user_id].discard(jti)
|
||||
if not _fallback_user_sessions[user_id]:
|
||||
_fallback_user_sessions.pop(user_id, None)
|
||||
|
||||
|
||||
async def list_refresh_sessions_for_user(*, user_id: int) -> list[RefreshSession]:
|
||||
try:
|
||||
redis = get_redis_client()
|
||||
session_ids = await redis.smembers(f"auth:user_refresh:{user_id}")
|
||||
sessions: list[RefreshSession] = []
|
||||
stale: list[str] = []
|
||||
for raw_jti in session_ids:
|
||||
jti = raw_jti.decode("utf-8") if isinstance(raw_jti, bytes) else str(raw_jti)
|
||||
owner = await redis.get(f"auth:refresh:{jti}")
|
||||
if not owner or not str(owner).isdigit() or int(owner) != user_id:
|
||||
stale.append(jti)
|
||||
continue
|
||||
meta = await redis.hgetall(f"auth:refresh_meta:{jti}")
|
||||
created_at_raw = meta.get("created_at") if isinstance(meta, dict) else None
|
||||
if isinstance(created_at_raw, bytes):
|
||||
created_at_raw = created_at_raw.decode("utf-8")
|
||||
created_at = float(created_at_raw) if created_at_raw else time.time()
|
||||
ip_raw = meta.get("ip_address") if isinstance(meta, dict) else ""
|
||||
ua_raw = meta.get("user_agent") if isinstance(meta, dict) else ""
|
||||
ip_address = ip_raw.decode("utf-8") if isinstance(ip_raw, bytes) else (str(ip_raw) if ip_raw else None)
|
||||
user_agent = ua_raw.decode("utf-8") if isinstance(ua_raw, bytes) else (str(ua_raw) if ua_raw else None)
|
||||
sessions.append(
|
||||
RefreshSession(
|
||||
jti=jti,
|
||||
user_id=user_id,
|
||||
created_at=created_at,
|
||||
ip_address=ip_address or None,
|
||||
user_agent=user_agent or None,
|
||||
)
|
||||
)
|
||||
for jti in stale:
|
||||
await redis.srem(f"auth:user_refresh:{user_id}", jti)
|
||||
sessions.sort(key=lambda item: item.created_at, reverse=True)
|
||||
return sessions
|
||||
except RedisError:
|
||||
_cleanup_fallback()
|
||||
items = []
|
||||
for jti in _fallback_user_sessions.get(user_id, set()):
|
||||
info = _fallback_token_meta.get(jti)
|
||||
if not info:
|
||||
continue
|
||||
uid, ip_address, user_agent, created_at, _exp_at = info
|
||||
if uid != user_id:
|
||||
continue
|
||||
items.append(
|
||||
RefreshSession(
|
||||
jti=jti,
|
||||
user_id=user_id,
|
||||
created_at=created_at,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
)
|
||||
items.sort(key=lambda item: item.created_at, reverse=True)
|
||||
return items
|
||||
|
||||
|
||||
async def revoke_all_refresh_sessions_for_user(*, user_id: int) -> None:
|
||||
sessions = await list_refresh_sessions_for_user(user_id=user_id)
|
||||
for session in sessions:
|
||||
await revoke_refresh_token_jti(jti=session.jti)
|
||||
|
||||
Reference in New Issue
Block a user