first commit
This commit is contained in:
0
app/realtime/__init__.py
Normal file
0
app/realtime/__init__.py
Normal file
10
app/realtime/models.py
Normal file
10
app/realtime/models.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ConnectionContext:
|
||||
user_id: int
|
||||
connection_id: str
|
||||
websocket: WebSocket
|
||||
48
app/realtime/repository.py
Normal file
48
app/realtime/repository.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
class RedisRealtimeRepository:
|
||||
def __init__(self) -> None:
|
||||
self._redis: Redis | None = None
|
||||
self._pubsub = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self._redis:
|
||||
return
|
||||
self._redis = Redis.from_url(settings.redis_url, decode_responses=True)
|
||||
self._pubsub = self._redis.pubsub()
|
||||
await self._pubsub.psubscribe("chat:*")
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._pubsub:
|
||||
await self._pubsub.close()
|
||||
self._pubsub = None
|
||||
if self._redis:
|
||||
await self._redis.aclose()
|
||||
self._redis = None
|
||||
|
||||
async def publish_event(self, channel: str, payload: dict) -> None:
|
||||
if not self._redis:
|
||||
await self.connect()
|
||||
assert self._redis is not None
|
||||
await self._redis.publish(channel, json.dumps(payload))
|
||||
|
||||
async def consume(self, handler: Callable[[str, dict], Awaitable[None]]) -> None:
|
||||
if not self._pubsub:
|
||||
await self.connect()
|
||||
assert self._pubsub is not None
|
||||
while True:
|
||||
message = await self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
||||
if not message:
|
||||
continue
|
||||
channel = message.get("channel")
|
||||
data = message.get("data")
|
||||
if not channel or not isinstance(data, str):
|
||||
continue
|
||||
payload = json.loads(data)
|
||||
await handler(channel, payload)
|
||||
76
app/realtime/router.py
Normal file
76
app/realtime/router.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, status
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.auth.service import get_current_user_for_ws
|
||||
from app.database.session import AsyncSessionLocal
|
||||
from app.realtime.schemas import (
|
||||
ChatEventPayload,
|
||||
IncomingRealtimeEvent,
|
||||
MessageStatusPayload,
|
||||
OutgoingRealtimeEvent,
|
||||
SendMessagePayload,
|
||||
)
|
||||
from app.realtime.service import realtime_gateway
|
||||
|
||||
router = APIRouter(prefix="/realtime", tags=["realtime"])
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_gateway(websocket: WebSocket) -> None:
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
try:
|
||||
user = await get_current_user_for_ws(token, db)
|
||||
except Exception:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
user_chat_ids = await realtime_gateway.load_user_chat_ids(db, user.id)
|
||||
await websocket.accept()
|
||||
connection_id = await realtime_gateway.register(user.id, websocket, user_chat_ids)
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw_data = await websocket.receive_json()
|
||||
try:
|
||||
event = IncomingRealtimeEvent.model_validate(raw_data)
|
||||
await _dispatch_event(db, user.id, event)
|
||||
except ValidationError:
|
||||
await websocket.send_json(
|
||||
OutgoingRealtimeEvent(
|
||||
event="error",
|
||||
payload={"detail": "Invalid event payload"},
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
).model_dump(mode="json")
|
||||
)
|
||||
except Exception as exc:
|
||||
await websocket.send_json(
|
||||
OutgoingRealtimeEvent(
|
||||
event="error",
|
||||
payload={"detail": str(exc)},
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
).model_dump(mode="json")
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
await realtime_gateway.unregister(user.id, connection_id, user_chat_ids)
|
||||
|
||||
|
||||
async def _dispatch_event(db, user_id: int, event: IncomingRealtimeEvent) -> None:
|
||||
if event.event == "send_message":
|
||||
payload = SendMessagePayload.model_validate(event.payload)
|
||||
await realtime_gateway.handle_send_message(db, user_id, payload)
|
||||
return
|
||||
if event.event in {"typing_start", "typing_stop"}:
|
||||
payload = ChatEventPayload.model_validate(event.payload)
|
||||
await realtime_gateway.handle_typing_event(db, user_id, payload, event.event)
|
||||
return
|
||||
if event.event in {"message_read", "message_delivered"}:
|
||||
payload = MessageStatusPayload.model_validate(event.payload)
|
||||
await realtime_gateway.handle_message_status(db, user_id, payload, event.event)
|
||||
return
|
||||
48
app/realtime/schemas.py
Normal file
48
app/realtime/schemas.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.messages.models import MessageType
|
||||
|
||||
|
||||
RealtimeEventName = Literal[
|
||||
"connect",
|
||||
"disconnect",
|
||||
"send_message",
|
||||
"receive_message",
|
||||
"typing_start",
|
||||
"typing_stop",
|
||||
"message_read",
|
||||
"message_delivered",
|
||||
"error",
|
||||
]
|
||||
|
||||
|
||||
class SendMessagePayload(BaseModel):
|
||||
chat_id: int
|
||||
type: MessageType = MessageType.TEXT
|
||||
text: str | None = Field(default=None, max_length=4096)
|
||||
temp_id: str | None = None
|
||||
|
||||
|
||||
class ChatEventPayload(BaseModel):
|
||||
chat_id: int
|
||||
|
||||
|
||||
class MessageStatusPayload(BaseModel):
|
||||
chat_id: int
|
||||
message_id: int
|
||||
|
||||
|
||||
class IncomingRealtimeEvent(BaseModel):
|
||||
event: Literal["send_message", "typing_start", "typing_stop", "message_read", "message_delivered"]
|
||||
payload: dict[str, Any]
|
||||
|
||||
|
||||
class OutgoingRealtimeEvent(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
event: RealtimeEventName
|
||||
payload: dict[str, Any]
|
||||
timestamp: datetime
|
||||
178
app/realtime/service.py
Normal file
178
app/realtime/service.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import WebSocket
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.chats.repository import list_user_chat_ids
|
||||
from app.chats.service import ensure_chat_membership
|
||||
from app.messages.schemas import MessageCreateRequest, MessageRead
|
||||
from app.messages.service import create_chat_message
|
||||
from app.realtime.models import ConnectionContext
|
||||
from app.realtime.repository import RedisRealtimeRepository
|
||||
from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload
|
||||
|
||||
|
||||
class RealtimeGateway:
|
||||
def __init__(self) -> None:
|
||||
self._repo = RedisRealtimeRepository()
|
||||
self._consume_task: asyncio.Task | None = None
|
||||
self._distributed_enabled = False
|
||||
self._connections: dict[int, dict[str, ConnectionContext]] = defaultdict(dict)
|
||||
self._chat_subscribers: dict[int, set[int]] = defaultdict(set)
|
||||
|
||||
async def start(self) -> None:
|
||||
try:
|
||||
await self._repo.connect()
|
||||
if not self._consume_task:
|
||||
self._consume_task = asyncio.create_task(self._repo.consume(self._handle_redis_event))
|
||||
self._distributed_enabled = True
|
||||
except Exception:
|
||||
self._distributed_enabled = False
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._consume_task:
|
||||
self._consume_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._consume_task
|
||||
self._consume_task = None
|
||||
await self._repo.close()
|
||||
self._distributed_enabled = False
|
||||
|
||||
async def register(self, user_id: int, websocket: WebSocket, user_chat_ids: list[int]) -> str:
|
||||
connection_id = str(uuid4())
|
||||
self._connections[user_id][connection_id] = ConnectionContext(
|
||||
user_id=user_id,
|
||||
connection_id=connection_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
for chat_id in user_chat_ids:
|
||||
self._chat_subscribers[chat_id].add(user_id)
|
||||
await self._send_user_event(
|
||||
user_id,
|
||||
OutgoingRealtimeEvent(
|
||||
event="connect",
|
||||
payload={"connection_id": connection_id},
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
),
|
||||
)
|
||||
return connection_id
|
||||
|
||||
async def unregister(self, user_id: int, connection_id: str, user_chat_ids: list[int]) -> None:
|
||||
user_connections = self._connections.get(user_id, {})
|
||||
user_connections.pop(connection_id, None)
|
||||
if not user_connections:
|
||||
self._connections.pop(user_id, None)
|
||||
for chat_id in user_chat_ids:
|
||||
subscribers = self._chat_subscribers.get(chat_id)
|
||||
if not subscribers:
|
||||
continue
|
||||
subscribers.discard(user_id)
|
||||
if not subscribers:
|
||||
self._chat_subscribers.pop(chat_id, None)
|
||||
|
||||
async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None:
|
||||
message = await create_chat_message(
|
||||
db,
|
||||
sender_id=user_id,
|
||||
payload=MessageCreateRequest(chat_id=payload.chat_id, type=payload.type, text=payload.text),
|
||||
)
|
||||
message_data = MessageRead.model_validate(message).model_dump(mode="json")
|
||||
await self._publish_chat_event(
|
||||
payload.chat_id,
|
||||
event="receive_message",
|
||||
payload={
|
||||
"chat_id": payload.chat_id,
|
||||
"message": message_data,
|
||||
"temp_id": payload.temp_id,
|
||||
"sender_id": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def handle_typing_event(self, db: AsyncSession, user_id: int, payload: ChatEventPayload, event: str) -> None:
|
||||
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
|
||||
await self._publish_chat_event(
|
||||
payload.chat_id,
|
||||
event=event,
|
||||
payload={"chat_id": payload.chat_id, "user_id": user_id},
|
||||
)
|
||||
|
||||
async def handle_message_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
payload: MessageStatusPayload,
|
||||
event: str,
|
||||
) -> None:
|
||||
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
|
||||
await self._publish_chat_event(
|
||||
payload.chat_id,
|
||||
event=event,
|
||||
payload={
|
||||
"chat_id": payload.chat_id,
|
||||
"message_id": payload.message_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def load_user_chat_ids(self, db: AsyncSession, user_id: int) -> list[int]:
|
||||
return await list_user_chat_ids(db, user_id=user_id)
|
||||
|
||||
async def _handle_redis_event(self, channel: str, payload: dict) -> None:
|
||||
chat_id = self._extract_chat_id(channel)
|
||||
if chat_id is None:
|
||||
return
|
||||
subscribers = self._chat_subscribers.get(chat_id, set())
|
||||
if not subscribers:
|
||||
return
|
||||
event = OutgoingRealtimeEvent(
|
||||
event=payload.get("event", "error"),
|
||||
payload=payload.get("payload", {}),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
await asyncio.gather(*(self._send_user_event(user_id, event) for user_id in subscribers), return_exceptions=True)
|
||||
|
||||
async def _publish_chat_event(self, chat_id: int, *, event: str, payload: dict) -> None:
|
||||
event_payload = {
|
||||
"event": event,
|
||||
"payload": payload,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
if self._distributed_enabled:
|
||||
await self._repo.publish_event(f"chat:{chat_id}", event_payload)
|
||||
return
|
||||
await self._handle_redis_event(f"chat:{chat_id}", event_payload)
|
||||
|
||||
async def _send_user_event(self, user_id: int, event: OutgoingRealtimeEvent) -> None:
|
||||
user_connections = self._connections.get(user_id, {})
|
||||
if not user_connections:
|
||||
return
|
||||
disconnected: list[str] = []
|
||||
for connection_id, context in user_connections.items():
|
||||
try:
|
||||
await context.websocket.send_json(event.model_dump(mode="json"))
|
||||
except Exception:
|
||||
disconnected.append(connection_id)
|
||||
for connection_id in disconnected:
|
||||
user_connections.pop(connection_id, None)
|
||||
if not user_connections:
|
||||
self._connections.pop(user_id, None)
|
||||
for chat_id, subscribers in list(self._chat_subscribers.items()):
|
||||
subscribers.discard(user_id)
|
||||
if not subscribers:
|
||||
self._chat_subscribers.pop(chat_id, None)
|
||||
|
||||
@staticmethod
|
||||
def _extract_chat_id(channel: str) -> int | None:
|
||||
if not channel.startswith("chat:"):
|
||||
return None
|
||||
chat_id = channel.split(":", maxsplit=1)[1]
|
||||
if not chat_id.isdigit():
|
||||
return None
|
||||
return int(chat_id)
|
||||
|
||||
|
||||
realtime_gateway = RealtimeGateway()
|
||||
Reference in New Issue
Block a user