import time from redis.exceptions import RedisError from app.utils.redis_client import get_redis_client _fallback_tokens: dict[str, tuple[int, float]] = {} def _cleanup_fallback() -> None: now = time.time() expired = [jti for jti, (_, exp_at) in _fallback_tokens.items() if exp_at <= now] for jti in expired: _fallback_tokens.pop(jti, None) async def store_refresh_token_jti(*, user_id: int, jti: str, ttl_seconds: int) -> None: try: redis = get_redis_client() await redis.set(f"auth:refresh:{jti}", str(user_id), ex=ttl_seconds) except RedisError: _cleanup_fallback() _fallback_tokens[jti] = (user_id, time.time() + ttl_seconds) async def get_refresh_token_user_id(*, jti: str) -> int | None: try: redis = get_redis_client() value = await redis.get(f"auth:refresh:{jti}") if not value or not str(value).isdigit(): return None return int(value) except RedisError: _cleanup_fallback() data = _fallback_tokens.get(jti) if not data: return None return data[0] async def revoke_refresh_token_jti(*, jti: str) -> None: try: redis = get_redis_client() await redis.delete(f"auth:refresh:{jti}") except RedisError: _fallback_tokens.pop(jti, None)