auth: force disconnect realtime on revoke-all sessions
All checks were successful
CI / test (push) Successful in 26s

This commit is contained in:
2026-03-08 19:04:23 +03:00
parent 7e38123d4a
commit 1c9855b34c
5 changed files with 72 additions and 19 deletions

View File

@@ -42,6 +42,7 @@ from app.auth.service import (
)
from app.database.session import get_db
from app.email.service import EmailService
from app.realtime.service import realtime_gateway
from app.config.settings import settings
from app.utils.rate_limit import enforce_ip_rate_limit
from app.users.models import User
@@ -193,6 +194,7 @@ async def revoke_all_sessions(
current_user: User = Depends(get_current_user),
) -> None:
await revoke_all_user_sessions(db, user_id=current_user.id)
await realtime_gateway.disconnect_user(current_user.id)
@router.post("/2fa/setup", response_model=TwoFactorSetupRead)

View File

@@ -162,6 +162,17 @@ class RealtimeGateway:
payload={"chat_id": chat_id},
)
async def disconnect_user(self, user_id: int, *, code: int = 4401, reason: str = "Session revoked") -> None:
user_connections = self._connections.get(user_id, {})
if not user_connections:
return
for context in list(user_connections.values()):
try:
await context.websocket.close(code=code, reason=reason)
except Exception:
continue
await self._cleanup_disconnected_user(user_id)
async def _handle_redis_event(self, channel: str, payload: dict) -> None:
chat_id = self._extract_chat_id(channel)
if chat_id is None:
@@ -242,23 +253,7 @@ class RealtimeGateway:
for connection_id in disconnected:
user_connections.pop(connection_id, None)
if not user_connections:
self._connections.pop(user_id, None)
affected_chat_ids: list[int] = []
for chat_id, subscribers in list(self._chat_subscribers.items()):
if user_id in subscribers:
affected_chat_ids.append(chat_id)
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
became_offline = await mark_user_offline(user_id)
await self._persist_last_seen(user_id)
if became_offline:
await self._broadcast_presence(
affected_chat_ids,
user_id=user_id,
is_online=False,
last_seen_at=datetime.now(timezone.utc),
)
await self._cleanup_disconnected_user(user_id)
@staticmethod
def _extract_chat_id(channel: str) -> int | None:
@@ -296,5 +291,24 @@ class RealtimeGateway:
payload["last_seen_at"] = last_seen_at.isoformat()
await self._publish_chat_event(chat_id, event=event_name, payload=payload)
async def _cleanup_disconnected_user(self, user_id: int) -> None:
self._connections.pop(user_id, None)
affected_chat_ids: list[int] = []
for chat_id, subscribers in list(self._chat_subscribers.items()):
if user_id in subscribers:
affected_chat_ids.append(chat_id)
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
became_offline = await mark_user_offline(user_id)
await self._persist_last_seen(user_id)
if became_offline:
await self._broadcast_presence(
affected_chat_ids,
user_id=user_id,
is_online=False,
last_seen_at=datetime.now(timezone.utc),
)
realtime_gateway = RealtimeGateway()