feat(auth): add active sessions management
Some checks failed
CI / test (push) Failing after 33s

- store refresh session metadata in redis (ip/user-agent/created_at)

- add auth APIs: list sessions, revoke one, revoke all

- add web privacy UI for active sessions
This commit is contained in:
2026-03-08 11:41:03 +03:00
parent da73b79ee7
commit e685a38be6
7 changed files with 309 additions and 11 deletions

View File

@@ -1,3 +1,5 @@
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, Request, status
from sqlalchemy.ext.asyncio import AsyncSession
@@ -11,12 +13,17 @@ from app.auth.schemas import (
ResendVerificationRequest,
ResetPasswordRequest,
TokenResponse,
SessionRead,
VerifyEmailRequest,
)
from app.auth.service import (
get_current_user,
get_email_sender,
get_request_metadata,
login_user,
list_user_sessions,
revoke_all_user_sessions,
revoke_user_session,
refresh_tokens,
register_user,
request_password_reset,
@@ -56,7 +63,8 @@ async def login(payload: LoginRequest, request: Request, db: AsyncSession = Depe
scope="auth_login",
limit=settings.login_rate_limit_per_minute,
)
return await login_user(db, payload)
ip_address, user_agent = get_request_metadata(request)
return await login_user(db, payload, ip_address=ip_address, user_agent=user_agent)
@router.post("/refresh", response_model=TokenResponse)
@@ -70,7 +78,8 @@ async def refresh(
scope="auth_refresh",
limit=settings.refresh_rate_limit_per_minute,
)
return await refresh_tokens(db, payload)
ip_address, user_agent = get_request_metadata(request)
return await refresh_tokens(db, payload, ip_address=ip_address, user_agent=user_agent)
@router.post("/verify-email", response_model=MessageResponse)
@@ -120,3 +129,29 @@ async def reset_password_endpoint(payload: ResetPasswordRequest, db: AsyncSessio
@router.get("/me", response_model=AuthUserResponse)
async def me(current_user: User = Depends(get_current_user)) -> AuthUserResponse:
return current_user
@router.get("/sessions", response_model=list[SessionRead])
async def list_sessions(current_user: User = Depends(get_current_user)) -> list[SessionRead]:
sessions = await list_user_sessions(current_user.id)
out: list[SessionRead] = []
for item in sessions:
out.append(
SessionRead(
jti=item.jti,
created_at=datetime.fromtimestamp(item.created_at, tz=timezone.utc),
ip_address=item.ip_address,
user_agent=item.user_agent,
)
)
return out
@router.delete("/sessions/{jti}", status_code=status.HTTP_204_NO_CONTENT)
async def revoke_session(jti: str, current_user: User = Depends(get_current_user)) -> None:
await revoke_user_session(user_id=current_user.id, jti=jti)
@router.delete("/sessions", status_code=status.HTTP_204_NO_CONTENT)
async def revoke_all_sessions(current_user: User = Depends(get_current_user)) -> None:
await revoke_all_user_sessions(user_id=current_user.id)

View File

@@ -58,3 +58,10 @@ class AuthUserResponse(BaseModel):
email_verified: bool
created_at: datetime
updated_at: datetime
class SessionRead(BaseModel):
jti: str
created_at: datetime
ip_address: str | None = None
user_agent: str | None = None

View File

@@ -2,11 +2,19 @@ from datetime import datetime, timedelta, timezone
from uuid import uuid4
from fastapi import Depends, HTTPException, status
from fastapi import Request
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth import repository as auth_repository
from app.auth.token_store import get_refresh_token_user_id, revoke_refresh_token_jti, store_refresh_token_jti
from app.auth.token_store import (
RefreshSession,
get_refresh_token_user_id,
list_refresh_sessions_for_user,
revoke_all_refresh_sessions_for_user,
revoke_refresh_token_jti,
store_refresh_token_jti,
)
from app.auth.schemas import (
LoginRequest,
RefreshTokenRequest,
@@ -111,7 +119,13 @@ async def resend_verification_email(
await db.commit()
async def login_user(db: AsyncSession, payload: LoginRequest) -> TokenResponse:
async def login_user(
db: AsyncSession,
payload: LoginRequest,
*,
ip_address: str | None = None,
user_agent: str | None = None,
) -> TokenResponse:
user = await get_user_by_email(db, payload.email)
if not user or not verify_password(payload.password, user.password_hash):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
@@ -121,14 +135,26 @@ async def login_user(db: AsyncSession, payload: LoginRequest) -> TokenResponse:
refresh_jti = str(uuid4())
refresh_token = create_refresh_token(str(user.id), jti=refresh_jti)
await store_refresh_token_jti(user_id=user.id, jti=refresh_jti, ttl_seconds=_refresh_ttl_seconds())
await store_refresh_token_jti(
user_id=user.id,
jti=refresh_jti,
ttl_seconds=_refresh_ttl_seconds(),
ip_address=ip_address,
user_agent=user_agent,
)
return TokenResponse(
access_token=create_access_token(str(user.id)),
refresh_token=refresh_token,
)
async def refresh_tokens(db: AsyncSession, payload: RefreshTokenRequest) -> TokenResponse:
async def refresh_tokens(
db: AsyncSession,
payload: RefreshTokenRequest,
*,
ip_address: str | None = None,
user_agent: str | None = None,
) -> TokenResponse:
credentials_error = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
@@ -155,13 +181,40 @@ async def refresh_tokens(db: AsyncSession, payload: RefreshTokenRequest) -> Toke
await revoke_refresh_token_jti(jti=refresh_jti)
new_jti = str(uuid4())
await store_refresh_token_jti(user_id=int(user_id), jti=new_jti, ttl_seconds=_refresh_ttl_seconds())
await store_refresh_token_jti(
user_id=int(user_id),
jti=new_jti,
ttl_seconds=_refresh_ttl_seconds(),
ip_address=ip_address,
user_agent=user_agent,
)
return TokenResponse(
access_token=create_access_token(str(user_id)),
refresh_token=create_refresh_token(str(user_id), jti=new_jti),
)
async def list_user_sessions(user_id: int) -> list[RefreshSession]:
return await list_refresh_sessions_for_user(user_id=user_id)
async def revoke_user_session(*, user_id: int, jti: str) -> None:
active_user_id = await get_refresh_token_user_id(jti=jti)
if active_user_id is None or active_user_id != user_id:
return
await revoke_refresh_token_jti(jti=jti)
async def revoke_all_user_sessions(*, user_id: int) -> None:
await revoke_all_refresh_sessions_for_user(user_id=user_id)
def get_request_metadata(request: Request) -> tuple[str | None, str | None]:
ip_address = request.client.host if request.client else None
user_agent = request.headers.get("user-agent")
return ip_address, user_agent
async def request_password_reset(
db: AsyncSession,
payload: RequestPasswordResetRequest,

View File

@@ -1,10 +1,22 @@
import time
from dataclasses import dataclass
from redis.exceptions import RedisError
from app.utils.redis_client import get_redis_client
_fallback_tokens: dict[str, tuple[int, float]] = {}
_fallback_token_meta: dict[str, tuple[int, str | None, str | None, float, float]] = {}
_fallback_user_sessions: dict[int, set[str]] = {}
@dataclass(slots=True)
class RefreshSession:
jti: str
user_id: int
created_at: float
user_agent: str | None = None
ip_address: str | None = None
def _cleanup_fallback() -> None:
@@ -12,15 +24,46 @@ def _cleanup_fallback() -> None:
expired = [jti for jti, (_, exp_at) in _fallback_tokens.items() if exp_at <= now]
for jti in expired:
_fallback_tokens.pop(jti, None)
_fallback_token_meta.pop(jti, None)
for user_id, session_ids in list(_fallback_user_sessions.items()):
live = {jti for jti in session_ids if jti in _fallback_tokens}
if live:
_fallback_user_sessions[user_id] = live
else:
_fallback_user_sessions.pop(user_id, None)
async def store_refresh_token_jti(*, user_id: int, jti: str, ttl_seconds: int) -> None:
async def store_refresh_token_jti(
*,
user_id: int,
jti: str,
ttl_seconds: int,
ip_address: str | None = None,
user_agent: str | None = None,
) -> None:
created_at = time.time()
try:
redis = get_redis_client()
await redis.set(f"auth:refresh:{jti}", str(user_id), ex=ttl_seconds)
await redis.hset(
f"auth:refresh_meta:{jti}",
mapping={
"user_id": str(user_id),
"created_at": str(created_at),
"ip_address": ip_address or "",
"user_agent": user_agent or "",
},
)
await redis.expire(f"auth:refresh_meta:{jti}", ttl_seconds)
await redis.sadd(f"auth:user_refresh:{user_id}", jti)
await redis.expire(f"auth:user_refresh:{user_id}", ttl_seconds + 86400)
except RedisError:
_cleanup_fallback()
_fallback_tokens[jti] = (user_id, time.time() + ttl_seconds)
_fallback_token_meta[jti] = (user_id, ip_address, user_agent, created_at, time.time() + ttl_seconds)
if user_id not in _fallback_user_sessions:
_fallback_user_sessions[user_id] = set()
_fallback_user_sessions[user_id].add(jti)
async def get_refresh_token_user_id(*, jti: str) -> int | None:
@@ -41,6 +84,81 @@ async def get_refresh_token_user_id(*, jti: str) -> int | None:
async def revoke_refresh_token_jti(*, jti: str) -> None:
try:
redis = get_redis_client()
user_id = await redis.get(f"auth:refresh:{jti}")
await redis.delete(f"auth:refresh:{jti}")
await redis.delete(f"auth:refresh_meta:{jti}")
if user_id and str(user_id).isdigit():
await redis.srem(f"auth:user_refresh:{int(user_id)}", jti)
except RedisError:
user_info = _fallback_tokens.get(jti)
_fallback_tokens.pop(jti, None)
_fallback_token_meta.pop(jti, None)
if user_info:
user_id = user_info[0]
if user_id in _fallback_user_sessions:
_fallback_user_sessions[user_id].discard(jti)
if not _fallback_user_sessions[user_id]:
_fallback_user_sessions.pop(user_id, None)
async def list_refresh_sessions_for_user(*, user_id: int) -> list[RefreshSession]:
try:
redis = get_redis_client()
session_ids = await redis.smembers(f"auth:user_refresh:{user_id}")
sessions: list[RefreshSession] = []
stale: list[str] = []
for raw_jti in session_ids:
jti = raw_jti.decode("utf-8") if isinstance(raw_jti, bytes) else str(raw_jti)
owner = await redis.get(f"auth:refresh:{jti}")
if not owner or not str(owner).isdigit() or int(owner) != user_id:
stale.append(jti)
continue
meta = await redis.hgetall(f"auth:refresh_meta:{jti}")
created_at_raw = meta.get("created_at") if isinstance(meta, dict) else None
if isinstance(created_at_raw, bytes):
created_at_raw = created_at_raw.decode("utf-8")
created_at = float(created_at_raw) if created_at_raw else time.time()
ip_raw = meta.get("ip_address") if isinstance(meta, dict) else ""
ua_raw = meta.get("user_agent") if isinstance(meta, dict) else ""
ip_address = ip_raw.decode("utf-8") if isinstance(ip_raw, bytes) else (str(ip_raw) if ip_raw else None)
user_agent = ua_raw.decode("utf-8") if isinstance(ua_raw, bytes) else (str(ua_raw) if ua_raw else None)
sessions.append(
RefreshSession(
jti=jti,
user_id=user_id,
created_at=created_at,
ip_address=ip_address or None,
user_agent=user_agent or None,
)
)
for jti in stale:
await redis.srem(f"auth:user_refresh:{user_id}", jti)
sessions.sort(key=lambda item: item.created_at, reverse=True)
return sessions
except RedisError:
_cleanup_fallback()
items = []
for jti in _fallback_user_sessions.get(user_id, set()):
info = _fallback_token_meta.get(jti)
if not info:
continue
uid, ip_address, user_agent, created_at, _exp_at = info
if uid != user_id:
continue
items.append(
RefreshSession(
jti=jti,
user_id=user_id,
created_at=created_at,
ip_address=ip_address,
user_agent=user_agent,
)
)
items.sort(key=lambda item: item.created_at, reverse=True)
return items
async def revoke_all_refresh_sessions_for_user(*, user_id: int) -> None:
sessions = await list_refresh_sessions_for_user(user_id=user_id)
for session in sessions:
await revoke_refresh_token_jti(jti=session.jti)