first commit

This commit is contained in:
2026-03-07 21:31:38 +03:00
commit a879ba7b50
68 changed files with 2487 additions and 0 deletions

0
app/realtime/__init__.py Normal file
View File

10
app/realtime/models.py Normal file
View File

@@ -0,0 +1,10 @@
from dataclasses import dataclass
from fastapi import WebSocket
@dataclass(slots=True)
class ConnectionContext:
user_id: int
connection_id: str
websocket: WebSocket

View File

@@ -0,0 +1,48 @@
import json
from collections.abc import Awaitable, Callable
from redis.asyncio import Redis
from app.config.settings import settings
class RedisRealtimeRepository:
def __init__(self) -> None:
self._redis: Redis | None = None
self._pubsub = None
async def connect(self) -> None:
if self._redis:
return
self._redis = Redis.from_url(settings.redis_url, decode_responses=True)
self._pubsub = self._redis.pubsub()
await self._pubsub.psubscribe("chat:*")
async def close(self) -> None:
if self._pubsub:
await self._pubsub.close()
self._pubsub = None
if self._redis:
await self._redis.aclose()
self._redis = None
async def publish_event(self, channel: str, payload: dict) -> None:
if not self._redis:
await self.connect()
assert self._redis is not None
await self._redis.publish(channel, json.dumps(payload))
async def consume(self, handler: Callable[[str, dict], Awaitable[None]]) -> None:
if not self._pubsub:
await self.connect()
assert self._pubsub is not None
while True:
message = await self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if not message:
continue
channel = message.get("channel")
data = message.get("data")
if not channel or not isinstance(data, str):
continue
payload = json.loads(data)
await handler(channel, payload)

76
app/realtime/router.py Normal file
View File

@@ -0,0 +1,76 @@
from datetime import datetime, timezone
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, status
from pydantic import ValidationError
from app.auth.service import get_current_user_for_ws
from app.database.session import AsyncSessionLocal
from app.realtime.schemas import (
ChatEventPayload,
IncomingRealtimeEvent,
MessageStatusPayload,
OutgoingRealtimeEvent,
SendMessagePayload,
)
from app.realtime.service import realtime_gateway
router = APIRouter(prefix="/realtime", tags=["realtime"])
@router.websocket("/ws")
async def websocket_gateway(websocket: WebSocket) -> None:
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
async with AsyncSessionLocal() as db:
try:
user = await get_current_user_for_ws(token, db)
except Exception:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
user_chat_ids = await realtime_gateway.load_user_chat_ids(db, user.id)
await websocket.accept()
connection_id = await realtime_gateway.register(user.id, websocket, user_chat_ids)
try:
while True:
raw_data = await websocket.receive_json()
try:
event = IncomingRealtimeEvent.model_validate(raw_data)
await _dispatch_event(db, user.id, event)
except ValidationError:
await websocket.send_json(
OutgoingRealtimeEvent(
event="error",
payload={"detail": "Invalid event payload"},
timestamp=datetime.now(timezone.utc),
).model_dump(mode="json")
)
except Exception as exc:
await websocket.send_json(
OutgoingRealtimeEvent(
event="error",
payload={"detail": str(exc)},
timestamp=datetime.now(timezone.utc),
).model_dump(mode="json")
)
except WebSocketDisconnect:
await realtime_gateway.unregister(user.id, connection_id, user_chat_ids)
async def _dispatch_event(db, user_id: int, event: IncomingRealtimeEvent) -> None:
if event.event == "send_message":
payload = SendMessagePayload.model_validate(event.payload)
await realtime_gateway.handle_send_message(db, user_id, payload)
return
if event.event in {"typing_start", "typing_stop"}:
payload = ChatEventPayload.model_validate(event.payload)
await realtime_gateway.handle_typing_event(db, user_id, payload, event.event)
return
if event.event in {"message_read", "message_delivered"}:
payload = MessageStatusPayload.model_validate(event.payload)
await realtime_gateway.handle_message_status(db, user_id, payload, event.event)
return

48
app/realtime/schemas.py Normal file
View File

@@ -0,0 +1,48 @@
from datetime import datetime
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
from app.messages.models import MessageType
RealtimeEventName = Literal[
"connect",
"disconnect",
"send_message",
"receive_message",
"typing_start",
"typing_stop",
"message_read",
"message_delivered",
"error",
]
class SendMessagePayload(BaseModel):
chat_id: int
type: MessageType = MessageType.TEXT
text: str | None = Field(default=None, max_length=4096)
temp_id: str | None = None
class ChatEventPayload(BaseModel):
chat_id: int
class MessageStatusPayload(BaseModel):
chat_id: int
message_id: int
class IncomingRealtimeEvent(BaseModel):
event: Literal["send_message", "typing_start", "typing_stop", "message_read", "message_delivered"]
payload: dict[str, Any]
class OutgoingRealtimeEvent(BaseModel):
model_config = ConfigDict(from_attributes=True)
event: RealtimeEventName
payload: dict[str, Any]
timestamp: datetime

178
app/realtime/service.py Normal file
View File

@@ -0,0 +1,178 @@
import asyncio
import contextlib
from collections import defaultdict
from datetime import datetime, timezone
from uuid import uuid4
from fastapi import WebSocket
from sqlalchemy.ext.asyncio import AsyncSession
from app.chats.repository import list_user_chat_ids
from app.chats.service import ensure_chat_membership
from app.messages.schemas import MessageCreateRequest, MessageRead
from app.messages.service import create_chat_message
from app.realtime.models import ConnectionContext
from app.realtime.repository import RedisRealtimeRepository
from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload
class RealtimeGateway:
def __init__(self) -> None:
self._repo = RedisRealtimeRepository()
self._consume_task: asyncio.Task | None = None
self._distributed_enabled = False
self._connections: dict[int, dict[str, ConnectionContext]] = defaultdict(dict)
self._chat_subscribers: dict[int, set[int]] = defaultdict(set)
async def start(self) -> None:
try:
await self._repo.connect()
if not self._consume_task:
self._consume_task = asyncio.create_task(self._repo.consume(self._handle_redis_event))
self._distributed_enabled = True
except Exception:
self._distributed_enabled = False
async def stop(self) -> None:
if self._consume_task:
self._consume_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._consume_task
self._consume_task = None
await self._repo.close()
self._distributed_enabled = False
async def register(self, user_id: int, websocket: WebSocket, user_chat_ids: list[int]) -> str:
connection_id = str(uuid4())
self._connections[user_id][connection_id] = ConnectionContext(
user_id=user_id,
connection_id=connection_id,
websocket=websocket,
)
for chat_id in user_chat_ids:
self._chat_subscribers[chat_id].add(user_id)
await self._send_user_event(
user_id,
OutgoingRealtimeEvent(
event="connect",
payload={"connection_id": connection_id},
timestamp=datetime.now(timezone.utc),
),
)
return connection_id
async def unregister(self, user_id: int, connection_id: str, user_chat_ids: list[int]) -> None:
user_connections = self._connections.get(user_id, {})
user_connections.pop(connection_id, None)
if not user_connections:
self._connections.pop(user_id, None)
for chat_id in user_chat_ids:
subscribers = self._chat_subscribers.get(chat_id)
if not subscribers:
continue
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None:
message = await create_chat_message(
db,
sender_id=user_id,
payload=MessageCreateRequest(chat_id=payload.chat_id, type=payload.type, text=payload.text),
)
message_data = MessageRead.model_validate(message).model_dump(mode="json")
await self._publish_chat_event(
payload.chat_id,
event="receive_message",
payload={
"chat_id": payload.chat_id,
"message": message_data,
"temp_id": payload.temp_id,
"sender_id": user_id,
},
)
async def handle_typing_event(self, db: AsyncSession, user_id: int, payload: ChatEventPayload, event: str) -> None:
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
await self._publish_chat_event(
payload.chat_id,
event=event,
payload={"chat_id": payload.chat_id, "user_id": user_id},
)
async def handle_message_status(
self,
db: AsyncSession,
user_id: int,
payload: MessageStatusPayload,
event: str,
) -> None:
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
await self._publish_chat_event(
payload.chat_id,
event=event,
payload={
"chat_id": payload.chat_id,
"message_id": payload.message_id,
"user_id": user_id,
},
)
async def load_user_chat_ids(self, db: AsyncSession, user_id: int) -> list[int]:
return await list_user_chat_ids(db, user_id=user_id)
async def _handle_redis_event(self, channel: str, payload: dict) -> None:
chat_id = self._extract_chat_id(channel)
if chat_id is None:
return
subscribers = self._chat_subscribers.get(chat_id, set())
if not subscribers:
return
event = OutgoingRealtimeEvent(
event=payload.get("event", "error"),
payload=payload.get("payload", {}),
timestamp=datetime.now(timezone.utc),
)
await asyncio.gather(*(self._send_user_event(user_id, event) for user_id in subscribers), return_exceptions=True)
async def _publish_chat_event(self, chat_id: int, *, event: str, payload: dict) -> None:
event_payload = {
"event": event,
"payload": payload,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
if self._distributed_enabled:
await self._repo.publish_event(f"chat:{chat_id}", event_payload)
return
await self._handle_redis_event(f"chat:{chat_id}", event_payload)
async def _send_user_event(self, user_id: int, event: OutgoingRealtimeEvent) -> None:
user_connections = self._connections.get(user_id, {})
if not user_connections:
return
disconnected: list[str] = []
for connection_id, context in user_connections.items():
try:
await context.websocket.send_json(event.model_dump(mode="json"))
except Exception:
disconnected.append(connection_id)
for connection_id in disconnected:
user_connections.pop(connection_id, None)
if not user_connections:
self._connections.pop(user_id, None)
for chat_id, subscribers in list(self._chat_subscribers.items()):
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
@staticmethod
def _extract_chat_id(channel: str) -> int | None:
if not channel.startswith("chat:"):
return None
chat_id = channel.split(":", maxsplit=1)[1]
if not chat_id.isdigit():
return None
return int(chat_id)
realtime_gateway = RealtimeGateway()