auth: force disconnect realtime on revoke-all sessions
All checks were successful
CI / test (push) Successful in 26s

This commit is contained in:
2026-03-08 19:04:23 +03:00
parent 7e38123d4a
commit 1c9855b34c
5 changed files with 72 additions and 19 deletions

View File

@@ -42,6 +42,7 @@ from app.auth.service import (
)
from app.database.session import get_db
from app.email.service import EmailService
from app.realtime.service import realtime_gateway
from app.config.settings import settings
from app.utils.rate_limit import enforce_ip_rate_limit
from app.users.models import User
@@ -193,6 +194,7 @@ async def revoke_all_sessions(
current_user: User = Depends(get_current_user),
) -> None:
await revoke_all_user_sessions(db, user_id=current_user.id)
await realtime_gateway.disconnect_user(current_user.id)
@router.post("/2fa/setup", response_model=TwoFactorSetupRead)

View File

@@ -162,6 +162,17 @@ class RealtimeGateway:
payload={"chat_id": chat_id},
)
async def disconnect_user(self, user_id: int, *, code: int = 4401, reason: str = "Session revoked") -> None:
user_connections = self._connections.get(user_id, {})
if not user_connections:
return
for context in list(user_connections.values()):
try:
await context.websocket.close(code=code, reason=reason)
except Exception:
continue
await self._cleanup_disconnected_user(user_id)
async def _handle_redis_event(self, channel: str, payload: dict) -> None:
chat_id = self._extract_chat_id(channel)
if chat_id is None:
@@ -242,23 +253,7 @@ class RealtimeGateway:
for connection_id in disconnected:
user_connections.pop(connection_id, None)
if not user_connections:
self._connections.pop(user_id, None)
affected_chat_ids: list[int] = []
for chat_id, subscribers in list(self._chat_subscribers.items()):
if user_id in subscribers:
affected_chat_ids.append(chat_id)
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
became_offline = await mark_user_offline(user_id)
await self._persist_last_seen(user_id)
if became_offline:
await self._broadcast_presence(
affected_chat_ids,
user_id=user_id,
is_online=False,
last_seen_at=datetime.now(timezone.utc),
)
await self._cleanup_disconnected_user(user_id)
@staticmethod
def _extract_chat_id(channel: str) -> int | None:
@@ -296,5 +291,24 @@ class RealtimeGateway:
payload["last_seen_at"] = last_seen_at.isoformat()
await self._publish_chat_event(chat_id, event=event_name, payload=payload)
async def _cleanup_disconnected_user(self, user_id: int) -> None:
self._connections.pop(user_id, None)
affected_chat_ids: list[int] = []
for chat_id, subscribers in list(self._chat_subscribers.items()):
if user_id in subscribers:
affected_chat_ids.append(chat_id)
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
became_offline = await mark_user_offline(user_id)
await self._persist_last_seen(user_id)
if became_offline:
await self._broadcast_presence(
affected_chat_ids,
user_id=user_id,
is_online=False,
last_seen_at=datetime.now(timezone.utc),
)
realtime_gateway = RealtimeGateway()

View File

@@ -539,7 +539,7 @@ Response: `204`
Auth required.
Response: `204`
Behavior: revokes all refresh sessions and invalidates all access tokens issued before this request.
Behavior: revokes all refresh sessions, invalidates all access tokens issued before this request, and force-closes active realtime WebSocket connections for the user.
### POST `/api/v1/auth/2fa/setup`

View File

@@ -38,7 +38,7 @@ Legend:
29. Archive - `DONE`
30. Blacklist - `DONE`
31. Privacy - `PARTIAL` (avatar/last-seen/group-invites + PM policy `everyone|contacts|nobody`; remaining edge UX/matrix hardening)
32. Security - `PARTIAL` (sessions + revoke + 2FA base + access-session visibility; UX/TOTP recovery flow ongoing)
32. Security - `PARTIAL` (sessions + revoke + 2FA base + access-session visibility; revoke-all now force-disconnects active realtime sessions; UX/TOTP recovery flow ongoing)
33. Realtime Events - `DONE` (connect/disconnect/send/receive/typing/read/delivered/online/offline + chat/message updates)
34. Sync - `PARTIAL` (cross-device via backend state + realtime; reconciliation improved for loaded chats/messages, chat-info panel now hot-refreshes on `chat_updated`)
35. Additional - `PARTIAL` (drafts/link preview partial/autoload media basic)

View File

@@ -78,3 +78,40 @@ async def test_refresh_token_rotation(client, db_session):
old_refresh_reuse = await client.post("/api/v1/auth/refresh", json={"refresh_token": refresh_token})
assert old_refresh_reuse.status_code == 401
async def test_revoke_all_sessions_invalidates_access_and_refresh(client, db_session):
payload = {
"email": "carol@example.com",
"name": "Carol",
"username": "carol",
"password": "strongpass123",
}
await client.post("/api/v1/auth/register", json=payload)
token_row = await db_session.execute(select(EmailVerificationToken).order_by(EmailVerificationToken.id.desc()))
verify_token = token_row.scalar_one().token
await client.post("/api/v1/auth/verify-email", json={"token": verify_token})
login_response = await client.post(
"/api/v1/auth/login",
json={"email": payload["email"], "password": payload["password"]},
)
assert login_response.status_code == 200
tokens = login_response.json()
access_token = tokens["access_token"]
refresh_token = tokens["refresh_token"]
revoke_all_response = await client.delete(
"/api/v1/auth/sessions",
headers={"Authorization": f"Bearer {access_token}"},
)
assert revoke_all_response.status_code == 204
me_response = await client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {access_token}"},
)
assert me_response.status_code == 401
refresh_response = await client.post("/api/v1/auth/refresh", json={"refresh_token": refresh_token})
assert refresh_response.status_code == 401