Files
Messenger/app/realtime/service.py
benya e1d0375392
Some checks failed
CI / test (push) Failing after 24s
feat: add reply/forward/pin message flow across backend and web
- add reply_to/forwarded_from message fields and chat pinned_message field

- add forward and pin APIs plus reply support in message create

- wire web actions: Reply, Fwd, Pin and reply composer state

- fix spam policy bug: allow repeated identical messages, keep rate limiting
2026-03-08 00:28:43 +03:00

212 lines
8.1 KiB
Python

import asyncio
import contextlib
from collections import defaultdict
from datetime import datetime, timezone
from uuid import uuid4
from fastapi import WebSocket
from sqlalchemy.ext.asyncio import AsyncSession
from app.chats.repository import list_user_chat_ids
from app.chats.service import ensure_chat_membership
from app.messages.schemas import MessageCreateRequest, MessageRead, MessageStatusUpdateRequest
from app.messages.service import create_chat_message, mark_message_status
from app.realtime.models import ConnectionContext
from app.realtime.presence import mark_user_offline, mark_user_online
from app.realtime.repository import RedisRealtimeRepository
from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload
class RealtimeGateway:
def __init__(self) -> None:
self._repo = RedisRealtimeRepository()
self._consume_task: asyncio.Task | None = None
self._distributed_enabled = False
self._connections: dict[int, dict[str, ConnectionContext]] = defaultdict(dict)
self._chat_subscribers: dict[int, set[int]] = defaultdict(set)
async def start(self) -> None:
try:
await self._repo.connect()
if not self._consume_task:
self._consume_task = asyncio.create_task(self._repo.consume(self._handle_redis_event))
self._distributed_enabled = True
except Exception:
self._distributed_enabled = False
async def stop(self) -> None:
if self._consume_task:
self._consume_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._consume_task
self._consume_task = None
await self._repo.close()
self._distributed_enabled = False
async def register(self, user_id: int, websocket: WebSocket, user_chat_ids: list[int]) -> str:
connection_id = str(uuid4())
self._connections[user_id][connection_id] = ConnectionContext(
user_id=user_id,
connection_id=connection_id,
websocket=websocket,
)
for chat_id in user_chat_ids:
self._chat_subscribers[chat_id].add(user_id)
await mark_user_online(user_id)
await self._send_user_event(
user_id,
OutgoingRealtimeEvent(
event="connect",
payload={"connection_id": connection_id},
timestamp=datetime.now(timezone.utc),
),
)
return connection_id
async def unregister(self, user_id: int, connection_id: str, user_chat_ids: list[int]) -> None:
user_connections = self._connections.get(user_id, {})
user_connections.pop(connection_id, None)
if not user_connections:
self._connections.pop(user_id, None)
for chat_id in user_chat_ids:
subscribers = self._chat_subscribers.get(chat_id)
if not subscribers:
continue
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
await mark_user_offline(user_id)
async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None:
message = await create_chat_message(
db,
sender_id=user_id,
payload=MessageCreateRequest(
chat_id=payload.chat_id,
type=payload.type,
text=payload.text,
client_message_id=payload.client_message_id or payload.temp_id,
reply_to_message_id=payload.reply_to_message_id,
),
)
await self.publish_message_created(
message=message,
sender_id=user_id,
temp_id=payload.temp_id,
client_message_id=payload.client_message_id,
)
async def handle_typing_event(self, db: AsyncSession, user_id: int, payload: ChatEventPayload, event: str) -> None:
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
await self._publish_chat_event(
payload.chat_id,
event=event,
payload={"chat_id": payload.chat_id, "user_id": user_id},
)
async def handle_message_status(
self,
db: AsyncSession,
user_id: int,
payload: MessageStatusPayload,
event: str,
) -> dict[str, int]:
receipt_state = await mark_message_status(
db,
user_id=user_id,
payload=MessageStatusUpdateRequest(chat_id=payload.chat_id, message_id=payload.message_id, status=event),
)
await self._publish_chat_event(
payload.chat_id,
event=event,
payload={
"chat_id": payload.chat_id,
"message_id": payload.message_id,
"user_id": user_id,
"last_delivered_message_id": receipt_state["last_delivered_message_id"],
"last_read_message_id": receipt_state["last_read_message_id"],
},
)
return receipt_state
async def load_user_chat_ids(self, db: AsyncSession, user_id: int) -> list[int]:
return await list_user_chat_ids(db, user_id=user_id)
async def _handle_redis_event(self, channel: str, payload: dict) -> None:
chat_id = self._extract_chat_id(channel)
if chat_id is None:
return
subscribers = self._chat_subscribers.get(chat_id, set())
if not subscribers:
return
event = OutgoingRealtimeEvent(
event=payload.get("event", "error"),
payload=payload.get("payload", {}),
timestamp=datetime.now(timezone.utc),
)
await asyncio.gather(*(self._send_user_event(user_id, event) for user_id in subscribers), return_exceptions=True)
async def _publish_chat_event(self, chat_id: int, *, event: str, payload: dict) -> None:
event_payload = {
"event": event,
"payload": payload,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
if self._distributed_enabled:
await self._repo.publish_event(f"chat:{chat_id}", event_payload)
return
await self._handle_redis_event(f"chat:{chat_id}", event_payload)
async def publish_message_created(
self,
*,
message,
sender_id: int,
temp_id: str | None = None,
client_message_id: str | None = None,
) -> None:
message_data = MessageRead.model_validate(message).model_dump(mode="json")
await self._publish_chat_event(
message.chat_id,
event="receive_message",
payload={
"chat_id": message.chat_id,
"message": message_data,
"temp_id": temp_id,
"client_message_id": client_message_id,
"sender_id": sender_id,
},
)
async def _send_user_event(self, user_id: int, event: OutgoingRealtimeEvent) -> None:
user_connections = self._connections.get(user_id, {})
if not user_connections:
return
disconnected: list[str] = []
for connection_id, context in user_connections.items():
try:
await context.websocket.send_json(event.model_dump(mode="json"))
except Exception:
disconnected.append(connection_id)
for connection_id in disconnected:
user_connections.pop(connection_id, None)
if not user_connections:
self._connections.pop(user_id, None)
for chat_id, subscribers in list(self._chat_subscribers.items()):
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
await mark_user_offline(user_id)
@staticmethod
def _extract_chat_id(channel: str) -> int | None:
if not channel.startswith("chat:"):
return None
chat_id = channel.split(":", maxsplit=1)[1]
if not chat_id.isdigit():
return None
return int(chat_id)
realtime_gateway = RealtimeGateway()