feat: add message reliability foundation
All checks were successful
CI / test (push) Successful in 23s
All checks were successful
CI / test (push) Successful in 23s
- implement idempotent message creation via client_message_id - add persistent delivered/read receipts - expose /messages/status and wire websocket receipt events - update web client to send client ids and auto-ack delivered/read
This commit is contained in:
@@ -2,7 +2,7 @@ from app.auth.models import EmailVerificationToken, PasswordResetToken
|
||||
from app.chats.models import Chat, ChatMember
|
||||
from app.email.models import EmailLog
|
||||
from app.media.models import Attachment
|
||||
from app.messages.models import Message
|
||||
from app.messages.models import Message, MessageIdempotencyKey, MessageReceipt
|
||||
from app.notifications.models import NotificationLog
|
||||
from app.users.models import User
|
||||
|
||||
@@ -13,6 +13,8 @@ __all__ = [
|
||||
"EmailLog",
|
||||
"EmailVerificationToken",
|
||||
"Message",
|
||||
"MessageIdempotencyKey",
|
||||
"MessageReceipt",
|
||||
"NotificationLog",
|
||||
"PasswordResetToken",
|
||||
"User",
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import DateTime, Enum as SAEnum, ForeignKey, Text, func
|
||||
from sqlalchemy import DateTime, Enum as SAEnum, ForeignKey, String, Text, UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database.base import Base
|
||||
@@ -42,3 +42,34 @@ class Message(Base):
|
||||
chat: Mapped["Chat"] = relationship(back_populates="messages")
|
||||
sender: Mapped["User"] = relationship(back_populates="sent_messages")
|
||||
attachments: Mapped[list["Attachment"]] = relationship(back_populates="message", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class MessageIdempotencyKey(Base):
|
||||
__tablename__ = "message_idempotency_keys"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("chat_id", "sender_id", "client_message_id", name="uq_msg_idem_chat_sender_client"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
chat_id: Mapped[int] = mapped_column(ForeignKey("chats.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
sender_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
client_message_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
message_id: Mapped[int] = mapped_column(ForeignKey("messages.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
|
||||
|
||||
class MessageReceipt(Base):
|
||||
__tablename__ = "message_receipts"
|
||||
__table_args__ = (UniqueConstraint("chat_id", "user_id", name="uq_msg_receipts_chat_user"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
chat_id: Mapped[int] = mapped_column(ForeignKey("chats.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
last_delivered_message_id: Mapped[int | None] = mapped_column(nullable=True)
|
||||
last_read_message_id: Mapped[int | None] = mapped_column(nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.messages.models import Message, MessageType
|
||||
from app.messages.models import Message, MessageIdempotencyKey, MessageReceipt, MessageType
|
||||
|
||||
|
||||
async def create_message(
|
||||
@@ -18,6 +18,45 @@ async def create_message(
|
||||
return message
|
||||
|
||||
|
||||
async def get_message_by_client_message_id(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
chat_id: int,
|
||||
sender_id: int,
|
||||
client_message_id: str,
|
||||
) -> Message | None:
|
||||
result = await db.execute(
|
||||
select(Message)
|
||||
.join(MessageIdempotencyKey, MessageIdempotencyKey.message_id == Message.id)
|
||||
.where(
|
||||
MessageIdempotencyKey.chat_id == chat_id,
|
||||
MessageIdempotencyKey.sender_id == sender_id,
|
||||
MessageIdempotencyKey.client_message_id == client_message_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def create_message_idempotency_key(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
chat_id: int,
|
||||
sender_id: int,
|
||||
client_message_id: str,
|
||||
message_id: int,
|
||||
) -> MessageIdempotencyKey:
|
||||
key = MessageIdempotencyKey(
|
||||
chat_id=chat_id,
|
||||
sender_id=sender_id,
|
||||
client_message_id=client_message_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
db.add(key)
|
||||
await db.flush()
|
||||
return key
|
||||
|
||||
|
||||
async def get_message_by_id(db: AsyncSession, message_id: int) -> Message | None:
|
||||
result = await db.execute(select(Message).where(Message.id == message_id))
|
||||
return result.scalar_one_or_none()
|
||||
@@ -39,3 +78,32 @@ async def list_chat_messages(
|
||||
|
||||
async def delete_message(db: AsyncSession, message: Message) -> None:
|
||||
await db.delete(message)
|
||||
|
||||
|
||||
async def get_message_receipt(db: AsyncSession, *, chat_id: int, user_id: int) -> MessageReceipt | None:
|
||||
result = await db.execute(
|
||||
select(MessageReceipt).where(
|
||||
MessageReceipt.chat_id == chat_id,
|
||||
MessageReceipt.user_id == user_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def create_message_receipt(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: int,
|
||||
last_delivered_message_id: int | None,
|
||||
last_read_message_id: int | None,
|
||||
) -> MessageReceipt:
|
||||
receipt = MessageReceipt(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
last_delivered_message_id=last_delivered_message_id,
|
||||
last_read_message_id=last_read_message_id,
|
||||
)
|
||||
db.add(receipt)
|
||||
await db.flush()
|
||||
return receipt
|
||||
|
||||
@@ -3,8 +3,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.service import get_current_user
|
||||
from app.database.session import get_db
|
||||
from app.messages.schemas import MessageCreateRequest, MessageRead, MessageUpdateRequest
|
||||
from app.messages.schemas import MessageCreateRequest, MessageRead, MessageStatusUpdateRequest, MessageUpdateRequest
|
||||
from app.messages.service import create_chat_message, delete_message, get_messages, update_message
|
||||
from app.realtime.schemas import MessageStatusPayload
|
||||
from app.realtime.service import realtime_gateway
|
||||
from app.users.models import User
|
||||
|
||||
@@ -18,7 +19,11 @@ async def create_message(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> MessageRead:
|
||||
message = await create_chat_message(db, sender_id=current_user.id, payload=payload)
|
||||
await realtime_gateway.publish_message_created(message=message, sender_id=current_user.id)
|
||||
await realtime_gateway.publish_message_created(
|
||||
message=message,
|
||||
sender_id=current_user.id,
|
||||
client_message_id=payload.client_message_id,
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
@@ -50,3 +55,17 @@ async def remove_message(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> None:
|
||||
await delete_message(db, message_id=message_id, user_id=current_user.id)
|
||||
|
||||
|
||||
@router.post("/status")
|
||||
async def update_status(
|
||||
payload: MessageStatusUpdateRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict[str, int]:
|
||||
return await realtime_gateway.handle_message_status(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
payload=MessageStatusPayload(chat_id=payload.chat_id, message_id=payload.message_id),
|
||||
event=payload.status,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
@@ -21,7 +22,14 @@ class MessageCreateRequest(BaseModel):
|
||||
chat_id: int
|
||||
type: MessageType = MessageType.TEXT
|
||||
text: str | None = Field(default=None, max_length=4096)
|
||||
client_message_id: str | None = Field(default=None, min_length=8, max_length=64)
|
||||
|
||||
|
||||
class MessageUpdateRequest(BaseModel):
|
||||
text: str = Field(min_length=1, max_length=4096)
|
||||
|
||||
|
||||
class MessageStatusUpdateRequest(BaseModel):
|
||||
chat_id: int
|
||||
message_id: int
|
||||
status: Literal["message_delivered", "message_read"]
|
||||
|
||||
@@ -1,29 +1,60 @@
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.chats.service import ensure_chat_membership
|
||||
from app.messages import repository
|
||||
from app.messages.models import Message
|
||||
from app.messages.spam_guard import enforce_message_spam_policy
|
||||
from app.messages.schemas import MessageCreateRequest, MessageUpdateRequest
|
||||
from app.messages.schemas import MessageCreateRequest, MessageStatusUpdateRequest, MessageUpdateRequest
|
||||
from app.notifications.service import dispatch_message_notifications
|
||||
|
||||
|
||||
async def create_chat_message(db: AsyncSession, *, sender_id: int, payload: MessageCreateRequest) -> Message:
|
||||
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=sender_id)
|
||||
if payload.client_message_id:
|
||||
existing = await repository.get_message_by_client_message_id(
|
||||
db,
|
||||
chat_id=payload.chat_id,
|
||||
sender_id=sender_id,
|
||||
client_message_id=payload.client_message_id,
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
if payload.type.value == "text" and not (payload.text and payload.text.strip()):
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Text message cannot be empty")
|
||||
await enforce_message_spam_policy(user_id=sender_id, chat_id=payload.chat_id, text=payload.text)
|
||||
|
||||
message = await repository.create_message(
|
||||
db,
|
||||
chat_id=payload.chat_id,
|
||||
sender_id=sender_id,
|
||||
message_type=payload.type,
|
||||
text=payload.text,
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(message)
|
||||
try:
|
||||
message = await repository.create_message(
|
||||
db,
|
||||
chat_id=payload.chat_id,
|
||||
sender_id=sender_id,
|
||||
message_type=payload.type,
|
||||
text=payload.text,
|
||||
)
|
||||
if payload.client_message_id:
|
||||
await repository.create_message_idempotency_key(
|
||||
db,
|
||||
chat_id=payload.chat_id,
|
||||
sender_id=sender_id,
|
||||
client_message_id=payload.client_message_id,
|
||||
message_id=message.id,
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(message)
|
||||
except IntegrityError:
|
||||
await db.rollback()
|
||||
if payload.client_message_id:
|
||||
existing = await repository.get_message_by_client_message_id(
|
||||
db,
|
||||
chat_id=payload.chat_id,
|
||||
sender_id=sender_id,
|
||||
client_message_id=payload.client_message_id,
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
raise
|
||||
try:
|
||||
await dispatch_message_notifications(db, message)
|
||||
except Exception:
|
||||
@@ -73,3 +104,43 @@ async def delete_message(db: AsyncSession, *, message_id: int, user_id: int) ->
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You can delete only your own messages")
|
||||
await repository.delete_message(db, message)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def mark_message_status(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: int,
|
||||
payload: MessageStatusUpdateRequest,
|
||||
) -> dict[str, int]:
|
||||
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
|
||||
message = await repository.get_message_by_id(db, payload.message_id)
|
||||
if not message or message.chat_id != payload.chat_id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
||||
|
||||
receipt = await repository.get_message_receipt(db, chat_id=payload.chat_id, user_id=user_id)
|
||||
if not receipt:
|
||||
last_delivered = payload.message_id if payload.status in {"message_delivered", "message_read"} else None
|
||||
last_read = payload.message_id if payload.status == "message_read" else None
|
||||
receipt = await repository.create_message_receipt(
|
||||
db,
|
||||
chat_id=payload.chat_id,
|
||||
user_id=user_id,
|
||||
last_delivered_message_id=last_delivered,
|
||||
last_read_message_id=last_read,
|
||||
)
|
||||
else:
|
||||
if payload.status in {"message_delivered", "message_read"}:
|
||||
current = receipt.last_delivered_message_id or 0
|
||||
receipt.last_delivered_message_id = max(current, payload.message_id)
|
||||
if payload.status == "message_read":
|
||||
current_read = receipt.last_read_message_id or 0
|
||||
receipt.last_read_message_id = max(current_read, payload.message_id)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(receipt)
|
||||
return {
|
||||
"chat_id": payload.chat_id,
|
||||
"message_id": payload.message_id,
|
||||
"last_delivered_message_id": receipt.last_delivered_message_id or 0,
|
||||
"last_read_message_id": receipt.last_read_message_id or 0,
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ class SendMessagePayload(BaseModel):
|
||||
type: MessageType = MessageType.TEXT
|
||||
text: str | None = Field(default=None, max_length=4096)
|
||||
temp_id: str | None = None
|
||||
client_message_id: str | None = Field(default=None, min_length=8, max_length=64)
|
||||
|
||||
|
||||
class ChatEventPayload(BaseModel):
|
||||
|
||||
@@ -9,8 +9,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.chats.repository import list_user_chat_ids
|
||||
from app.chats.service import ensure_chat_membership
|
||||
from app.messages.schemas import MessageCreateRequest, MessageRead
|
||||
from app.messages.service import create_chat_message
|
||||
from app.messages.schemas import MessageCreateRequest, MessageRead, MessageStatusUpdateRequest
|
||||
from app.messages.service import create_chat_message, mark_message_status
|
||||
from app.realtime.models import ConnectionContext
|
||||
from app.realtime.presence import mark_user_offline, mark_user_online
|
||||
from app.realtime.repository import RedisRealtimeRepository
|
||||
@@ -81,9 +81,19 @@ class RealtimeGateway:
|
||||
message = await create_chat_message(
|
||||
db,
|
||||
sender_id=user_id,
|
||||
payload=MessageCreateRequest(chat_id=payload.chat_id, type=payload.type, text=payload.text),
|
||||
payload=MessageCreateRequest(
|
||||
chat_id=payload.chat_id,
|
||||
type=payload.type,
|
||||
text=payload.text,
|
||||
client_message_id=payload.client_message_id or payload.temp_id,
|
||||
),
|
||||
)
|
||||
await self.publish_message_created(
|
||||
message=message,
|
||||
sender_id=user_id,
|
||||
temp_id=payload.temp_id,
|
||||
client_message_id=payload.client_message_id,
|
||||
)
|
||||
await self.publish_message_created(message=message, sender_id=user_id, temp_id=payload.temp_id)
|
||||
|
||||
async def handle_typing_event(self, db: AsyncSession, user_id: int, payload: ChatEventPayload, event: str) -> None:
|
||||
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
|
||||
@@ -99,8 +109,12 @@ class RealtimeGateway:
|
||||
user_id: int,
|
||||
payload: MessageStatusPayload,
|
||||
event: str,
|
||||
) -> None:
|
||||
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
|
||||
) -> dict[str, int]:
|
||||
receipt_state = await mark_message_status(
|
||||
db,
|
||||
user_id=user_id,
|
||||
payload=MessageStatusUpdateRequest(chat_id=payload.chat_id, message_id=payload.message_id, status=event),
|
||||
)
|
||||
await self._publish_chat_event(
|
||||
payload.chat_id,
|
||||
event=event,
|
||||
@@ -108,8 +122,11 @@ class RealtimeGateway:
|
||||
"chat_id": payload.chat_id,
|
||||
"message_id": payload.message_id,
|
||||
"user_id": user_id,
|
||||
"last_delivered_message_id": receipt_state["last_delivered_message_id"],
|
||||
"last_read_message_id": receipt_state["last_read_message_id"],
|
||||
},
|
||||
)
|
||||
return receipt_state
|
||||
|
||||
async def load_user_chat_ids(self, db: AsyncSession, user_id: int) -> list[int]:
|
||||
return await list_user_chat_ids(db, user_id=user_id)
|
||||
@@ -139,7 +156,14 @@ class RealtimeGateway:
|
||||
return
|
||||
await self._handle_redis_event(f"chat:{chat_id}", event_payload)
|
||||
|
||||
async def publish_message_created(self, *, message, sender_id: int, temp_id: str | None = None) -> None:
|
||||
async def publish_message_created(
|
||||
self,
|
||||
*,
|
||||
message,
|
||||
sender_id: int,
|
||||
temp_id: str | None = None,
|
||||
client_message_id: str | None = None,
|
||||
) -> None:
|
||||
message_data = MessageRead.model_validate(message).model_dump(mode="json")
|
||||
await self._publish_chat_event(
|
||||
message.chat_id,
|
||||
@@ -148,6 +172,7 @@ class RealtimeGateway:
|
||||
"chat_id": message.chat_id,
|
||||
"message": message_data,
|
||||
"temp_id": temp_id,
|
||||
"client_message_id": client_message_id,
|
||||
"sender_id": sender_id,
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user