Files
Messenger/app/auth/token_store.py
benya 897defc39d
Some checks failed
CI / test (push) Failing after 18s
fix(audio,sessions): unify audio playback state and improve session discovery
- move voice/audio players to single global audio engine with shared volume

- stop/reset previous track when switching to another media

- keep playback alive across chat switches via global audio element

- list refresh sessions by redis scan fallback when user session set is missing
2026-03-08 11:48:13 +03:00

185 lines
7.1 KiB
Python

import time
from dataclasses import dataclass
from redis.exceptions import RedisError
from app.utils.redis_client import get_redis_client
_fallback_tokens: dict[str, tuple[int, float]] = {}
_fallback_token_meta: dict[str, tuple[int, str | None, str | None, float, float]] = {}
_fallback_user_sessions: dict[int, set[str]] = {}
@dataclass(slots=True)
class RefreshSession:
jti: str
user_id: int
created_at: float
user_agent: str | None = None
ip_address: str | None = None
def _cleanup_fallback() -> None:
now = time.time()
expired = [jti for jti, (_, exp_at) in _fallback_tokens.items() if exp_at <= now]
for jti in expired:
_fallback_tokens.pop(jti, None)
_fallback_token_meta.pop(jti, None)
for user_id, session_ids in list(_fallback_user_sessions.items()):
live = {jti for jti in session_ids if jti in _fallback_tokens}
if live:
_fallback_user_sessions[user_id] = live
else:
_fallback_user_sessions.pop(user_id, None)
async def store_refresh_token_jti(
*,
user_id: int,
jti: str,
ttl_seconds: int,
ip_address: str | None = None,
user_agent: str | None = None,
) -> None:
created_at = time.time()
try:
redis = get_redis_client()
await redis.set(f"auth:refresh:{jti}", str(user_id), ex=ttl_seconds)
await redis.hset(
f"auth:refresh_meta:{jti}",
mapping={
"user_id": str(user_id),
"created_at": str(created_at),
"ip_address": ip_address or "",
"user_agent": user_agent or "",
},
)
await redis.expire(f"auth:refresh_meta:{jti}", ttl_seconds)
await redis.sadd(f"auth:user_refresh:{user_id}", jti)
await redis.expire(f"auth:user_refresh:{user_id}", ttl_seconds + 86400)
except RedisError:
_cleanup_fallback()
_fallback_tokens[jti] = (user_id, time.time() + ttl_seconds)
_fallback_token_meta[jti] = (user_id, ip_address, user_agent, created_at, time.time() + ttl_seconds)
if user_id not in _fallback_user_sessions:
_fallback_user_sessions[user_id] = set()
_fallback_user_sessions[user_id].add(jti)
async def get_refresh_token_user_id(*, jti: str) -> int | None:
try:
redis = get_redis_client()
value = await redis.get(f"auth:refresh:{jti}")
if not value or not str(value).isdigit():
return None
return int(value)
except RedisError:
_cleanup_fallback()
data = _fallback_tokens.get(jti)
if not data:
return None
return data[0]
async def revoke_refresh_token_jti(*, jti: str) -> None:
try:
redis = get_redis_client()
user_id = await redis.get(f"auth:refresh:{jti}")
await redis.delete(f"auth:refresh:{jti}")
await redis.delete(f"auth:refresh_meta:{jti}")
if user_id and str(user_id).isdigit():
await redis.srem(f"auth:user_refresh:{int(user_id)}", jti)
except RedisError:
user_info = _fallback_tokens.get(jti)
_fallback_tokens.pop(jti, None)
_fallback_token_meta.pop(jti, None)
if user_info:
user_id = user_info[0]
if user_id in _fallback_user_sessions:
_fallback_user_sessions[user_id].discard(jti)
if not _fallback_user_sessions[user_id]:
_fallback_user_sessions.pop(user_id, None)
async def list_refresh_sessions_for_user(*, user_id: int) -> list[RefreshSession]:
try:
redis = get_redis_client()
session_ids = await redis.smembers(f"auth:user_refresh:{user_id}")
if not session_ids:
cursor = 0
discovered: list[str] = []
while True:
cursor, keys = await redis.scan(cursor=cursor, match="auth:refresh:*", count=200)
for raw_key in keys:
key = raw_key.decode("utf-8") if isinstance(raw_key, bytes) else str(raw_key)
if not key.startswith("auth:refresh:"):
continue
jti = key.removeprefix("auth:refresh:")
if not jti:
continue
owner = await redis.get(key)
if owner and str(owner).isdigit() and int(owner) == user_id:
discovered.append(jti)
if cursor == 0:
break
if discovered:
await redis.sadd(f"auth:user_refresh:{user_id}", *discovered)
session_ids = {item.encode("utf-8") for item in discovered}
sessions: list[RefreshSession] = []
stale: list[str] = []
for raw_jti in session_ids:
jti = raw_jti.decode("utf-8") if isinstance(raw_jti, bytes) else str(raw_jti)
owner = await redis.get(f"auth:refresh:{jti}")
if not owner or not str(owner).isdigit() or int(owner) != user_id:
stale.append(jti)
continue
meta = await redis.hgetall(f"auth:refresh_meta:{jti}")
created_at_raw = meta.get("created_at") if isinstance(meta, dict) else None
if isinstance(created_at_raw, bytes):
created_at_raw = created_at_raw.decode("utf-8")
created_at = float(created_at_raw) if created_at_raw else time.time()
ip_raw = meta.get("ip_address") if isinstance(meta, dict) else ""
ua_raw = meta.get("user_agent") if isinstance(meta, dict) else ""
ip_address = ip_raw.decode("utf-8") if isinstance(ip_raw, bytes) else (str(ip_raw) if ip_raw else None)
user_agent = ua_raw.decode("utf-8") if isinstance(ua_raw, bytes) else (str(ua_raw) if ua_raw else None)
sessions.append(
RefreshSession(
jti=jti,
user_id=user_id,
created_at=created_at,
ip_address=ip_address or None,
user_agent=user_agent or None,
)
)
for jti in stale:
await redis.srem(f"auth:user_refresh:{user_id}", jti)
sessions.sort(key=lambda item: item.created_at, reverse=True)
return sessions
except RedisError:
_cleanup_fallback()
items = []
for jti in _fallback_user_sessions.get(user_id, set()):
info = _fallback_token_meta.get(jti)
if not info:
continue
uid, ip_address, user_agent, created_at, _exp_at = info
if uid != user_id:
continue
items.append(
RefreshSession(
jti=jti,
user_id=user_id,
created_at=created_at,
ip_address=ip_address,
user_agent=user_agent,
)
)
items.sort(key=lambda item: item.created_at, reverse=True)
return items
async def revoke_all_refresh_sessions_for_user(*, user_id: int) -> None:
sessions = await list_refresh_sessions_for_user(user_id=user_id)
for session in sessions:
await revoke_refresh_token_jti(jti=session.jti)