import asyncio import contextlib from collections import defaultdict from datetime import datetime, timezone from uuid import uuid4 from fastapi import WebSocket from sqlalchemy.ext.asyncio import AsyncSession from app.chats.repository import list_user_chat_ids from app.chats.service import ensure_chat_membership from app.database.session import AsyncSessionLocal from app.messages.schemas import MessageCreateRequest, MessageRead, MessageStatusUpdateRequest from app.messages.service import create_chat_message, mark_message_status from app.realtime.models import ConnectionContext from app.realtime.presence import mark_user_offline, mark_user_online from app.realtime.repository import RedisRealtimeRepository from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload from app.users.repository import update_user_last_seen_now class RealtimeGateway: def __init__(self) -> None: self._repo = RedisRealtimeRepository() self._consume_task: asyncio.Task | None = None self._distributed_enabled = False self._connections: dict[int, dict[str, ConnectionContext]] = defaultdict(dict) self._chat_subscribers: dict[int, set[int]] = defaultdict(set) async def start(self) -> None: try: await self._repo.connect() if not self._consume_task: self._consume_task = asyncio.create_task(self._repo.consume(self._handle_redis_event)) self._distributed_enabled = True except Exception: self._distributed_enabled = False async def stop(self) -> None: if self._consume_task: self._consume_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._consume_task self._consume_task = None await self._repo.close() self._distributed_enabled = False async def register(self, user_id: int, websocket: WebSocket, user_chat_ids: list[int]) -> str: connection_id = str(uuid4()) self._connections[user_id][connection_id] = ConnectionContext( user_id=user_id, connection_id=connection_id, websocket=websocket, ) for chat_id in user_chat_ids: self._chat_subscribers[chat_id].add(user_id) became_online = await mark_user_online(user_id) if became_online: await self._broadcast_presence(user_chat_ids, user_id=user_id, is_online=True, last_seen_at=None) await self._send_user_event( user_id, OutgoingRealtimeEvent( event="connect", payload={"connection_id": connection_id}, timestamp=datetime.now(timezone.utc), ), ) return connection_id async def unregister(self, user_id: int, connection_id: str, user_chat_ids: list[int]) -> None: user_connections = self._connections.get(user_id, {}) user_connections.pop(connection_id, None) if not user_connections: self._connections.pop(user_id, None) for chat_id in user_chat_ids: subscribers = self._chat_subscribers.get(chat_id) if not subscribers: continue subscribers.discard(user_id) if not subscribers: self._chat_subscribers.pop(chat_id, None) became_offline = await mark_user_offline(user_id) await self._persist_last_seen(user_id) if became_offline: await self._broadcast_presence( user_chat_ids, user_id=user_id, is_online=False, last_seen_at=datetime.now(timezone.utc), ) async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None: message = await create_chat_message( db, sender_id=user_id, payload=MessageCreateRequest( chat_id=payload.chat_id, type=payload.type, text=payload.text, client_message_id=payload.client_message_id or payload.temp_id, reply_to_message_id=payload.reply_to_message_id, ), ) await self.publish_message_created( message=message, sender_id=user_id, temp_id=payload.temp_id, client_message_id=payload.client_message_id, ) async def handle_typing_event(self, db: AsyncSession, user_id: int, payload: ChatEventPayload, event: str) -> None: await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id) await self._publish_chat_event( payload.chat_id, event=event, payload={"chat_id": payload.chat_id, "user_id": user_id}, ) async def handle_message_status( self, db: AsyncSession, user_id: int, payload: MessageStatusPayload, event: str, ) -> dict[str, int]: receipt_state = await mark_message_status( db, user_id=user_id, payload=MessageStatusUpdateRequest(chat_id=payload.chat_id, message_id=payload.message_id, status=event), ) await self._publish_chat_event( payload.chat_id, event=event, payload={ "chat_id": payload.chat_id, "message_id": payload.message_id, "user_id": user_id, "last_delivered_message_id": receipt_state["last_delivered_message_id"], "last_read_message_id": receipt_state["last_read_message_id"], }, ) return receipt_state async def load_user_chat_ids(self, db: AsyncSession, user_id: int) -> list[int]: return await list_user_chat_ids(db, user_id=user_id) def add_chat_subscription(self, *, chat_id: int, user_id: int) -> None: self._chat_subscribers[chat_id].add(user_id) def remove_chat_subscription(self, *, chat_id: int, user_id: int) -> None: subscribers = self._chat_subscribers.get(chat_id) if not subscribers: return subscribers.discard(user_id) if not subscribers: self._chat_subscribers.pop(chat_id, None) async def publish_chat_updated(self, *, chat_id: int) -> None: await self._publish_chat_event( chat_id, event="chat_updated", payload={"chat_id": chat_id}, ) async def _handle_redis_event(self, channel: str, payload: dict) -> None: chat_id = self._extract_chat_id(channel) if chat_id is None: return subscribers = self._chat_subscribers.get(chat_id, set()) if not subscribers: return event = OutgoingRealtimeEvent( event=payload.get("event", "error"), payload=payload.get("payload", {}), timestamp=datetime.now(timezone.utc), ) await asyncio.gather(*(self._send_user_event(user_id, event) for user_id in subscribers), return_exceptions=True) async def _publish_chat_event(self, chat_id: int, *, event: str, payload: dict) -> None: event_payload = { "event": event, "payload": payload, "timestamp": datetime.now(timezone.utc).isoformat(), } if self._distributed_enabled: await self._repo.publish_event(f"chat:{chat_id}", event_payload) return await self._handle_redis_event(f"chat:{chat_id}", event_payload) async def publish_message_created( self, *, message, sender_id: int, temp_id: str | None = None, client_message_id: str | None = None, ) -> None: message_data = MessageRead.model_validate(message).model_dump(mode="json") await self._publish_chat_event( message.chat_id, event="receive_message", payload={ "chat_id": message.chat_id, "message": message_data, "temp_id": temp_id, "client_message_id": client_message_id, "sender_id": sender_id, }, ) async def _send_user_event(self, user_id: int, event: OutgoingRealtimeEvent) -> None: user_connections = self._connections.get(user_id, {}) if not user_connections: return disconnected: list[str] = [] for connection_id, context in user_connections.items(): try: await context.websocket.send_json(event.model_dump(mode="json")) except Exception: disconnected.append(connection_id) for connection_id in disconnected: user_connections.pop(connection_id, None) if not user_connections: self._connections.pop(user_id, None) affected_chat_ids: list[int] = [] for chat_id, subscribers in list(self._chat_subscribers.items()): if user_id in subscribers: affected_chat_ids.append(chat_id) subscribers.discard(user_id) if not subscribers: self._chat_subscribers.pop(chat_id, None) became_offline = await mark_user_offline(user_id) await self._persist_last_seen(user_id) if became_offline: await self._broadcast_presence( affected_chat_ids, user_id=user_id, is_online=False, last_seen_at=datetime.now(timezone.utc), ) @staticmethod def _extract_chat_id(channel: str) -> int | None: if not channel.startswith("chat:"): return None chat_id = channel.split(":", maxsplit=1)[1] if not chat_id.isdigit(): return None return int(chat_id) async def _persist_last_seen(self, user_id: int) -> None: try: async with AsyncSessionLocal() as db: await update_user_last_seen_now(db, user_id=user_id) await db.commit() except Exception: return async def _broadcast_presence( self, chat_ids: list[int], *, user_id: int, is_online: bool, last_seen_at: datetime | None, ) -> None: event_name = "user_online" if is_online else "user_offline" for chat_id in chat_ids: payload = { "chat_id": chat_id, "user_id": user_id, "is_online": is_online, } if last_seen_at is not None: payload["last_seen_at"] = last_seen_at.isoformat() await self._publish_chat_event(chat_id, event=event_name, payload=payload) realtime_gateway = RealtimeGateway()