first commit

This commit is contained in:
2026-03-07 21:31:38 +03:00
commit a879ba7b50
68 changed files with 2487 additions and 0 deletions

178
app/realtime/service.py Normal file
View File

@@ -0,0 +1,178 @@
import asyncio
import contextlib
from collections import defaultdict
from datetime import datetime, timezone
from uuid import uuid4
from fastapi import WebSocket
from sqlalchemy.ext.asyncio import AsyncSession
from app.chats.repository import list_user_chat_ids
from app.chats.service import ensure_chat_membership
from app.messages.schemas import MessageCreateRequest, MessageRead
from app.messages.service import create_chat_message
from app.realtime.models import ConnectionContext
from app.realtime.repository import RedisRealtimeRepository
from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload
class RealtimeGateway:
def __init__(self) -> None:
self._repo = RedisRealtimeRepository()
self._consume_task: asyncio.Task | None = None
self._distributed_enabled = False
self._connections: dict[int, dict[str, ConnectionContext]] = defaultdict(dict)
self._chat_subscribers: dict[int, set[int]] = defaultdict(set)
async def start(self) -> None:
try:
await self._repo.connect()
if not self._consume_task:
self._consume_task = asyncio.create_task(self._repo.consume(self._handle_redis_event))
self._distributed_enabled = True
except Exception:
self._distributed_enabled = False
async def stop(self) -> None:
if self._consume_task:
self._consume_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._consume_task
self._consume_task = None
await self._repo.close()
self._distributed_enabled = False
async def register(self, user_id: int, websocket: WebSocket, user_chat_ids: list[int]) -> str:
connection_id = str(uuid4())
self._connections[user_id][connection_id] = ConnectionContext(
user_id=user_id,
connection_id=connection_id,
websocket=websocket,
)
for chat_id in user_chat_ids:
self._chat_subscribers[chat_id].add(user_id)
await self._send_user_event(
user_id,
OutgoingRealtimeEvent(
event="connect",
payload={"connection_id": connection_id},
timestamp=datetime.now(timezone.utc),
),
)
return connection_id
async def unregister(self, user_id: int, connection_id: str, user_chat_ids: list[int]) -> None:
user_connections = self._connections.get(user_id, {})
user_connections.pop(connection_id, None)
if not user_connections:
self._connections.pop(user_id, None)
for chat_id in user_chat_ids:
subscribers = self._chat_subscribers.get(chat_id)
if not subscribers:
continue
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None:
message = await create_chat_message(
db,
sender_id=user_id,
payload=MessageCreateRequest(chat_id=payload.chat_id, type=payload.type, text=payload.text),
)
message_data = MessageRead.model_validate(message).model_dump(mode="json")
await self._publish_chat_event(
payload.chat_id,
event="receive_message",
payload={
"chat_id": payload.chat_id,
"message": message_data,
"temp_id": payload.temp_id,
"sender_id": user_id,
},
)
async def handle_typing_event(self, db: AsyncSession, user_id: int, payload: ChatEventPayload, event: str) -> None:
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
await self._publish_chat_event(
payload.chat_id,
event=event,
payload={"chat_id": payload.chat_id, "user_id": user_id},
)
async def handle_message_status(
self,
db: AsyncSession,
user_id: int,
payload: MessageStatusPayload,
event: str,
) -> None:
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
await self._publish_chat_event(
payload.chat_id,
event=event,
payload={
"chat_id": payload.chat_id,
"message_id": payload.message_id,
"user_id": user_id,
},
)
async def load_user_chat_ids(self, db: AsyncSession, user_id: int) -> list[int]:
return await list_user_chat_ids(db, user_id=user_id)
async def _handle_redis_event(self, channel: str, payload: dict) -> None:
chat_id = self._extract_chat_id(channel)
if chat_id is None:
return
subscribers = self._chat_subscribers.get(chat_id, set())
if not subscribers:
return
event = OutgoingRealtimeEvent(
event=payload.get("event", "error"),
payload=payload.get("payload", {}),
timestamp=datetime.now(timezone.utc),
)
await asyncio.gather(*(self._send_user_event(user_id, event) for user_id in subscribers), return_exceptions=True)
async def _publish_chat_event(self, chat_id: int, *, event: str, payload: dict) -> None:
event_payload = {
"event": event,
"payload": payload,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
if self._distributed_enabled:
await self._repo.publish_event(f"chat:{chat_id}", event_payload)
return
await self._handle_redis_event(f"chat:{chat_id}", event_payload)
async def _send_user_event(self, user_id: int, event: OutgoingRealtimeEvent) -> None:
user_connections = self._connections.get(user_id, {})
if not user_connections:
return
disconnected: list[str] = []
for connection_id, context in user_connections.items():
try:
await context.websocket.send_json(event.model_dump(mode="json"))
except Exception:
disconnected.append(connection_id)
for connection_id in disconnected:
user_connections.pop(connection_id, None)
if not user_connections:
self._connections.pop(user_id, None)
for chat_id, subscribers in list(self._chat_subscribers.items()):
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
@staticmethod
def _extract_chat_id(channel: str) -> int | None:
if not channel.startswith("chat:"):
return None
chat_id = channel.split(":", maxsplit=1)[1]
if not chat_id.isdigit():
return None
return int(chat_id)
realtime_gateway = RealtimeGateway()