from datetime import datetime, timezone from fastapi import APIRouter, WebSocket, WebSocketDisconnect, status from pydantic import ValidationError from app.auth.service import get_current_user_for_ws from app.database.session import AsyncSessionLocal from app.realtime.schemas import ( ChatEventPayload, IncomingRealtimeEvent, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload, ) from app.realtime.service import realtime_gateway router = APIRouter(prefix="/realtime", tags=["realtime"]) @router.websocket("/ws") async def websocket_gateway(websocket: WebSocket) -> None: token = websocket.query_params.get("token") if not token: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return async with AsyncSessionLocal() as db: try: user = await get_current_user_for_ws(token, db) except Exception: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return user_chat_ids = await realtime_gateway.load_user_chat_ids(db, user.id) await websocket.accept() connection_id = await realtime_gateway.register(user.id, websocket, user_chat_ids) try: while True: raw_data = await websocket.receive_json() try: event = IncomingRealtimeEvent.model_validate(raw_data) await _dispatch_event(db, user.id, event) except ValidationError: await websocket.send_json( OutgoingRealtimeEvent( event="error", payload={"detail": "Invalid event payload"}, timestamp=datetime.now(timezone.utc), ).model_dump(mode="json") ) except Exception as exc: await websocket.send_json( OutgoingRealtimeEvent( event="error", payload={"detail": str(exc)}, timestamp=datetime.now(timezone.utc), ).model_dump(mode="json") ) except WebSocketDisconnect: await realtime_gateway.unregister(user.id, connection_id, user_chat_ids) async def _dispatch_event(db, user_id: int, event: IncomingRealtimeEvent) -> None: if event.event == "send_message": payload = SendMessagePayload.model_validate(event.payload) await realtime_gateway.handle_send_message(db, user_id, payload) return if event.event in {"typing_start", "typing_stop"}: payload = ChatEventPayload.model_validate(event.payload) await realtime_gateway.handle_typing_event(db, user_id, payload, event.event) return if event.event in {"message_read", "message_delivered"}: payload = MessageStatusPayload.model_validate(event.payload) await realtime_gateway.handle_message_status(db, user_id, payload, event.event) return