auth: force disconnect realtime on revoke-all sessions
All checks were successful
CI / test (push) Successful in 26s
All checks were successful
CI / test (push) Successful in 26s
This commit is contained in:
@@ -42,6 +42,7 @@ from app.auth.service import (
|
|||||||
)
|
)
|
||||||
from app.database.session import get_db
|
from app.database.session import get_db
|
||||||
from app.email.service import EmailService
|
from app.email.service import EmailService
|
||||||
|
from app.realtime.service import realtime_gateway
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.utils.rate_limit import enforce_ip_rate_limit
|
from app.utils.rate_limit import enforce_ip_rate_limit
|
||||||
from app.users.models import User
|
from app.users.models import User
|
||||||
@@ -193,6 +194,7 @@ async def revoke_all_sessions(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> None:
|
) -> None:
|
||||||
await revoke_all_user_sessions(db, user_id=current_user.id)
|
await revoke_all_user_sessions(db, user_id=current_user.id)
|
||||||
|
await realtime_gateway.disconnect_user(current_user.id)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/2fa/setup", response_model=TwoFactorSetupRead)
|
@router.post("/2fa/setup", response_model=TwoFactorSetupRead)
|
||||||
|
|||||||
@@ -162,6 +162,17 @@ class RealtimeGateway:
|
|||||||
payload={"chat_id": chat_id},
|
payload={"chat_id": chat_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def disconnect_user(self, user_id: int, *, code: int = 4401, reason: str = "Session revoked") -> None:
|
||||||
|
user_connections = self._connections.get(user_id, {})
|
||||||
|
if not user_connections:
|
||||||
|
return
|
||||||
|
for context in list(user_connections.values()):
|
||||||
|
try:
|
||||||
|
await context.websocket.close(code=code, reason=reason)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
await self._cleanup_disconnected_user(user_id)
|
||||||
|
|
||||||
async def _handle_redis_event(self, channel: str, payload: dict) -> None:
|
async def _handle_redis_event(self, channel: str, payload: dict) -> None:
|
||||||
chat_id = self._extract_chat_id(channel)
|
chat_id = self._extract_chat_id(channel)
|
||||||
if chat_id is None:
|
if chat_id is None:
|
||||||
@@ -242,23 +253,7 @@ class RealtimeGateway:
|
|||||||
for connection_id in disconnected:
|
for connection_id in disconnected:
|
||||||
user_connections.pop(connection_id, None)
|
user_connections.pop(connection_id, None)
|
||||||
if not user_connections:
|
if not user_connections:
|
||||||
self._connections.pop(user_id, None)
|
await self._cleanup_disconnected_user(user_id)
|
||||||
affected_chat_ids: list[int] = []
|
|
||||||
for chat_id, subscribers in list(self._chat_subscribers.items()):
|
|
||||||
if user_id in subscribers:
|
|
||||||
affected_chat_ids.append(chat_id)
|
|
||||||
subscribers.discard(user_id)
|
|
||||||
if not subscribers:
|
|
||||||
self._chat_subscribers.pop(chat_id, None)
|
|
||||||
became_offline = await mark_user_offline(user_id)
|
|
||||||
await self._persist_last_seen(user_id)
|
|
||||||
if became_offline:
|
|
||||||
await self._broadcast_presence(
|
|
||||||
affected_chat_ids,
|
|
||||||
user_id=user_id,
|
|
||||||
is_online=False,
|
|
||||||
last_seen_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_chat_id(channel: str) -> int | None:
|
def _extract_chat_id(channel: str) -> int | None:
|
||||||
@@ -296,5 +291,24 @@ class RealtimeGateway:
|
|||||||
payload["last_seen_at"] = last_seen_at.isoformat()
|
payload["last_seen_at"] = last_seen_at.isoformat()
|
||||||
await self._publish_chat_event(chat_id, event=event_name, payload=payload)
|
await self._publish_chat_event(chat_id, event=event_name, payload=payload)
|
||||||
|
|
||||||
|
async def _cleanup_disconnected_user(self, user_id: int) -> None:
|
||||||
|
self._connections.pop(user_id, None)
|
||||||
|
affected_chat_ids: list[int] = []
|
||||||
|
for chat_id, subscribers in list(self._chat_subscribers.items()):
|
||||||
|
if user_id in subscribers:
|
||||||
|
affected_chat_ids.append(chat_id)
|
||||||
|
subscribers.discard(user_id)
|
||||||
|
if not subscribers:
|
||||||
|
self._chat_subscribers.pop(chat_id, None)
|
||||||
|
became_offline = await mark_user_offline(user_id)
|
||||||
|
await self._persist_last_seen(user_id)
|
||||||
|
if became_offline:
|
||||||
|
await self._broadcast_presence(
|
||||||
|
affected_chat_ids,
|
||||||
|
user_id=user_id,
|
||||||
|
is_online=False,
|
||||||
|
last_seen_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
realtime_gateway = RealtimeGateway()
|
realtime_gateway = RealtimeGateway()
|
||||||
|
|||||||
@@ -539,7 +539,7 @@ Response: `204`
|
|||||||
|
|
||||||
Auth required.
|
Auth required.
|
||||||
Response: `204`
|
Response: `204`
|
||||||
Behavior: revokes all refresh sessions and invalidates all access tokens issued before this request.
|
Behavior: revokes all refresh sessions, invalidates all access tokens issued before this request, and force-closes active realtime WebSocket connections for the user.
|
||||||
|
|
||||||
### POST `/api/v1/auth/2fa/setup`
|
### POST `/api/v1/auth/2fa/setup`
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ Legend:
|
|||||||
29. Archive - `DONE`
|
29. Archive - `DONE`
|
||||||
30. Blacklist - `DONE`
|
30. Blacklist - `DONE`
|
||||||
31. Privacy - `PARTIAL` (avatar/last-seen/group-invites + PM policy `everyone|contacts|nobody`; remaining edge UX/matrix hardening)
|
31. Privacy - `PARTIAL` (avatar/last-seen/group-invites + PM policy `everyone|contacts|nobody`; remaining edge UX/matrix hardening)
|
||||||
32. Security - `PARTIAL` (sessions + revoke + 2FA base + access-session visibility; UX/TOTP recovery flow ongoing)
|
32. Security - `PARTIAL` (sessions + revoke + 2FA base + access-session visibility; revoke-all now force-disconnects active realtime sessions; UX/TOTP recovery flow ongoing)
|
||||||
33. Realtime Events - `DONE` (connect/disconnect/send/receive/typing/read/delivered/online/offline + chat/message updates)
|
33. Realtime Events - `DONE` (connect/disconnect/send/receive/typing/read/delivered/online/offline + chat/message updates)
|
||||||
34. Sync - `PARTIAL` (cross-device via backend state + realtime; reconciliation improved for loaded chats/messages, chat-info panel now hot-refreshes on `chat_updated`)
|
34. Sync - `PARTIAL` (cross-device via backend state + realtime; reconciliation improved for loaded chats/messages, chat-info panel now hot-refreshes on `chat_updated`)
|
||||||
35. Additional - `PARTIAL` (drafts/link preview partial/autoload media basic)
|
35. Additional - `PARTIAL` (drafts/link preview partial/autoload media basic)
|
||||||
|
|||||||
@@ -78,3 +78,40 @@ async def test_refresh_token_rotation(client, db_session):
|
|||||||
|
|
||||||
old_refresh_reuse = await client.post("/api/v1/auth/refresh", json={"refresh_token": refresh_token})
|
old_refresh_reuse = await client.post("/api/v1/auth/refresh", json={"refresh_token": refresh_token})
|
||||||
assert old_refresh_reuse.status_code == 401
|
assert old_refresh_reuse.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_revoke_all_sessions_invalidates_access_and_refresh(client, db_session):
|
||||||
|
payload = {
|
||||||
|
"email": "carol@example.com",
|
||||||
|
"name": "Carol",
|
||||||
|
"username": "carol",
|
||||||
|
"password": "strongpass123",
|
||||||
|
}
|
||||||
|
await client.post("/api/v1/auth/register", json=payload)
|
||||||
|
token_row = await db_session.execute(select(EmailVerificationToken).order_by(EmailVerificationToken.id.desc()))
|
||||||
|
verify_token = token_row.scalar_one().token
|
||||||
|
await client.post("/api/v1/auth/verify-email", json={"token": verify_token})
|
||||||
|
|
||||||
|
login_response = await client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": payload["email"], "password": payload["password"]},
|
||||||
|
)
|
||||||
|
assert login_response.status_code == 200
|
||||||
|
tokens = login_response.json()
|
||||||
|
access_token = tokens["access_token"]
|
||||||
|
refresh_token = tokens["refresh_token"]
|
||||||
|
|
||||||
|
revoke_all_response = await client.delete(
|
||||||
|
"/api/v1/auth/sessions",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
assert revoke_all_response.status_code == 204
|
||||||
|
|
||||||
|
me_response = await client.get(
|
||||||
|
"/api/v1/auth/me",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
assert me_response.status_code == 401
|
||||||
|
|
||||||
|
refresh_response = await client.post("/api/v1/auth/refresh", json={"refresh_token": refresh_token})
|
||||||
|
assert refresh_response.status_code == 401
|
||||||
|
|||||||
Reference in New Issue
Block a user