- store refresh session metadata in redis (ip/user-agent/created_at) - add auth APIs: list sessions, revoke one, revoke all - add web privacy UI for active sessions
This commit is contained in:
@@ -1,10 +1,22 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from app.utils.redis_client import get_redis_client
|
||||
|
||||
_fallback_tokens: dict[str, tuple[int, float]] = {}
|
||||
_fallback_token_meta: dict[str, tuple[int, str | None, str | None, float, float]] = {}
|
||||
_fallback_user_sessions: dict[int, set[str]] = {}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RefreshSession:
|
||||
jti: str
|
||||
user_id: int
|
||||
created_at: float
|
||||
user_agent: str | None = None
|
||||
ip_address: str | None = None
|
||||
|
||||
|
||||
def _cleanup_fallback() -> None:
|
||||
@@ -12,15 +24,46 @@ def _cleanup_fallback() -> None:
|
||||
expired = [jti for jti, (_, exp_at) in _fallback_tokens.items() if exp_at <= now]
|
||||
for jti in expired:
|
||||
_fallback_tokens.pop(jti, None)
|
||||
_fallback_token_meta.pop(jti, None)
|
||||
for user_id, session_ids in list(_fallback_user_sessions.items()):
|
||||
live = {jti for jti in session_ids if jti in _fallback_tokens}
|
||||
if live:
|
||||
_fallback_user_sessions[user_id] = live
|
||||
else:
|
||||
_fallback_user_sessions.pop(user_id, None)
|
||||
|
||||
|
||||
async def store_refresh_token_jti(*, user_id: int, jti: str, ttl_seconds: int) -> None:
|
||||
async def store_refresh_token_jti(
|
||||
*,
|
||||
user_id: int,
|
||||
jti: str,
|
||||
ttl_seconds: int,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> None:
|
||||
created_at = time.time()
|
||||
try:
|
||||
redis = get_redis_client()
|
||||
await redis.set(f"auth:refresh:{jti}", str(user_id), ex=ttl_seconds)
|
||||
await redis.hset(
|
||||
f"auth:refresh_meta:{jti}",
|
||||
mapping={
|
||||
"user_id": str(user_id),
|
||||
"created_at": str(created_at),
|
||||
"ip_address": ip_address or "",
|
||||
"user_agent": user_agent or "",
|
||||
},
|
||||
)
|
||||
await redis.expire(f"auth:refresh_meta:{jti}", ttl_seconds)
|
||||
await redis.sadd(f"auth:user_refresh:{user_id}", jti)
|
||||
await redis.expire(f"auth:user_refresh:{user_id}", ttl_seconds + 86400)
|
||||
except RedisError:
|
||||
_cleanup_fallback()
|
||||
_fallback_tokens[jti] = (user_id, time.time() + ttl_seconds)
|
||||
_fallback_token_meta[jti] = (user_id, ip_address, user_agent, created_at, time.time() + ttl_seconds)
|
||||
if user_id not in _fallback_user_sessions:
|
||||
_fallback_user_sessions[user_id] = set()
|
||||
_fallback_user_sessions[user_id].add(jti)
|
||||
|
||||
|
||||
async def get_refresh_token_user_id(*, jti: str) -> int | None:
|
||||
@@ -41,6 +84,81 @@ async def get_refresh_token_user_id(*, jti: str) -> int | None:
|
||||
async def revoke_refresh_token_jti(*, jti: str) -> None:
|
||||
try:
|
||||
redis = get_redis_client()
|
||||
user_id = await redis.get(f"auth:refresh:{jti}")
|
||||
await redis.delete(f"auth:refresh:{jti}")
|
||||
await redis.delete(f"auth:refresh_meta:{jti}")
|
||||
if user_id and str(user_id).isdigit():
|
||||
await redis.srem(f"auth:user_refresh:{int(user_id)}", jti)
|
||||
except RedisError:
|
||||
user_info = _fallback_tokens.get(jti)
|
||||
_fallback_tokens.pop(jti, None)
|
||||
_fallback_token_meta.pop(jti, None)
|
||||
if user_info:
|
||||
user_id = user_info[0]
|
||||
if user_id in _fallback_user_sessions:
|
||||
_fallback_user_sessions[user_id].discard(jti)
|
||||
if not _fallback_user_sessions[user_id]:
|
||||
_fallback_user_sessions.pop(user_id, None)
|
||||
|
||||
|
||||
async def list_refresh_sessions_for_user(*, user_id: int) -> list[RefreshSession]:
|
||||
try:
|
||||
redis = get_redis_client()
|
||||
session_ids = await redis.smembers(f"auth:user_refresh:{user_id}")
|
||||
sessions: list[RefreshSession] = []
|
||||
stale: list[str] = []
|
||||
for raw_jti in session_ids:
|
||||
jti = raw_jti.decode("utf-8") if isinstance(raw_jti, bytes) else str(raw_jti)
|
||||
owner = await redis.get(f"auth:refresh:{jti}")
|
||||
if not owner or not str(owner).isdigit() or int(owner) != user_id:
|
||||
stale.append(jti)
|
||||
continue
|
||||
meta = await redis.hgetall(f"auth:refresh_meta:{jti}")
|
||||
created_at_raw = meta.get("created_at") if isinstance(meta, dict) else None
|
||||
if isinstance(created_at_raw, bytes):
|
||||
created_at_raw = created_at_raw.decode("utf-8")
|
||||
created_at = float(created_at_raw) if created_at_raw else time.time()
|
||||
ip_raw = meta.get("ip_address") if isinstance(meta, dict) else ""
|
||||
ua_raw = meta.get("user_agent") if isinstance(meta, dict) else ""
|
||||
ip_address = ip_raw.decode("utf-8") if isinstance(ip_raw, bytes) else (str(ip_raw) if ip_raw else None)
|
||||
user_agent = ua_raw.decode("utf-8") if isinstance(ua_raw, bytes) else (str(ua_raw) if ua_raw else None)
|
||||
sessions.append(
|
||||
RefreshSession(
|
||||
jti=jti,
|
||||
user_id=user_id,
|
||||
created_at=created_at,
|
||||
ip_address=ip_address or None,
|
||||
user_agent=user_agent or None,
|
||||
)
|
||||
)
|
||||
for jti in stale:
|
||||
await redis.srem(f"auth:user_refresh:{user_id}", jti)
|
||||
sessions.sort(key=lambda item: item.created_at, reverse=True)
|
||||
return sessions
|
||||
except RedisError:
|
||||
_cleanup_fallback()
|
||||
items = []
|
||||
for jti in _fallback_user_sessions.get(user_id, set()):
|
||||
info = _fallback_token_meta.get(jti)
|
||||
if not info:
|
||||
continue
|
||||
uid, ip_address, user_agent, created_at, _exp_at = info
|
||||
if uid != user_id:
|
||||
continue
|
||||
items.append(
|
||||
RefreshSession(
|
||||
jti=jti,
|
||||
user_id=user_id,
|
||||
created_at=created_at,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
)
|
||||
items.sort(key=lambda item: item.created_at, reverse=True)
|
||||
return items
|
||||
|
||||
|
||||
async def revoke_all_refresh_sessions_for_user(*, user_id: int) -> None:
|
||||
sessions = await list_refresh_sessions_for_user(user_id=user_id)
|
||||
for session in sessions:
|
||||
await revoke_refresh_token_jti(jti=session.jti)
|
||||
|
||||
Reference in New Issue
Block a user