Files
Messenger/app/realtime/service.py
benya d6cd0e719c
Some checks are pending
CI / test (push) Has started running
fix(realtime): flush activity state during forced disconnect cleanup
2026-03-08 20:02:46 +03:00

371 lines
14 KiB
Python

import asyncio
import contextlib
from collections import defaultdict
from datetime import datetime, timezone
from uuid import uuid4
from fastapi import WebSocket
from sqlalchemy.ext.asyncio import AsyncSession
from app.chats.repository import list_user_chat_ids
from app.chats.service import ensure_chat_membership
from app.database.session import AsyncSessionLocal
from app.messages.schemas import MessageCreateRequest, MessageRead, MessageStatusUpdateRequest
from app.messages.service import create_chat_message, mark_message_status
from app.realtime.models import ConnectionContext
from app.realtime.presence import mark_user_offline, mark_user_online
from app.realtime.repository import RedisRealtimeRepository
from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload
from app.users.repository import update_user_last_seen_now
class RealtimeGateway:
def __init__(self) -> None:
self._repo = RedisRealtimeRepository()
self._consume_task: asyncio.Task | None = None
self._distributed_enabled = False
self._connections: dict[int, dict[str, ConnectionContext]] = defaultdict(dict)
self._chat_subscribers: dict[int, set[int]] = defaultdict(set)
self._typing_chats_by_user: dict[int, set[int]] = defaultdict(set)
self._recording_voice_chats_by_user: dict[int, set[int]] = defaultdict(set)
self._recording_video_chats_by_user: dict[int, set[int]] = defaultdict(set)
async def start(self) -> None:
try:
await self._repo.connect()
if not self._consume_task:
self._consume_task = asyncio.create_task(self._repo.consume(self._handle_redis_event))
self._distributed_enabled = True
except Exception:
self._distributed_enabled = False
async def stop(self) -> None:
if self._consume_task:
self._consume_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._consume_task
self._consume_task = None
await self._repo.close()
self._distributed_enabled = False
async def register(self, user_id: int, websocket: WebSocket, user_chat_ids: list[int]) -> str:
connection_id = str(uuid4())
self._connections[user_id][connection_id] = ConnectionContext(
user_id=user_id,
connection_id=connection_id,
websocket=websocket,
)
for chat_id in user_chat_ids:
self._chat_subscribers[chat_id].add(user_id)
became_online = await mark_user_online(user_id)
if became_online:
await self._broadcast_presence(user_chat_ids, user_id=user_id, is_online=True, last_seen_at=None)
await self._send_user_event(
user_id,
OutgoingRealtimeEvent(
event="connect",
payload={"connection_id": connection_id},
timestamp=datetime.now(timezone.utc),
),
)
return connection_id
async def unregister(self, user_id: int, connection_id: str, user_chat_ids: list[int]) -> None:
user_connections = self._connections.get(user_id, {})
user_connections.pop(connection_id, None)
if not user_connections:
await self._flush_user_activity(user_id)
self._connections.pop(user_id, None)
for chat_id in user_chat_ids:
subscribers = self._chat_subscribers.get(chat_id)
if not subscribers:
continue
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
became_offline = await mark_user_offline(user_id)
await self._persist_last_seen(user_id)
if became_offline:
await self._broadcast_presence(
user_chat_ids,
user_id=user_id,
is_online=False,
last_seen_at=datetime.now(timezone.utc),
)
async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None:
message = await create_chat_message(
db,
sender_id=user_id,
payload=MessageCreateRequest(
chat_id=payload.chat_id,
type=payload.type,
text=payload.text,
client_message_id=payload.client_message_id or payload.temp_id,
reply_to_message_id=payload.reply_to_message_id,
),
)
await self.publish_message_created(
message=message,
sender_id=user_id,
temp_id=payload.temp_id,
client_message_id=payload.client_message_id,
)
async def handle_typing_event(self, db: AsyncSession, user_id: int, payload: ChatEventPayload, event: str) -> None:
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
self._update_user_activity(user_id=user_id, chat_id=payload.chat_id, event=event)
await self._publish_chat_event(
payload.chat_id,
event=event,
payload={"chat_id": payload.chat_id, "user_id": user_id},
)
async def handle_message_status(
self,
db: AsyncSession,
user_id: int,
payload: MessageStatusPayload,
event: str,
) -> dict[str, int]:
receipt_state = await mark_message_status(
db,
user_id=user_id,
payload=MessageStatusUpdateRequest(chat_id=payload.chat_id, message_id=payload.message_id, status=event),
)
await self._publish_chat_event(
payload.chat_id,
event=event,
payload={
"chat_id": payload.chat_id,
"message_id": payload.message_id,
"user_id": user_id,
"last_delivered_message_id": receipt_state["last_delivered_message_id"],
"last_read_message_id": receipt_state["last_read_message_id"],
},
)
return receipt_state
async def load_user_chat_ids(self, db: AsyncSession, user_id: int) -> list[int]:
return await list_user_chat_ids(db, user_id=user_id)
def add_chat_subscription(self, *, chat_id: int, user_id: int) -> None:
self._chat_subscribers[chat_id].add(user_id)
def remove_chat_subscription(self, *, chat_id: int, user_id: int) -> None:
subscribers = self._chat_subscribers.get(chat_id)
if not subscribers:
return
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
async def publish_chat_updated(self, *, chat_id: int) -> None:
await self._publish_chat_event(
chat_id,
event="chat_updated",
payload={"chat_id": chat_id},
)
async def publish_chat_deleted(self, *, chat_id: int) -> None:
await self._publish_chat_event(
chat_id,
event="chat_deleted",
payload={"chat_id": chat_id},
)
async def disconnect_user(self, user_id: int, *, code: int = 4401, reason: str = "Session revoked") -> None:
user_connections = self._connections.get(user_id, {})
if not user_connections:
return
for context in list(user_connections.values()):
try:
await context.websocket.close(code=code, reason=reason)
except Exception:
continue
await self._cleanup_disconnected_user(user_id)
async def _handle_redis_event(self, channel: str, payload: dict) -> None:
chat_id = self._extract_chat_id(channel)
if chat_id is None:
return
subscribers = self._chat_subscribers.get(chat_id, set())
if not subscribers:
return
event = OutgoingRealtimeEvent(
event=payload.get("event", "error"),
payload=payload.get("payload", {}),
timestamp=datetime.now(timezone.utc),
)
await asyncio.gather(*(self._send_user_event(user_id, event) for user_id in subscribers), return_exceptions=True)
async def _publish_chat_event(self, chat_id: int, *, event: str, payload: dict) -> None:
event_payload = {
"event": event,
"payload": payload,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
if self._distributed_enabled:
await self._repo.publish_event(f"chat:{chat_id}", event_payload)
return
await self._handle_redis_event(f"chat:{chat_id}", event_payload)
def _update_user_activity(self, *, user_id: int, chat_id: int, event: str) -> None:
if event == "typing_start":
self._typing_chats_by_user[user_id].add(chat_id)
return
if event == "typing_stop":
self._typing_chats_by_user[user_id].discard(chat_id)
return
if event == "recording_voice_start":
self._recording_voice_chats_by_user[user_id].add(chat_id)
return
if event == "recording_voice_stop":
self._recording_voice_chats_by_user[user_id].discard(chat_id)
return
if event == "recording_video_start":
self._recording_video_chats_by_user[user_id].add(chat_id)
return
if event == "recording_video_stop":
self._recording_video_chats_by_user[user_id].discard(chat_id)
async def _flush_user_activity(self, user_id: int) -> None:
typing_chats = list(self._typing_chats_by_user.pop(user_id, set()))
voice_chats = list(self._recording_voice_chats_by_user.pop(user_id, set()))
video_chats = list(self._recording_video_chats_by_user.pop(user_id, set()))
for chat_id in typing_chats:
await self._publish_chat_event(
chat_id,
event="typing_stop",
payload={"chat_id": chat_id, "user_id": user_id},
)
for chat_id in voice_chats:
await self._publish_chat_event(
chat_id,
event="recording_voice_stop",
payload={"chat_id": chat_id, "user_id": user_id},
)
for chat_id in video_chats:
await self._publish_chat_event(
chat_id,
event="recording_video_stop",
payload={"chat_id": chat_id, "user_id": user_id},
)
async def publish_message_created(
self,
*,
message,
sender_id: int,
temp_id: str | None = None,
client_message_id: str | None = None,
) -> None:
message_data = MessageRead.model_validate(message).model_dump(mode="json")
await self._publish_chat_event(
message.chat_id,
event="receive_message",
payload={
"chat_id": message.chat_id,
"message": message_data,
"temp_id": temp_id,
"client_message_id": client_message_id,
"sender_id": sender_id,
},
)
async def publish_message_updated(self, *, message) -> None:
message_data = MessageRead.model_validate(message).model_dump(mode="json")
await self._publish_chat_event(
message.chat_id,
event="message_updated",
payload={
"chat_id": message.chat_id,
"message": message_data,
},
)
async def publish_message_deleted(self, *, chat_id: int, message_id: int) -> None:
await self._publish_chat_event(
chat_id,
event="message_deleted",
payload={
"chat_id": chat_id,
"message_id": message_id,
},
)
async def _send_user_event(self, user_id: int, event: OutgoingRealtimeEvent) -> None:
user_connections = self._connections.get(user_id, {})
if not user_connections:
return
disconnected: list[str] = []
for connection_id, context in user_connections.items():
try:
await context.websocket.send_json(event.model_dump(mode="json"))
except Exception:
disconnected.append(connection_id)
for connection_id in disconnected:
user_connections.pop(connection_id, None)
if not user_connections:
await self._cleanup_disconnected_user(user_id)
@staticmethod
def _extract_chat_id(channel: str) -> int | None:
if not channel.startswith("chat:"):
return None
chat_id = channel.split(":", maxsplit=1)[1]
if not chat_id.isdigit():
return None
return int(chat_id)
async def _persist_last_seen(self, user_id: int) -> None:
try:
async with AsyncSessionLocal() as db:
await update_user_last_seen_now(db, user_id=user_id)
await db.commit()
except Exception:
return
async def _broadcast_presence(
self,
chat_ids: list[int],
*,
user_id: int,
is_online: bool,
last_seen_at: datetime | None,
) -> None:
event_name = "user_online" if is_online else "user_offline"
for chat_id in chat_ids:
payload = {
"chat_id": chat_id,
"user_id": user_id,
"is_online": is_online,
}
if last_seen_at is not None:
payload["last_seen_at"] = last_seen_at.isoformat()
await self._publish_chat_event(chat_id, event=event_name, payload=payload)
async def _cleanup_disconnected_user(self, user_id: int) -> None:
await self._flush_user_activity(user_id)
self._connections.pop(user_id, None)
affected_chat_ids: list[int] = []
for chat_id, subscribers in list(self._chat_subscribers.items()):
if user_id in subscribers:
affected_chat_ids.append(chat_id)
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
became_offline = await mark_user_offline(user_id)
await self._persist_last_seen(user_id)
if became_offline:
await self._broadcast_presence(
affected_chat_ids,
user_id=user_id,
is_online=False,
last_seen_at=datetime.now(timezone.utc),
)
realtime_gateway = RealtimeGateway()