import time from dataclasses import dataclass from redis.exceptions import RedisError from app.utils.redis_client import get_redis_client _fallback_tokens: dict[str, tuple[int, float]] = {} _fallback_token_meta: dict[str, tuple[int, str | None, str | None, float, float]] = {} _fallback_user_sessions: dict[int, set[str]] = {} @dataclass(slots=True) class RefreshSession: jti: str user_id: int created_at: float user_agent: str | None = None ip_address: str | None = None def _cleanup_fallback() -> None: now = time.time() expired = [jti for jti, (_, exp_at) in _fallback_tokens.items() if exp_at <= now] for jti in expired: _fallback_tokens.pop(jti, None) _fallback_token_meta.pop(jti, None) for user_id, session_ids in list(_fallback_user_sessions.items()): live = {jti for jti in session_ids if jti in _fallback_tokens} if live: _fallback_user_sessions[user_id] = live else: _fallback_user_sessions.pop(user_id, None) async def store_refresh_token_jti( *, user_id: int, jti: str, ttl_seconds: int, ip_address: str | None = None, user_agent: str | None = None, ) -> None: created_at = time.time() try: redis = get_redis_client() await redis.set(f"auth:refresh:{jti}", str(user_id), ex=ttl_seconds) await redis.hset( f"auth:refresh_meta:{jti}", mapping={ "user_id": str(user_id), "created_at": str(created_at), "ip_address": ip_address or "", "user_agent": user_agent or "", }, ) await redis.expire(f"auth:refresh_meta:{jti}", ttl_seconds) await redis.sadd(f"auth:user_refresh:{user_id}", jti) await redis.expire(f"auth:user_refresh:{user_id}", ttl_seconds + 86400) except RedisError: _cleanup_fallback() _fallback_tokens[jti] = (user_id, time.time() + ttl_seconds) _fallback_token_meta[jti] = (user_id, ip_address, user_agent, created_at, time.time() + ttl_seconds) if user_id not in _fallback_user_sessions: _fallback_user_sessions[user_id] = set() _fallback_user_sessions[user_id].add(jti) async def get_refresh_token_user_id(*, jti: str) -> int | None: try: redis = get_redis_client() value = await redis.get(f"auth:refresh:{jti}") if not value or not str(value).isdigit(): return None return int(value) except RedisError: _cleanup_fallback() data = _fallback_tokens.get(jti) if not data: return None return data[0] async def revoke_refresh_token_jti(*, jti: str) -> None: try: redis = get_redis_client() user_id = await redis.get(f"auth:refresh:{jti}") await redis.delete(f"auth:refresh:{jti}") await redis.delete(f"auth:refresh_meta:{jti}") if user_id and str(user_id).isdigit(): await redis.srem(f"auth:user_refresh:{int(user_id)}", jti) except RedisError: user_info = _fallback_tokens.get(jti) _fallback_tokens.pop(jti, None) _fallback_token_meta.pop(jti, None) if user_info: user_id = user_info[0] if user_id in _fallback_user_sessions: _fallback_user_sessions[user_id].discard(jti) if not _fallback_user_sessions[user_id]: _fallback_user_sessions.pop(user_id, None) async def list_refresh_sessions_for_user(*, user_id: int) -> list[RefreshSession]: try: redis = get_redis_client() session_ids = await redis.smembers(f"auth:user_refresh:{user_id}") if not session_ids: cursor = 0 discovered: list[str] = [] while True: cursor, keys = await redis.scan(cursor=cursor, match="auth:refresh:*", count=200) for raw_key in keys: key = raw_key.decode("utf-8") if isinstance(raw_key, bytes) else str(raw_key) if not key.startswith("auth:refresh:"): continue jti = key.removeprefix("auth:refresh:") if not jti: continue owner = await redis.get(key) if owner and str(owner).isdigit() and int(owner) == user_id: discovered.append(jti) if cursor == 0: break if discovered: await redis.sadd(f"auth:user_refresh:{user_id}", *discovered) session_ids = {item.encode("utf-8") for item in discovered} sessions: list[RefreshSession] = [] stale: list[str] = [] for raw_jti in session_ids: jti = raw_jti.decode("utf-8") if isinstance(raw_jti, bytes) else str(raw_jti) owner = await redis.get(f"auth:refresh:{jti}") if not owner or not str(owner).isdigit() or int(owner) != user_id: stale.append(jti) continue meta = await redis.hgetall(f"auth:refresh_meta:{jti}") created_at_raw = meta.get("created_at") if isinstance(meta, dict) else None if isinstance(created_at_raw, bytes): created_at_raw = created_at_raw.decode("utf-8") created_at = float(created_at_raw) if created_at_raw else time.time() ip_raw = meta.get("ip_address") if isinstance(meta, dict) else "" ua_raw = meta.get("user_agent") if isinstance(meta, dict) else "" ip_address = ip_raw.decode("utf-8") if isinstance(ip_raw, bytes) else (str(ip_raw) if ip_raw else None) user_agent = ua_raw.decode("utf-8") if isinstance(ua_raw, bytes) else (str(ua_raw) if ua_raw else None) sessions.append( RefreshSession( jti=jti, user_id=user_id, created_at=created_at, ip_address=ip_address or None, user_agent=user_agent or None, ) ) for jti in stale: await redis.srem(f"auth:user_refresh:{user_id}", jti) sessions.sort(key=lambda item: item.created_at, reverse=True) return sessions except RedisError: _cleanup_fallback() items = [] for jti in _fallback_user_sessions.get(user_id, set()): info = _fallback_token_meta.get(jti) if not info: continue uid, ip_address, user_agent, created_at, _exp_at = info if uid != user_id: continue items.append( RefreshSession( jti=jti, user_id=user_id, created_at=created_at, ip_address=ip_address, user_agent=user_agent, ) ) items.sort(key=lambda item: item.created_at, reverse=True) return items async def revoke_all_refresh_sessions_for_user(*, user_id: int) -> None: sessions = await list_refresh_sessions_for_user(user_id=user_id) for session in sessions: await revoke_refresh_token_jti(jti=session.jti)