Some checks failed
CI / test (push) Failing after 18s
- add websocket events user_online/user_offline - broadcast presence changes on first connect and final disconnect only - apply live presence updates in web chat store and realtime hook - move public discover into unified left search (users + groups/channels) - remove separate Discover Chats dialog/menu entry
262 lines
10 KiB
Python
262 lines
10 KiB
Python
import asyncio
|
|
import contextlib
|
|
from collections import defaultdict
|
|
from datetime import datetime, timezone
|
|
from uuid import uuid4
|
|
|
|
from fastapi import WebSocket
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.chats.repository import list_user_chat_ids
|
|
from app.chats.service import ensure_chat_membership
|
|
from app.database.session import AsyncSessionLocal
|
|
from app.messages.schemas import MessageCreateRequest, MessageRead, MessageStatusUpdateRequest
|
|
from app.messages.service import create_chat_message, mark_message_status
|
|
from app.realtime.models import ConnectionContext
|
|
from app.realtime.presence import mark_user_offline, mark_user_online
|
|
from app.realtime.repository import RedisRealtimeRepository
|
|
from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload
|
|
from app.users.repository import update_user_last_seen_now
|
|
|
|
|
|
class RealtimeGateway:
|
|
def __init__(self) -> None:
|
|
self._repo = RedisRealtimeRepository()
|
|
self._consume_task: asyncio.Task | None = None
|
|
self._distributed_enabled = False
|
|
self._connections: dict[int, dict[str, ConnectionContext]] = defaultdict(dict)
|
|
self._chat_subscribers: dict[int, set[int]] = defaultdict(set)
|
|
|
|
async def start(self) -> None:
|
|
try:
|
|
await self._repo.connect()
|
|
if not self._consume_task:
|
|
self._consume_task = asyncio.create_task(self._repo.consume(self._handle_redis_event))
|
|
self._distributed_enabled = True
|
|
except Exception:
|
|
self._distributed_enabled = False
|
|
|
|
async def stop(self) -> None:
|
|
if self._consume_task:
|
|
self._consume_task.cancel()
|
|
with contextlib.suppress(asyncio.CancelledError):
|
|
await self._consume_task
|
|
self._consume_task = None
|
|
await self._repo.close()
|
|
self._distributed_enabled = False
|
|
|
|
async def register(self, user_id: int, websocket: WebSocket, user_chat_ids: list[int]) -> str:
|
|
connection_id = str(uuid4())
|
|
self._connections[user_id][connection_id] = ConnectionContext(
|
|
user_id=user_id,
|
|
connection_id=connection_id,
|
|
websocket=websocket,
|
|
)
|
|
for chat_id in user_chat_ids:
|
|
self._chat_subscribers[chat_id].add(user_id)
|
|
became_online = await mark_user_online(user_id)
|
|
if became_online:
|
|
await self._broadcast_presence(user_chat_ids, user_id=user_id, is_online=True, last_seen_at=None)
|
|
await self._send_user_event(
|
|
user_id,
|
|
OutgoingRealtimeEvent(
|
|
event="connect",
|
|
payload={"connection_id": connection_id},
|
|
timestamp=datetime.now(timezone.utc),
|
|
),
|
|
)
|
|
return connection_id
|
|
|
|
async def unregister(self, user_id: int, connection_id: str, user_chat_ids: list[int]) -> None:
|
|
user_connections = self._connections.get(user_id, {})
|
|
user_connections.pop(connection_id, None)
|
|
if not user_connections:
|
|
self._connections.pop(user_id, None)
|
|
for chat_id in user_chat_ids:
|
|
subscribers = self._chat_subscribers.get(chat_id)
|
|
if not subscribers:
|
|
continue
|
|
subscribers.discard(user_id)
|
|
if not subscribers:
|
|
self._chat_subscribers.pop(chat_id, None)
|
|
became_offline = await mark_user_offline(user_id)
|
|
await self._persist_last_seen(user_id)
|
|
if became_offline:
|
|
await self._broadcast_presence(
|
|
user_chat_ids,
|
|
user_id=user_id,
|
|
is_online=False,
|
|
last_seen_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None:
|
|
message = await create_chat_message(
|
|
db,
|
|
sender_id=user_id,
|
|
payload=MessageCreateRequest(
|
|
chat_id=payload.chat_id,
|
|
type=payload.type,
|
|
text=payload.text,
|
|
client_message_id=payload.client_message_id or payload.temp_id,
|
|
reply_to_message_id=payload.reply_to_message_id,
|
|
),
|
|
)
|
|
await self.publish_message_created(
|
|
message=message,
|
|
sender_id=user_id,
|
|
temp_id=payload.temp_id,
|
|
client_message_id=payload.client_message_id,
|
|
)
|
|
|
|
async def handle_typing_event(self, db: AsyncSession, user_id: int, payload: ChatEventPayload, event: str) -> None:
|
|
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
|
|
await self._publish_chat_event(
|
|
payload.chat_id,
|
|
event=event,
|
|
payload={"chat_id": payload.chat_id, "user_id": user_id},
|
|
)
|
|
|
|
async def handle_message_status(
|
|
self,
|
|
db: AsyncSession,
|
|
user_id: int,
|
|
payload: MessageStatusPayload,
|
|
event: str,
|
|
) -> dict[str, int]:
|
|
receipt_state = await mark_message_status(
|
|
db,
|
|
user_id=user_id,
|
|
payload=MessageStatusUpdateRequest(chat_id=payload.chat_id, message_id=payload.message_id, status=event),
|
|
)
|
|
await self._publish_chat_event(
|
|
payload.chat_id,
|
|
event=event,
|
|
payload={
|
|
"chat_id": payload.chat_id,
|
|
"message_id": payload.message_id,
|
|
"user_id": user_id,
|
|
"last_delivered_message_id": receipt_state["last_delivered_message_id"],
|
|
"last_read_message_id": receipt_state["last_read_message_id"],
|
|
},
|
|
)
|
|
return receipt_state
|
|
|
|
async def load_user_chat_ids(self, db: AsyncSession, user_id: int) -> list[int]:
|
|
return await list_user_chat_ids(db, user_id=user_id)
|
|
|
|
async def _handle_redis_event(self, channel: str, payload: dict) -> None:
|
|
chat_id = self._extract_chat_id(channel)
|
|
if chat_id is None:
|
|
return
|
|
subscribers = self._chat_subscribers.get(chat_id, set())
|
|
if not subscribers:
|
|
return
|
|
event = OutgoingRealtimeEvent(
|
|
event=payload.get("event", "error"),
|
|
payload=payload.get("payload", {}),
|
|
timestamp=datetime.now(timezone.utc),
|
|
)
|
|
await asyncio.gather(*(self._send_user_event(user_id, event) for user_id in subscribers), return_exceptions=True)
|
|
|
|
async def _publish_chat_event(self, chat_id: int, *, event: str, payload: dict) -> None:
|
|
event_payload = {
|
|
"event": event,
|
|
"payload": payload,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
if self._distributed_enabled:
|
|
await self._repo.publish_event(f"chat:{chat_id}", event_payload)
|
|
return
|
|
await self._handle_redis_event(f"chat:{chat_id}", event_payload)
|
|
|
|
async def publish_message_created(
|
|
self,
|
|
*,
|
|
message,
|
|
sender_id: int,
|
|
temp_id: str | None = None,
|
|
client_message_id: str | None = None,
|
|
) -> None:
|
|
message_data = MessageRead.model_validate(message).model_dump(mode="json")
|
|
await self._publish_chat_event(
|
|
message.chat_id,
|
|
event="receive_message",
|
|
payload={
|
|
"chat_id": message.chat_id,
|
|
"message": message_data,
|
|
"temp_id": temp_id,
|
|
"client_message_id": client_message_id,
|
|
"sender_id": sender_id,
|
|
},
|
|
)
|
|
|
|
async def _send_user_event(self, user_id: int, event: OutgoingRealtimeEvent) -> None:
|
|
user_connections = self._connections.get(user_id, {})
|
|
if not user_connections:
|
|
return
|
|
disconnected: list[str] = []
|
|
for connection_id, context in user_connections.items():
|
|
try:
|
|
await context.websocket.send_json(event.model_dump(mode="json"))
|
|
except Exception:
|
|
disconnected.append(connection_id)
|
|
for connection_id in disconnected:
|
|
user_connections.pop(connection_id, None)
|
|
if not user_connections:
|
|
self._connections.pop(user_id, None)
|
|
affected_chat_ids: list[int] = []
|
|
for chat_id, subscribers in list(self._chat_subscribers.items()):
|
|
if user_id in subscribers:
|
|
affected_chat_ids.append(chat_id)
|
|
subscribers.discard(user_id)
|
|
if not subscribers:
|
|
self._chat_subscribers.pop(chat_id, None)
|
|
became_offline = await mark_user_offline(user_id)
|
|
await self._persist_last_seen(user_id)
|
|
if became_offline:
|
|
await self._broadcast_presence(
|
|
affected_chat_ids,
|
|
user_id=user_id,
|
|
is_online=False,
|
|
last_seen_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
@staticmethod
|
|
def _extract_chat_id(channel: str) -> int | None:
|
|
if not channel.startswith("chat:"):
|
|
return None
|
|
chat_id = channel.split(":", maxsplit=1)[1]
|
|
if not chat_id.isdigit():
|
|
return None
|
|
return int(chat_id)
|
|
|
|
async def _persist_last_seen(self, user_id: int) -> None:
|
|
try:
|
|
async with AsyncSessionLocal() as db:
|
|
await update_user_last_seen_now(db, user_id=user_id)
|
|
await db.commit()
|
|
except Exception:
|
|
return
|
|
|
|
async def _broadcast_presence(
|
|
self,
|
|
chat_ids: list[int],
|
|
*,
|
|
user_id: int,
|
|
is_online: bool,
|
|
last_seen_at: datetime | None,
|
|
) -> None:
|
|
event_name = "user_online" if is_online else "user_offline"
|
|
for chat_id in chat_ids:
|
|
payload = {
|
|
"chat_id": chat_id,
|
|
"user_id": user_id,
|
|
"is_online": is_online,
|
|
}
|
|
if last_seen_at is not None:
|
|
payload["last_seen_at"] = last_seen_at.isoformat()
|
|
await self._publish_chat_event(chat_id, event=event_name, payload=payload)
|
|
|
|
|
|
realtime_gateway = RealtimeGateway()
|