Some checks failed
CI / test (push) Failing after 18s
- move voice/audio players to single global audio engine with shared volume - stop/reset previous track when switching to another media - keep playback alive across chat switches via global audio element - list refresh sessions by redis scan fallback when user session set is missing
185 lines
7.1 KiB
Python
185 lines
7.1 KiB
Python
import time
|
|
from dataclasses import dataclass
|
|
|
|
from redis.exceptions import RedisError
|
|
|
|
from app.utils.redis_client import get_redis_client
|
|
|
|
_fallback_tokens: dict[str, tuple[int, float]] = {}
|
|
_fallback_token_meta: dict[str, tuple[int, str | None, str | None, float, float]] = {}
|
|
_fallback_user_sessions: dict[int, set[str]] = {}
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class RefreshSession:
|
|
jti: str
|
|
user_id: int
|
|
created_at: float
|
|
user_agent: str | None = None
|
|
ip_address: str | None = None
|
|
|
|
|
|
def _cleanup_fallback() -> None:
|
|
now = time.time()
|
|
expired = [jti for jti, (_, exp_at) in _fallback_tokens.items() if exp_at <= now]
|
|
for jti in expired:
|
|
_fallback_tokens.pop(jti, None)
|
|
_fallback_token_meta.pop(jti, None)
|
|
for user_id, session_ids in list(_fallback_user_sessions.items()):
|
|
live = {jti for jti in session_ids if jti in _fallback_tokens}
|
|
if live:
|
|
_fallback_user_sessions[user_id] = live
|
|
else:
|
|
_fallback_user_sessions.pop(user_id, None)
|
|
|
|
|
|
async def store_refresh_token_jti(
|
|
*,
|
|
user_id: int,
|
|
jti: str,
|
|
ttl_seconds: int,
|
|
ip_address: str | None = None,
|
|
user_agent: str | None = None,
|
|
) -> None:
|
|
created_at = time.time()
|
|
try:
|
|
redis = get_redis_client()
|
|
await redis.set(f"auth:refresh:{jti}", str(user_id), ex=ttl_seconds)
|
|
await redis.hset(
|
|
f"auth:refresh_meta:{jti}",
|
|
mapping={
|
|
"user_id": str(user_id),
|
|
"created_at": str(created_at),
|
|
"ip_address": ip_address or "",
|
|
"user_agent": user_agent or "",
|
|
},
|
|
)
|
|
await redis.expire(f"auth:refresh_meta:{jti}", ttl_seconds)
|
|
await redis.sadd(f"auth:user_refresh:{user_id}", jti)
|
|
await redis.expire(f"auth:user_refresh:{user_id}", ttl_seconds + 86400)
|
|
except RedisError:
|
|
_cleanup_fallback()
|
|
_fallback_tokens[jti] = (user_id, time.time() + ttl_seconds)
|
|
_fallback_token_meta[jti] = (user_id, ip_address, user_agent, created_at, time.time() + ttl_seconds)
|
|
if user_id not in _fallback_user_sessions:
|
|
_fallback_user_sessions[user_id] = set()
|
|
_fallback_user_sessions[user_id].add(jti)
|
|
|
|
|
|
async def get_refresh_token_user_id(*, jti: str) -> int | None:
|
|
try:
|
|
redis = get_redis_client()
|
|
value = await redis.get(f"auth:refresh:{jti}")
|
|
if not value or not str(value).isdigit():
|
|
return None
|
|
return int(value)
|
|
except RedisError:
|
|
_cleanup_fallback()
|
|
data = _fallback_tokens.get(jti)
|
|
if not data:
|
|
return None
|
|
return data[0]
|
|
|
|
|
|
async def revoke_refresh_token_jti(*, jti: str) -> None:
|
|
try:
|
|
redis = get_redis_client()
|
|
user_id = await redis.get(f"auth:refresh:{jti}")
|
|
await redis.delete(f"auth:refresh:{jti}")
|
|
await redis.delete(f"auth:refresh_meta:{jti}")
|
|
if user_id and str(user_id).isdigit():
|
|
await redis.srem(f"auth:user_refresh:{int(user_id)}", jti)
|
|
except RedisError:
|
|
user_info = _fallback_tokens.get(jti)
|
|
_fallback_tokens.pop(jti, None)
|
|
_fallback_token_meta.pop(jti, None)
|
|
if user_info:
|
|
user_id = user_info[0]
|
|
if user_id in _fallback_user_sessions:
|
|
_fallback_user_sessions[user_id].discard(jti)
|
|
if not _fallback_user_sessions[user_id]:
|
|
_fallback_user_sessions.pop(user_id, None)
|
|
|
|
|
|
async def list_refresh_sessions_for_user(*, user_id: int) -> list[RefreshSession]:
|
|
try:
|
|
redis = get_redis_client()
|
|
session_ids = await redis.smembers(f"auth:user_refresh:{user_id}")
|
|
if not session_ids:
|
|
cursor = 0
|
|
discovered: list[str] = []
|
|
while True:
|
|
cursor, keys = await redis.scan(cursor=cursor, match="auth:refresh:*", count=200)
|
|
for raw_key in keys:
|
|
key = raw_key.decode("utf-8") if isinstance(raw_key, bytes) else str(raw_key)
|
|
if not key.startswith("auth:refresh:"):
|
|
continue
|
|
jti = key.removeprefix("auth:refresh:")
|
|
if not jti:
|
|
continue
|
|
owner = await redis.get(key)
|
|
if owner and str(owner).isdigit() and int(owner) == user_id:
|
|
discovered.append(jti)
|
|
if cursor == 0:
|
|
break
|
|
if discovered:
|
|
await redis.sadd(f"auth:user_refresh:{user_id}", *discovered)
|
|
session_ids = {item.encode("utf-8") for item in discovered}
|
|
sessions: list[RefreshSession] = []
|
|
stale: list[str] = []
|
|
for raw_jti in session_ids:
|
|
jti = raw_jti.decode("utf-8") if isinstance(raw_jti, bytes) else str(raw_jti)
|
|
owner = await redis.get(f"auth:refresh:{jti}")
|
|
if not owner or not str(owner).isdigit() or int(owner) != user_id:
|
|
stale.append(jti)
|
|
continue
|
|
meta = await redis.hgetall(f"auth:refresh_meta:{jti}")
|
|
created_at_raw = meta.get("created_at") if isinstance(meta, dict) else None
|
|
if isinstance(created_at_raw, bytes):
|
|
created_at_raw = created_at_raw.decode("utf-8")
|
|
created_at = float(created_at_raw) if created_at_raw else time.time()
|
|
ip_raw = meta.get("ip_address") if isinstance(meta, dict) else ""
|
|
ua_raw = meta.get("user_agent") if isinstance(meta, dict) else ""
|
|
ip_address = ip_raw.decode("utf-8") if isinstance(ip_raw, bytes) else (str(ip_raw) if ip_raw else None)
|
|
user_agent = ua_raw.decode("utf-8") if isinstance(ua_raw, bytes) else (str(ua_raw) if ua_raw else None)
|
|
sessions.append(
|
|
RefreshSession(
|
|
jti=jti,
|
|
user_id=user_id,
|
|
created_at=created_at,
|
|
ip_address=ip_address or None,
|
|
user_agent=user_agent or None,
|
|
)
|
|
)
|
|
for jti in stale:
|
|
await redis.srem(f"auth:user_refresh:{user_id}", jti)
|
|
sessions.sort(key=lambda item: item.created_at, reverse=True)
|
|
return sessions
|
|
except RedisError:
|
|
_cleanup_fallback()
|
|
items = []
|
|
for jti in _fallback_user_sessions.get(user_id, set()):
|
|
info = _fallback_token_meta.get(jti)
|
|
if not info:
|
|
continue
|
|
uid, ip_address, user_agent, created_at, _exp_at = info
|
|
if uid != user_id:
|
|
continue
|
|
items.append(
|
|
RefreshSession(
|
|
jti=jti,
|
|
user_id=user_id,
|
|
created_at=created_at,
|
|
ip_address=ip_address,
|
|
user_agent=user_agent,
|
|
)
|
|
)
|
|
items.sort(key=lambda item: item.created_at, reverse=True)
|
|
return items
|
|
|
|
|
|
async def revoke_all_refresh_sessions_for_user(*, user_id: int) -> None:
|
|
sessions = await list_refresh_sessions_for_user(user_id=user_id)
|
|
for session in sessions:
|
|
await revoke_refresh_token_jti(jti=session.jti)
|