93 lines
3.6 KiB
Python
93 lines
3.6 KiB
Python
from datetime import datetime, timezone
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, status
|
|
from pydantic import ValidationError
|
|
|
|
from app.auth.service import get_current_user_for_ws
|
|
from app.database.session import AsyncSessionLocal
|
|
from app.realtime.schemas import (
|
|
ChatEventPayload,
|
|
IncomingRealtimeEvent,
|
|
MessageStatusPayload,
|
|
OutgoingRealtimeEvent,
|
|
SendMessagePayload,
|
|
)
|
|
from app.realtime.service import realtime_gateway
|
|
|
|
router = APIRouter(prefix="/realtime", tags=["realtime"])
|
|
|
|
|
|
@router.websocket("/ws")
|
|
async def websocket_gateway(websocket: WebSocket) -> None:
|
|
token = websocket.query_params.get("token")
|
|
if not token:
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
return
|
|
|
|
async with AsyncSessionLocal() as db:
|
|
try:
|
|
user = await get_current_user_for_ws(token, db)
|
|
except Exception:
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
return
|
|
|
|
user_chat_ids = await realtime_gateway.load_user_chat_ids(db, user.id)
|
|
await websocket.accept()
|
|
connection_id = await realtime_gateway.register(user.id, websocket, user_chat_ids)
|
|
|
|
try:
|
|
while True:
|
|
raw_data = await websocket.receive_json()
|
|
try:
|
|
event = IncomingRealtimeEvent.model_validate(raw_data)
|
|
if event.event == "ping":
|
|
await websocket.send_json(
|
|
OutgoingRealtimeEvent(
|
|
event="pong",
|
|
payload={},
|
|
timestamp=datetime.now(timezone.utc),
|
|
).model_dump(mode="json")
|
|
)
|
|
continue
|
|
await _dispatch_event(db, user.id, event)
|
|
except ValidationError:
|
|
await websocket.send_json(
|
|
OutgoingRealtimeEvent(
|
|
event="error",
|
|
payload={"detail": "Invalid event payload"},
|
|
timestamp=datetime.now(timezone.utc),
|
|
).model_dump(mode="json")
|
|
)
|
|
except Exception as exc:
|
|
await websocket.send_json(
|
|
OutgoingRealtimeEvent(
|
|
event="error",
|
|
payload={"detail": str(exc)},
|
|
timestamp=datetime.now(timezone.utc),
|
|
).model_dump(mode="json")
|
|
)
|
|
except WebSocketDisconnect:
|
|
await realtime_gateway.unregister(user.id, connection_id, user_chat_ids)
|
|
|
|
|
|
async def _dispatch_event(db, user_id: int, event: IncomingRealtimeEvent) -> None:
|
|
if event.event == "send_message":
|
|
payload = SendMessagePayload.model_validate(event.payload)
|
|
await realtime_gateway.handle_send_message(db, user_id, payload)
|
|
return
|
|
if event.event in {
|
|
"typing_start",
|
|
"typing_stop",
|
|
"recording_voice_start",
|
|
"recording_voice_stop",
|
|
"recording_video_start",
|
|
"recording_video_stop",
|
|
}:
|
|
payload = ChatEventPayload.model_validate(event.payload)
|
|
await realtime_gateway.handle_typing_event(db, user_id, payload, event.event)
|
|
return
|
|
if event.event in {"message_read", "message_delivered"}:
|
|
payload = MessageStatusPayload.model_validate(event.payload)
|
|
await realtime_gateway.handle_message_status(db, user_id, payload, event.event)
|
|
return
|