auth: force disconnect realtime on revoke-all sessions
All checks were successful
CI / test (push) Successful in 26s
All checks were successful
CI / test (push) Successful in 26s
This commit is contained in:
@@ -42,6 +42,7 @@ from app.auth.service import (
|
||||
)
|
||||
from app.database.session import get_db
|
||||
from app.email.service import EmailService
|
||||
from app.realtime.service import realtime_gateway
|
||||
from app.config.settings import settings
|
||||
from app.utils.rate_limit import enforce_ip_rate_limit
|
||||
from app.users.models import User
|
||||
@@ -193,6 +194,7 @@ async def revoke_all_sessions(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> None:
|
||||
await revoke_all_user_sessions(db, user_id=current_user.id)
|
||||
await realtime_gateway.disconnect_user(current_user.id)
|
||||
|
||||
|
||||
@router.post("/2fa/setup", response_model=TwoFactorSetupRead)
|
||||
|
||||
@@ -162,6 +162,17 @@ class RealtimeGateway:
|
||||
payload={"chat_id": chat_id},
|
||||
)
|
||||
|
||||
async def disconnect_user(self, user_id: int, *, code: int = 4401, reason: str = "Session revoked") -> None:
|
||||
user_connections = self._connections.get(user_id, {})
|
||||
if not user_connections:
|
||||
return
|
||||
for context in list(user_connections.values()):
|
||||
try:
|
||||
await context.websocket.close(code=code, reason=reason)
|
||||
except Exception:
|
||||
continue
|
||||
await self._cleanup_disconnected_user(user_id)
|
||||
|
||||
async def _handle_redis_event(self, channel: str, payload: dict) -> None:
|
||||
chat_id = self._extract_chat_id(channel)
|
||||
if chat_id is None:
|
||||
@@ -242,23 +253,7 @@ class RealtimeGateway:
|
||||
for connection_id in disconnected:
|
||||
user_connections.pop(connection_id, None)
|
||||
if not user_connections:
|
||||
self._connections.pop(user_id, None)
|
||||
affected_chat_ids: list[int] = []
|
||||
for chat_id, subscribers in list(self._chat_subscribers.items()):
|
||||
if user_id in subscribers:
|
||||
affected_chat_ids.append(chat_id)
|
||||
subscribers.discard(user_id)
|
||||
if not subscribers:
|
||||
self._chat_subscribers.pop(chat_id, None)
|
||||
became_offline = await mark_user_offline(user_id)
|
||||
await self._persist_last_seen(user_id)
|
||||
if became_offline:
|
||||
await self._broadcast_presence(
|
||||
affected_chat_ids,
|
||||
user_id=user_id,
|
||||
is_online=False,
|
||||
last_seen_at=datetime.now(timezone.utc),
|
||||
)
|
||||
await self._cleanup_disconnected_user(user_id)
|
||||
|
||||
@staticmethod
|
||||
def _extract_chat_id(channel: str) -> int | None:
|
||||
@@ -296,5 +291,24 @@ class RealtimeGateway:
|
||||
payload["last_seen_at"] = last_seen_at.isoformat()
|
||||
await self._publish_chat_event(chat_id, event=event_name, payload=payload)
|
||||
|
||||
async def _cleanup_disconnected_user(self, user_id: int) -> None:
|
||||
self._connections.pop(user_id, None)
|
||||
affected_chat_ids: list[int] = []
|
||||
for chat_id, subscribers in list(self._chat_subscribers.items()):
|
||||
if user_id in subscribers:
|
||||
affected_chat_ids.append(chat_id)
|
||||
subscribers.discard(user_id)
|
||||
if not subscribers:
|
||||
self._chat_subscribers.pop(chat_id, None)
|
||||
became_offline = await mark_user_offline(user_id)
|
||||
await self._persist_last_seen(user_id)
|
||||
if became_offline:
|
||||
await self._broadcast_presence(
|
||||
affected_chat_ids,
|
||||
user_id=user_id,
|
||||
is_online=False,
|
||||
last_seen_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
realtime_gateway = RealtimeGateway()
|
||||
|
||||
Reference in New Issue
Block a user