Files
Messenger/app/realtime/service.py
benya e6a271f8be
Some checks failed
CI / test (push) Failing after 22s
feat(chat): add presence metadata and improve web chat core
- add user last_seen_at with alembic migration and persist on realtime disconnect
- extend chat serialization with private online/last_seen, group members/online, channel subscribers
- add Redis batch presence lookup helper
- update web chat list/header to display status counters and last-seen labels
- improve delivery receipt handling using last_delivered/last_read boundaries
- include chat info panel and related API/type updates
2026-03-08 02:02:09 +03:00

224 lines
8.6 KiB
Python

import asyncio
import contextlib
from collections import defaultdict
from datetime import datetime, timezone
from uuid import uuid4
from fastapi import WebSocket
from sqlalchemy.ext.asyncio import AsyncSession
from app.chats.repository import list_user_chat_ids
from app.chats.service import ensure_chat_membership
from app.database.session import AsyncSessionLocal
from app.messages.schemas import MessageCreateRequest, MessageRead, MessageStatusUpdateRequest
from app.messages.service import create_chat_message, mark_message_status
from app.realtime.models import ConnectionContext
from app.realtime.presence import mark_user_offline, mark_user_online
from app.realtime.repository import RedisRealtimeRepository
from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload
from app.users.repository import update_user_last_seen_now
class RealtimeGateway:
def __init__(self) -> None:
self._repo = RedisRealtimeRepository()
self._consume_task: asyncio.Task | None = None
self._distributed_enabled = False
self._connections: dict[int, dict[str, ConnectionContext]] = defaultdict(dict)
self._chat_subscribers: dict[int, set[int]] = defaultdict(set)
async def start(self) -> None:
try:
await self._repo.connect()
if not self._consume_task:
self._consume_task = asyncio.create_task(self._repo.consume(self._handle_redis_event))
self._distributed_enabled = True
except Exception:
self._distributed_enabled = False
async def stop(self) -> None:
if self._consume_task:
self._consume_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._consume_task
self._consume_task = None
await self._repo.close()
self._distributed_enabled = False
async def register(self, user_id: int, websocket: WebSocket, user_chat_ids: list[int]) -> str:
connection_id = str(uuid4())
self._connections[user_id][connection_id] = ConnectionContext(
user_id=user_id,
connection_id=connection_id,
websocket=websocket,
)
for chat_id in user_chat_ids:
self._chat_subscribers[chat_id].add(user_id)
await mark_user_online(user_id)
await self._send_user_event(
user_id,
OutgoingRealtimeEvent(
event="connect",
payload={"connection_id": connection_id},
timestamp=datetime.now(timezone.utc),
),
)
return connection_id
async def unregister(self, user_id: int, connection_id: str, user_chat_ids: list[int]) -> None:
user_connections = self._connections.get(user_id, {})
user_connections.pop(connection_id, None)
if not user_connections:
self._connections.pop(user_id, None)
for chat_id in user_chat_ids:
subscribers = self._chat_subscribers.get(chat_id)
if not subscribers:
continue
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
await mark_user_offline(user_id)
await self._persist_last_seen(user_id)
async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None:
message = await create_chat_message(
db,
sender_id=user_id,
payload=MessageCreateRequest(
chat_id=payload.chat_id,
type=payload.type,
text=payload.text,
client_message_id=payload.client_message_id or payload.temp_id,
reply_to_message_id=payload.reply_to_message_id,
),
)
await self.publish_message_created(
message=message,
sender_id=user_id,
temp_id=payload.temp_id,
client_message_id=payload.client_message_id,
)
async def handle_typing_event(self, db: AsyncSession, user_id: int, payload: ChatEventPayload, event: str) -> None:
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
await self._publish_chat_event(
payload.chat_id,
event=event,
payload={"chat_id": payload.chat_id, "user_id": user_id},
)
async def handle_message_status(
self,
db: AsyncSession,
user_id: int,
payload: MessageStatusPayload,
event: str,
) -> dict[str, int]:
receipt_state = await mark_message_status(
db,
user_id=user_id,
payload=MessageStatusUpdateRequest(chat_id=payload.chat_id, message_id=payload.message_id, status=event),
)
await self._publish_chat_event(
payload.chat_id,
event=event,
payload={
"chat_id": payload.chat_id,
"message_id": payload.message_id,
"user_id": user_id,
"last_delivered_message_id": receipt_state["last_delivered_message_id"],
"last_read_message_id": receipt_state["last_read_message_id"],
},
)
return receipt_state
async def load_user_chat_ids(self, db: AsyncSession, user_id: int) -> list[int]:
return await list_user_chat_ids(db, user_id=user_id)
async def _handle_redis_event(self, channel: str, payload: dict) -> None:
chat_id = self._extract_chat_id(channel)
if chat_id is None:
return
subscribers = self._chat_subscribers.get(chat_id, set())
if not subscribers:
return
event = OutgoingRealtimeEvent(
event=payload.get("event", "error"),
payload=payload.get("payload", {}),
timestamp=datetime.now(timezone.utc),
)
await asyncio.gather(*(self._send_user_event(user_id, event) for user_id in subscribers), return_exceptions=True)
async def _publish_chat_event(self, chat_id: int, *, event: str, payload: dict) -> None:
event_payload = {
"event": event,
"payload": payload,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
if self._distributed_enabled:
await self._repo.publish_event(f"chat:{chat_id}", event_payload)
return
await self._handle_redis_event(f"chat:{chat_id}", event_payload)
async def publish_message_created(
self,
*,
message,
sender_id: int,
temp_id: str | None = None,
client_message_id: str | None = None,
) -> None:
message_data = MessageRead.model_validate(message).model_dump(mode="json")
await self._publish_chat_event(
message.chat_id,
event="receive_message",
payload={
"chat_id": message.chat_id,
"message": message_data,
"temp_id": temp_id,
"client_message_id": client_message_id,
"sender_id": sender_id,
},
)
async def _send_user_event(self, user_id: int, event: OutgoingRealtimeEvent) -> None:
user_connections = self._connections.get(user_id, {})
if not user_connections:
return
disconnected: list[str] = []
for connection_id, context in user_connections.items():
try:
await context.websocket.send_json(event.model_dump(mode="json"))
except Exception:
disconnected.append(connection_id)
for connection_id in disconnected:
user_connections.pop(connection_id, None)
if not user_connections:
self._connections.pop(user_id, None)
for chat_id, subscribers in list(self._chat_subscribers.items()):
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
await mark_user_offline(user_id)
await self._persist_last_seen(user_id)
@staticmethod
def _extract_chat_id(channel: str) -> int | None:
if not channel.startswith("chat:"):
return None
chat_id = channel.split(":", maxsplit=1)[1]
if not chat_id.isdigit():
return None
return int(chat_id)
async def _persist_last_seen(self, user_id: int) -> None:
try:
async with AsyncSessionLocal() as db:
await update_user_last_seen_now(db, user_id=user_id)
await db.commit()
except Exception:
return
realtime_gateway = RealtimeGateway()