import asyncio import contextlib from collections import defaultdict from datetime import datetime, timezone from uuid import uuid4 from fastapi import WebSocket from sqlalchemy.ext.asyncio import AsyncSession from app.chats.repository import list_user_chat_ids from app.chats.service import ensure_chat_membership from app.messages.schemas import MessageCreateRequest, MessageRead from app.messages.service import create_chat_message from app.realtime.models import ConnectionContext from app.realtime.repository import RedisRealtimeRepository from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload class RealtimeGateway: def __init__(self) -> None: self._repo = RedisRealtimeRepository() self._consume_task: asyncio.Task | None = None self._distributed_enabled = False self._connections: dict[int, dict[str, ConnectionContext]] = defaultdict(dict) self._chat_subscribers: dict[int, set[int]] = defaultdict(set) async def start(self) -> None: try: await self._repo.connect() if not self._consume_task: self._consume_task = asyncio.create_task(self._repo.consume(self._handle_redis_event)) self._distributed_enabled = True except Exception: self._distributed_enabled = False async def stop(self) -> None: if self._consume_task: self._consume_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._consume_task self._consume_task = None await self._repo.close() self._distributed_enabled = False async def register(self, user_id: int, websocket: WebSocket, user_chat_ids: list[int]) -> str: connection_id = str(uuid4()) self._connections[user_id][connection_id] = ConnectionContext( user_id=user_id, connection_id=connection_id, websocket=websocket, ) for chat_id in user_chat_ids: self._chat_subscribers[chat_id].add(user_id) await self._send_user_event( user_id, OutgoingRealtimeEvent( event="connect", payload={"connection_id": connection_id}, timestamp=datetime.now(timezone.utc), ), ) return connection_id async def unregister(self, user_id: int, connection_id: str, user_chat_ids: list[int]) -> None: user_connections = self._connections.get(user_id, {}) user_connections.pop(connection_id, None) if not user_connections: self._connections.pop(user_id, None) for chat_id in user_chat_ids: subscribers = self._chat_subscribers.get(chat_id) if not subscribers: continue subscribers.discard(user_id) if not subscribers: self._chat_subscribers.pop(chat_id, None) async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None: message = await create_chat_message( db, sender_id=user_id, payload=MessageCreateRequest(chat_id=payload.chat_id, type=payload.type, text=payload.text), ) message_data = MessageRead.model_validate(message).model_dump(mode="json") await self._publish_chat_event( payload.chat_id, event="receive_message", payload={ "chat_id": payload.chat_id, "message": message_data, "temp_id": payload.temp_id, "sender_id": user_id, }, ) async def handle_typing_event(self, db: AsyncSession, user_id: int, payload: ChatEventPayload, event: str) -> None: await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id) await self._publish_chat_event( payload.chat_id, event=event, payload={"chat_id": payload.chat_id, "user_id": user_id}, ) async def handle_message_status( self, db: AsyncSession, user_id: int, payload: MessageStatusPayload, event: str, ) -> None: await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id) await self._publish_chat_event( payload.chat_id, event=event, payload={ "chat_id": payload.chat_id, "message_id": payload.message_id, "user_id": user_id, }, ) async def load_user_chat_ids(self, db: AsyncSession, user_id: int) -> list[int]: return await list_user_chat_ids(db, user_id=user_id) async def _handle_redis_event(self, channel: str, payload: dict) -> None: chat_id = self._extract_chat_id(channel) if chat_id is None: return subscribers = self._chat_subscribers.get(chat_id, set()) if not subscribers: return event = OutgoingRealtimeEvent( event=payload.get("event", "error"), payload=payload.get("payload", {}), timestamp=datetime.now(timezone.utc), ) await asyncio.gather(*(self._send_user_event(user_id, event) for user_id in subscribers), return_exceptions=True) async def _publish_chat_event(self, chat_id: int, *, event: str, payload: dict) -> None: event_payload = { "event": event, "payload": payload, "timestamp": datetime.now(timezone.utc).isoformat(), } if self._distributed_enabled: await self._repo.publish_event(f"chat:{chat_id}", event_payload) return await self._handle_redis_event(f"chat:{chat_id}", event_payload) async def _send_user_event(self, user_id: int, event: OutgoingRealtimeEvent) -> None: user_connections = self._connections.get(user_id, {}) if not user_connections: return disconnected: list[str] = [] for connection_id, context in user_connections.items(): try: await context.websocket.send_json(event.model_dump(mode="json")) except Exception: disconnected.append(connection_id) for connection_id in disconnected: user_connections.pop(connection_id, None) if not user_connections: self._connections.pop(user_id, None) for chat_id, subscribers in list(self._chat_subscribers.items()): subscribers.discard(user_id) if not subscribers: self._chat_subscribers.pop(chat_id, None) @staticmethod def _extract_chat_id(channel: str) -> int | None: if not channel.startswith("chat:"): return None chat_id = channel.split(":", maxsplit=1)[1] if not chat_id.isdigit(): return None return int(chat_id) realtime_gateway = RealtimeGateway()