From 1c9855b34cdd98993de4841e59502b99c6051c83 Mon Sep 17 00:00:00 2001 From: benya Date: Sun, 8 Mar 2026 19:04:23 +0300 Subject: [PATCH] auth: force disconnect realtime on revoke-all sessions --- app/auth/router.py | 2 ++ app/realtime/service.py | 48 ++++++++++++++++++++++------------- docs/api-reference.md | 2 +- docs/core-checklist-status.md | 2 +- tests/test_auth_flow.py | 37 +++++++++++++++++++++++++++ 5 files changed, 72 insertions(+), 19 deletions(-) diff --git a/app/auth/router.py b/app/auth/router.py index 0487ede..f201db3 100644 --- a/app/auth/router.py +++ b/app/auth/router.py @@ -42,6 +42,7 @@ from app.auth.service import ( ) from app.database.session import get_db from app.email.service import EmailService +from app.realtime.service import realtime_gateway from app.config.settings import settings from app.utils.rate_limit import enforce_ip_rate_limit from app.users.models import User @@ -193,6 +194,7 @@ async def revoke_all_sessions( current_user: User = Depends(get_current_user), ) -> None: await revoke_all_user_sessions(db, user_id=current_user.id) + await realtime_gateway.disconnect_user(current_user.id) @router.post("/2fa/setup", response_model=TwoFactorSetupRead) diff --git a/app/realtime/service.py b/app/realtime/service.py index de31acf..ffda60b 100644 --- a/app/realtime/service.py +++ b/app/realtime/service.py @@ -162,6 +162,17 @@ class RealtimeGateway: payload={"chat_id": chat_id}, ) + async def disconnect_user(self, user_id: int, *, code: int = 4401, reason: str = "Session revoked") -> None: + user_connections = self._connections.get(user_id, {}) + if not user_connections: + return + for context in list(user_connections.values()): + try: + await context.websocket.close(code=code, reason=reason) + except Exception: + continue + await self._cleanup_disconnected_user(user_id) + async def _handle_redis_event(self, channel: str, payload: dict) -> None: chat_id = self._extract_chat_id(channel) if chat_id is None: @@ -242,23 +253,7 @@ class RealtimeGateway: for connection_id in disconnected: user_connections.pop(connection_id, None) if not user_connections: - self._connections.pop(user_id, None) - affected_chat_ids: list[int] = [] - for chat_id, subscribers in list(self._chat_subscribers.items()): - if user_id in subscribers: - affected_chat_ids.append(chat_id) - subscribers.discard(user_id) - if not subscribers: - self._chat_subscribers.pop(chat_id, None) - became_offline = await mark_user_offline(user_id) - await self._persist_last_seen(user_id) - if became_offline: - await self._broadcast_presence( - affected_chat_ids, - user_id=user_id, - is_online=False, - last_seen_at=datetime.now(timezone.utc), - ) + await self._cleanup_disconnected_user(user_id) @staticmethod def _extract_chat_id(channel: str) -> int | None: @@ -296,5 +291,24 @@ class RealtimeGateway: payload["last_seen_at"] = last_seen_at.isoformat() await self._publish_chat_event(chat_id, event=event_name, payload=payload) + async def _cleanup_disconnected_user(self, user_id: int) -> None: + self._connections.pop(user_id, None) + affected_chat_ids: list[int] = [] + for chat_id, subscribers in list(self._chat_subscribers.items()): + if user_id in subscribers: + affected_chat_ids.append(chat_id) + subscribers.discard(user_id) + if not subscribers: + self._chat_subscribers.pop(chat_id, None) + became_offline = await mark_user_offline(user_id) + await self._persist_last_seen(user_id) + if became_offline: + await self._broadcast_presence( + affected_chat_ids, + user_id=user_id, + is_online=False, + last_seen_at=datetime.now(timezone.utc), + ) + realtime_gateway = RealtimeGateway() diff --git a/docs/api-reference.md b/docs/api-reference.md index 4dc7d5a..dd377e5 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -539,7 +539,7 @@ Response: `204` Auth required. Response: `204` -Behavior: revokes all refresh sessions and invalidates all access tokens issued before this request. +Behavior: revokes all refresh sessions, invalidates all access tokens issued before this request, and force-closes active realtime WebSocket connections for the user. ### POST `/api/v1/auth/2fa/setup` diff --git a/docs/core-checklist-status.md b/docs/core-checklist-status.md index 4479459..f6b287c 100644 --- a/docs/core-checklist-status.md +++ b/docs/core-checklist-status.md @@ -38,7 +38,7 @@ Legend: 29. Archive - `DONE` 30. Blacklist - `DONE` 31. Privacy - `PARTIAL` (avatar/last-seen/group-invites + PM policy `everyone|contacts|nobody`; remaining edge UX/matrix hardening) -32. Security - `PARTIAL` (sessions + revoke + 2FA base + access-session visibility; UX/TOTP recovery flow ongoing) +32. Security - `PARTIAL` (sessions + revoke + 2FA base + access-session visibility; revoke-all now force-disconnects active realtime sessions; UX/TOTP recovery flow ongoing) 33. Realtime Events - `DONE` (connect/disconnect/send/receive/typing/read/delivered/online/offline + chat/message updates) 34. Sync - `PARTIAL` (cross-device via backend state + realtime; reconciliation improved for loaded chats/messages, chat-info panel now hot-refreshes on `chat_updated`) 35. Additional - `PARTIAL` (drafts/link preview partial/autoload media basic) diff --git a/tests/test_auth_flow.py b/tests/test_auth_flow.py index 83811ff..2bb0e8e 100644 --- a/tests/test_auth_flow.py +++ b/tests/test_auth_flow.py @@ -78,3 +78,40 @@ async def test_refresh_token_rotation(client, db_session): old_refresh_reuse = await client.post("/api/v1/auth/refresh", json={"refresh_token": refresh_token}) assert old_refresh_reuse.status_code == 401 + + +async def test_revoke_all_sessions_invalidates_access_and_refresh(client, db_session): + payload = { + "email": "carol@example.com", + "name": "Carol", + "username": "carol", + "password": "strongpass123", + } + await client.post("/api/v1/auth/register", json=payload) + token_row = await db_session.execute(select(EmailVerificationToken).order_by(EmailVerificationToken.id.desc())) + verify_token = token_row.scalar_one().token + await client.post("/api/v1/auth/verify-email", json={"token": verify_token}) + + login_response = await client.post( + "/api/v1/auth/login", + json={"email": payload["email"], "password": payload["password"]}, + ) + assert login_response.status_code == 200 + tokens = login_response.json() + access_token = tokens["access_token"] + refresh_token = tokens["refresh_token"] + + revoke_all_response = await client.delete( + "/api/v1/auth/sessions", + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert revoke_all_response.status_code == 204 + + me_response = await client.get( + "/api/v1/auth/me", + headers={"Authorization": f"Bearer {access_token}"}, + ) + assert me_response.status_code == 401 + + refresh_response = await client.post("/api/v1/auth/refresh", json={"refresh_token": refresh_token}) + assert refresh_response.status_code == 401