feat(reactions): add message reactions API and web quick reactions
This commit is contained in:
@@ -2,7 +2,7 @@ from app.auth.models import EmailVerificationToken, PasswordResetToken
|
||||
from app.chats.models import Chat, ChatMember
|
||||
from app.email.models import EmailLog
|
||||
from app.media.models import Attachment
|
||||
from app.messages.models import Message, MessageHidden, MessageIdempotencyKey, MessageReceipt
|
||||
from app.messages.models import Message, MessageHidden, MessageIdempotencyKey, MessageReaction, MessageReceipt
|
||||
from app.notifications.models import NotificationLog
|
||||
from app.users.models import User
|
||||
|
||||
@@ -14,6 +14,7 @@ __all__ = [
|
||||
"EmailVerificationToken",
|
||||
"Message",
|
||||
"MessageIdempotencyKey",
|
||||
"MessageReaction",
|
||||
"MessageReceipt",
|
||||
"NotificationLog",
|
||||
"PasswordResetToken",
|
||||
|
||||
@@ -93,3 +93,16 @@ class MessageHidden(Base):
|
||||
message_id: Mapped[int] = mapped_column(ForeignKey("messages.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
|
||||
|
||||
class MessageReaction(Base):
|
||||
__tablename__ = "message_reactions"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("message_id", "user_id", name="uq_message_reactions_message_user"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
message_id: Mapped[int] = mapped_column(ForeignKey("messages.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
emoji: Mapped[str] = mapped_column(String(16), nullable=False, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.chats.models import ChatMember
|
||||
from app.messages.models import Message, MessageHidden, MessageIdempotencyKey, MessageReceipt, MessageType
|
||||
from app.messages.models import Message, MessageHidden, MessageIdempotencyKey, MessageReaction, MessageReceipt, MessageType
|
||||
|
||||
|
||||
async def create_message(
|
||||
@@ -178,3 +178,44 @@ async def create_message_receipt(
|
||||
db.add(receipt)
|
||||
await db.flush()
|
||||
return receipt
|
||||
|
||||
|
||||
async def get_message_reaction(db: AsyncSession, *, message_id: int, user_id: int) -> MessageReaction | None:
|
||||
result = await db.execute(
|
||||
select(MessageReaction)
|
||||
.where(MessageReaction.message_id == message_id, MessageReaction.user_id == user_id)
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def list_message_reactions(db: AsyncSession, *, message_id: int) -> list[tuple[str, int]]:
|
||||
result = await db.execute(
|
||||
select(MessageReaction.emoji, func.count(MessageReaction.id))
|
||||
.where(MessageReaction.message_id == message_id)
|
||||
.group_by(MessageReaction.emoji)
|
||||
.order_by(func.count(MessageReaction.id).desc(), MessageReaction.emoji.asc())
|
||||
)
|
||||
return [(str(emoji), int(count)) for emoji, count in result.all()]
|
||||
|
||||
|
||||
async def upsert_message_reaction(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
message_id: int,
|
||||
user_id: int,
|
||||
emoji: str,
|
||||
) -> tuple[MessageReaction | None, str]:
|
||||
existing = await get_message_reaction(db, message_id=message_id, user_id=user_id)
|
||||
if existing and existing.emoji == emoji:
|
||||
await db.delete(existing)
|
||||
await db.flush()
|
||||
return None, "removed"
|
||||
if existing:
|
||||
existing.emoji = emoji
|
||||
await db.flush()
|
||||
return existing, "updated"
|
||||
reaction = MessageReaction(message_id=message_id, user_id=user_id, emoji=emoji)
|
||||
db.add(reaction)
|
||||
await db.flush()
|
||||
return reaction, "added"
|
||||
|
||||
@@ -3,14 +3,24 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.service import get_current_user
|
||||
from app.database.session import get_db
|
||||
from app.messages.schemas import MessageCreateRequest, MessageForwardRequest, MessageRead, MessageStatusUpdateRequest, MessageUpdateRequest
|
||||
from app.messages.schemas import (
|
||||
MessageCreateRequest,
|
||||
MessageForwardRequest,
|
||||
MessageReactionRead,
|
||||
MessageReactionToggleRequest,
|
||||
MessageRead,
|
||||
MessageStatusUpdateRequest,
|
||||
MessageUpdateRequest,
|
||||
)
|
||||
from app.messages.service import (
|
||||
create_chat_message,
|
||||
delete_message,
|
||||
delete_message_for_all,
|
||||
forward_message,
|
||||
get_messages,
|
||||
list_message_reactions,
|
||||
search_messages,
|
||||
toggle_message_reaction,
|
||||
update_message,
|
||||
)
|
||||
from app.realtime.schemas import MessageStatusPayload
|
||||
@@ -104,3 +114,22 @@ async def forward_message_endpoint(
|
||||
message = await forward_message(db, source_message_id=message_id, sender_id=current_user.id, payload=payload)
|
||||
await realtime_gateway.publish_message_created(message=message, sender_id=current_user.id)
|
||||
return message
|
||||
|
||||
|
||||
@router.get("/{message_id}/reactions", response_model=list[MessageReactionRead])
|
||||
async def list_reactions_endpoint(
|
||||
message_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list[MessageReactionRead]:
|
||||
return await list_message_reactions(db, message_id=message_id, user_id=current_user.id)
|
||||
|
||||
|
||||
@router.post("/{message_id}/reactions/toggle", response_model=list[MessageReactionRead])
|
||||
async def toggle_reaction_endpoint(
|
||||
message_id: int,
|
||||
payload: MessageReactionToggleRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list[MessageReactionRead]:
|
||||
return await toggle_message_reaction(db, message_id=message_id, user_id=current_user.id, payload=payload)
|
||||
|
||||
@@ -40,3 +40,17 @@ class MessageStatusUpdateRequest(BaseModel):
|
||||
|
||||
class MessageForwardRequest(BaseModel):
|
||||
target_chat_id: int
|
||||
|
||||
|
||||
class MessageForwardBulkRequest(BaseModel):
|
||||
target_chat_ids: list[int] = Field(min_length=1, max_length=20)
|
||||
|
||||
|
||||
class MessageReactionToggleRequest(BaseModel):
|
||||
emoji: str = Field(min_length=1, max_length=16)
|
||||
|
||||
|
||||
class MessageReactionRead(BaseModel):
|
||||
emoji: str
|
||||
count: int
|
||||
reacted: bool = False
|
||||
|
||||
@@ -8,7 +8,14 @@ from app.chats.service import ensure_chat_membership
|
||||
from app.messages import repository
|
||||
from app.messages.models import Message
|
||||
from app.messages.spam_guard import enforce_message_spam_policy
|
||||
from app.messages.schemas import MessageCreateRequest, MessageForwardRequest, MessageStatusUpdateRequest, MessageUpdateRequest
|
||||
from app.messages.schemas import (
|
||||
MessageCreateRequest,
|
||||
MessageForwardRequest,
|
||||
MessageReactionRead,
|
||||
MessageReactionToggleRequest,
|
||||
MessageStatusUpdateRequest,
|
||||
MessageUpdateRequest,
|
||||
)
|
||||
from app.notifications.service import dispatch_message_notifications
|
||||
from app.users.repository import has_block_relation_between_users
|
||||
from app.users.service import get_user_by_id
|
||||
@@ -267,3 +274,43 @@ async def forward_message(
|
||||
await db.commit()
|
||||
await db.refresh(forwarded)
|
||||
return forwarded
|
||||
|
||||
|
||||
async def list_message_reactions(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
message_id: int,
|
||||
user_id: int,
|
||||
) -> list[MessageReactionRead]:
|
||||
message = await repository.get_message_by_id(db, message_id)
|
||||
if not message:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
||||
await ensure_chat_membership(db, chat_id=message.chat_id, user_id=user_id)
|
||||
counts = await repository.list_message_reactions(db, message_id=message_id)
|
||||
mine = await repository.get_message_reaction(db, message_id=message_id, user_id=user_id)
|
||||
mine_emoji = mine.emoji if mine else None
|
||||
return [
|
||||
MessageReactionRead(emoji=emoji, count=count, reacted=(emoji == mine_emoji))
|
||||
for emoji, count in counts
|
||||
]
|
||||
|
||||
|
||||
async def toggle_message_reaction(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
message_id: int,
|
||||
user_id: int,
|
||||
payload: MessageReactionToggleRequest,
|
||||
) -> list[MessageReactionRead]:
|
||||
message = await repository.get_message_by_id(db, message_id)
|
||||
if not message:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
||||
await ensure_chat_membership(db, chat_id=message.chat_id, user_id=user_id)
|
||||
await repository.upsert_message_reaction(
|
||||
db,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
emoji=payload.emoji.strip(),
|
||||
)
|
||||
await db.commit()
|
||||
return await list_message_reactions(db, message_id=message_id, user_id=user_id)
|
||||
|
||||
Reference in New Issue
Block a user