diff --git a/alembic/versions/0004_reply_forward_pin.py b/alembic/versions/0004_reply_forward_pin.py new file mode 100644 index 0000000..966d641 --- /dev/null +++ b/alembic/versions/0004_reply_forward_pin.py @@ -0,0 +1,64 @@ +"""reply forward pin + +Revision ID: 0004_reply_forward_pin +Revises: 0003_search_indexes +Create Date: 2026-03-08 03:20:00.000000 +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + + +revision: str = "0004_reply_forward_pin" +down_revision: Union[str, Sequence[str], None] = "0003_search_indexes" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column("messages", sa.Column("reply_to_message_id", sa.Integer(), nullable=True)) + op.add_column("messages", sa.Column("forwarded_from_message_id", sa.Integer(), nullable=True)) + op.create_foreign_key( + op.f("fk_messages_reply_to_message_id_messages"), + "messages", + "messages", + ["reply_to_message_id"], + ["id"], + ondelete="SET NULL", + ) + op.create_foreign_key( + op.f("fk_messages_forwarded_from_message_id_messages"), + "messages", + "messages", + ["forwarded_from_message_id"], + ["id"], + ondelete="SET NULL", + ) + op.create_index(op.f("ix_messages_reply_to_message_id"), "messages", ["reply_to_message_id"], unique=False) + op.create_index(op.f("ix_messages_forwarded_from_message_id"), "messages", ["forwarded_from_message_id"], unique=False) + + op.add_column("chats", sa.Column("pinned_message_id", sa.Integer(), nullable=True)) + op.create_foreign_key( + op.f("fk_chats_pinned_message_id_messages"), + "chats", + "messages", + ["pinned_message_id"], + ["id"], + ondelete="SET NULL", + ) + op.create_index(op.f("ix_chats_pinned_message_id"), "chats", ["pinned_message_id"], unique=False) + + +def downgrade() -> None: + op.drop_index(op.f("ix_chats_pinned_message_id"), table_name="chats") + op.drop_constraint(op.f("fk_chats_pinned_message_id_messages"), "chats", type_="foreignkey") + op.drop_column("chats", "pinned_message_id") + + op.drop_index(op.f("ix_messages_forwarded_from_message_id"), table_name="messages") + op.drop_index(op.f("ix_messages_reply_to_message_id"), table_name="messages") + op.drop_constraint(op.f("fk_messages_forwarded_from_message_id_messages"), "messages", type_="foreignkey") + op.drop_constraint(op.f("fk_messages_reply_to_message_id_messages"), "messages", type_="foreignkey") + op.drop_column("messages", "forwarded_from_message_id") + op.drop_column("messages", "reply_to_message_id") diff --git a/app/chats/models.py b/app/chats/models.py index 87dc521..1ce60d6 100644 --- a/app/chats/models.py +++ b/app/chats/models.py @@ -30,6 +30,7 @@ class Chat(Base): id: Mapped[int] = mapped_column(primary_key=True, index=True) type: Mapped[ChatType] = mapped_column(SAEnum(ChatType), nullable=False, index=True) title: Mapped[str | None] = mapped_column(String(255), nullable=True) + pinned_message_id: Mapped[int | None] = mapped_column(ForeignKey("messages.id", ondelete="SET NULL"), nullable=True, index=True) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) members: Mapped[list["ChatMember"]] = relationship(back_populates="chat", cascade="all, delete-orphan") diff --git a/app/chats/router.py b/app/chats/router.py index c84b372..f445745 100644 --- a/app/chats/router.py +++ b/app/chats/router.py @@ -8,6 +8,7 @@ from app.chats.schemas import ( ChatMemberAddRequest, ChatMemberRead, ChatMemberRoleUpdateRequest, + ChatPinMessageRequest, ChatRead, ChatTitleUpdateRequest, ) @@ -17,6 +18,7 @@ from app.chats.service import ( get_chat_for_user, get_chats_for_user, leave_chat_for_user, + pin_chat_message_for_user, remove_chat_member_for_user, update_chat_member_role_for_user, update_chat_title_for_user, @@ -58,6 +60,7 @@ async def get_chat( id=chat.id, type=chat.type, title=chat.title, + pinned_message_id=chat.pinned_message_id, created_at=chat.created_at, members=members, ) @@ -127,3 +130,13 @@ async def leave_chat( current_user: User = Depends(get_current_user), ) -> None: await leave_chat_for_user(db, chat_id=chat_id, user_id=current_user.id) + + +@router.post("/{chat_id}/pin", response_model=ChatRead) +async def pin_chat_message( + chat_id: int, + payload: ChatPinMessageRequest, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +) -> ChatRead: + return await pin_chat_message_for_user(db, chat_id=chat_id, user_id=current_user.id, payload=payload) diff --git a/app/chats/schemas.py b/app/chats/schemas.py index 833849d..e0b2335 100644 --- a/app/chats/schemas.py +++ b/app/chats/schemas.py @@ -11,6 +11,7 @@ class ChatRead(BaseModel): id: int type: ChatType title: str | None = None + pinned_message_id: int | None = None created_at: datetime @@ -43,3 +44,7 @@ class ChatMemberRoleUpdateRequest(BaseModel): class ChatTitleUpdateRequest(BaseModel): title: str = Field(min_length=1, max_length=255) + + +class ChatPinMessageRequest(BaseModel): + message_id: int | None = None diff --git a/app/chats/service.py b/app/chats/service.py index 44a05e0..a56ef30 100644 --- a/app/chats/service.py +++ b/app/chats/service.py @@ -4,7 +4,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.chats import repository from app.chats.models import Chat, ChatMember, ChatMemberRole, ChatType -from app.chats.schemas import ChatCreateRequest, ChatTitleUpdateRequest +from app.chats.schemas import ChatCreateRequest, ChatPinMessageRequest, ChatTitleUpdateRequest +from app.messages.repository import get_message_by_id from app.users.repository import get_user_by_id @@ -211,3 +212,28 @@ async def leave_chat_for_user(db: AsyncSession, *, chat_id: int, user_id: int) - ) await repository.delete_chat_member(db, membership) await db.commit() + + +async def pin_chat_message_for_user( + db: AsyncSession, + *, + chat_id: int, + user_id: int, + payload: ChatPinMessageRequest, +) -> Chat: + chat, membership = await _get_chat_and_membership(db, chat_id=chat_id, user_id=user_id) + _ensure_group_or_channel(chat.type) + _ensure_manage_permission(membership.role) + if payload.message_id is None: + chat.pinned_message_id = None + await db.commit() + await db.refresh(chat) + return chat + + message = await get_message_by_id(db, payload.message_id) + if not message or message.chat_id != chat_id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found in chat") + chat.pinned_message_id = message.id + await db.commit() + await db.refresh(chat) + return chat diff --git a/app/messages/models.py b/app/messages/models.py index 94882ab..8ab55bb 100644 --- a/app/messages/models.py +++ b/app/messages/models.py @@ -29,6 +29,16 @@ class Message(Base): id: Mapped[int] = mapped_column(primary_key=True, index=True) chat_id: Mapped[int] = mapped_column(ForeignKey("chats.id", ondelete="CASCADE"), nullable=False, index=True) sender_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True) + reply_to_message_id: Mapped[int | None] = mapped_column( + ForeignKey("messages.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + forwarded_from_message_id: Mapped[int | None] = mapped_column( + ForeignKey("messages.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) type: Mapped[MessageType] = mapped_column(SAEnum(MessageType), nullable=False, default=MessageType.TEXT) text: Mapped[str | None] = mapped_column(Text, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) diff --git a/app/messages/repository.py b/app/messages/repository.py index 9ff70ec..472cb97 100644 --- a/app/messages/repository.py +++ b/app/messages/repository.py @@ -10,10 +10,19 @@ async def create_message( *, chat_id: int, sender_id: int, + reply_to_message_id: int | None, + forwarded_from_message_id: int | None, message_type: MessageType, text: str | None, ) -> Message: - message = Message(chat_id=chat_id, sender_id=sender_id, type=message_type, text=text) + message = Message( + chat_id=chat_id, + sender_id=sender_id, + reply_to_message_id=reply_to_message_id, + forwarded_from_message_id=forwarded_from_message_id, + type=message_type, + text=text, + ) db.add(message) await db.flush() return message diff --git a/app/messages/router.py b/app/messages/router.py index 7ee7e3e..8382c9c 100644 --- a/app/messages/router.py +++ b/app/messages/router.py @@ -3,8 +3,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.auth.service import get_current_user from app.database.session import get_db -from app.messages.schemas import MessageCreateRequest, MessageRead, MessageStatusUpdateRequest, MessageUpdateRequest -from app.messages.service import create_chat_message, delete_message, get_messages, search_messages, update_message +from app.messages.schemas import MessageCreateRequest, MessageForwardRequest, MessageRead, MessageStatusUpdateRequest, MessageUpdateRequest +from app.messages.service import create_chat_message, delete_message, forward_message, get_messages, search_messages, update_message from app.realtime.schemas import MessageStatusPayload from app.realtime.service import realtime_gateway from app.users.models import User @@ -80,3 +80,15 @@ async def update_status( payload=MessageStatusPayload(chat_id=payload.chat_id, message_id=payload.message_id), event=payload.status, ) + + +@router.post("/{message_id}/forward", response_model=MessageRead, status_code=status.HTTP_201_CREATED) +async def forward_message_endpoint( + message_id: int, + payload: MessageForwardRequest, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +) -> MessageRead: + message = await forward_message(db, source_message_id=message_id, sender_id=current_user.id, payload=payload) + await realtime_gateway.publish_message_created(message=message, sender_id=current_user.id) + return message diff --git a/app/messages/schemas.py b/app/messages/schemas.py index 17e370b..bb296fc 100644 --- a/app/messages/schemas.py +++ b/app/messages/schemas.py @@ -12,6 +12,8 @@ class MessageRead(BaseModel): id: int chat_id: int sender_id: int + reply_to_message_id: int | None + forwarded_from_message_id: int | None type: MessageType text: str | None created_at: datetime @@ -23,6 +25,7 @@ class MessageCreateRequest(BaseModel): type: MessageType = MessageType.TEXT text: str | None = Field(default=None, max_length=4096) client_message_id: str | None = Field(default=None, min_length=8, max_length=64) + reply_to_message_id: int | None = None class MessageUpdateRequest(BaseModel): @@ -33,3 +36,7 @@ class MessageStatusUpdateRequest(BaseModel): chat_id: int message_id: int status: Literal["message_delivered", "message_read"] + + +class MessageForwardRequest(BaseModel): + target_chat_id: int diff --git a/app/messages/service.py b/app/messages/service.py index bbd9ebd..0bec0d0 100644 --- a/app/messages/service.py +++ b/app/messages/service.py @@ -6,12 +6,16 @@ from app.chats.service import ensure_chat_membership from app.messages import repository from app.messages.models import Message from app.messages.spam_guard import enforce_message_spam_policy -from app.messages.schemas import MessageCreateRequest, MessageStatusUpdateRequest, MessageUpdateRequest +from app.messages.schemas import MessageCreateRequest, MessageForwardRequest, MessageStatusUpdateRequest, MessageUpdateRequest from app.notifications.service import dispatch_message_notifications async def create_chat_message(db: AsyncSession, *, sender_id: int, payload: MessageCreateRequest) -> Message: await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=sender_id) + if payload.reply_to_message_id is not None: + reply_to = await repository.get_message_by_id(db, payload.reply_to_message_id) + if not reply_to or reply_to.chat_id != payload.chat_id: + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Invalid reply target") if payload.client_message_id: existing = await repository.get_message_by_client_message_id( db, @@ -30,6 +34,8 @@ async def create_chat_message(db: AsyncSession, *, sender_id: int, payload: Mess db, chat_id=payload.chat_id, sender_id=sender_id, + reply_to_message_id=payload.reply_to_message_id, + forwarded_from_message_id=None, message_type=payload.type, text=payload.text, ) @@ -167,3 +173,29 @@ async def mark_message_status( "last_delivered_message_id": receipt.last_delivered_message_id or 0, "last_read_message_id": receipt.last_read_message_id or 0, } + + +async def forward_message( + db: AsyncSession, + *, + source_message_id: int, + sender_id: int, + payload: MessageForwardRequest, +) -> Message: + source = await repository.get_message_by_id(db, source_message_id) + if not source: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Source message not found") + await ensure_chat_membership(db, chat_id=source.chat_id, user_id=sender_id) + await ensure_chat_membership(db, chat_id=payload.target_chat_id, user_id=sender_id) + forwarded = await repository.create_message( + db, + chat_id=payload.target_chat_id, + sender_id=sender_id, + reply_to_message_id=None, + forwarded_from_message_id=source.id, + message_type=source.type, + text=source.text, + ) + await db.commit() + await db.refresh(forwarded) + return forwarded diff --git a/app/messages/spam_guard.py b/app/messages/spam_guard.py index ef91bb5..e44ec0e 100644 --- a/app/messages/spam_guard.py +++ b/app/messages/spam_guard.py @@ -1,16 +1,9 @@ -import hashlib - from fastapi import HTTPException, status from redis.exceptions import RedisError from app.config.settings import settings from app.utils.redis_client import get_redis_client - -def _hash_text(text: str) -> str: - return hashlib.sha256(text.encode("utf-8")).hexdigest() - - async def enforce_message_spam_policy(*, user_id: int, chat_id: int, text: str | None) -> None: redis = get_redis_client() rate_key = f"spam:msg_rate:{user_id}:{chat_id}" @@ -23,15 +16,5 @@ async def enforce_message_spam_policy(*, user_id: int, chat_id: int, text: str | status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Message rate limit exceeded for this chat.", ) - - normalized = (text or "").strip() - if normalized: - dup_key = f"spam:dup:{user_id}:{chat_id}:{_hash_text(normalized)}" - if await redis.exists(dup_key): - raise HTTPException( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - detail="Duplicate message cooldown is active.", - ) - await redis.set(dup_key, "1", ex=settings.duplicate_message_cooldown_seconds) except RedisError: return diff --git a/app/realtime/schemas.py b/app/realtime/schemas.py index d4f1790..e865830 100644 --- a/app/realtime/schemas.py +++ b/app/realtime/schemas.py @@ -25,6 +25,7 @@ class SendMessagePayload(BaseModel): text: str | None = Field(default=None, max_length=4096) temp_id: str | None = None client_message_id: str | None = Field(default=None, min_length=8, max_length=64) + reply_to_message_id: int | None = None class ChatEventPayload(BaseModel): diff --git a/app/realtime/service.py b/app/realtime/service.py index ce3ed8a..9ce0654 100644 --- a/app/realtime/service.py +++ b/app/realtime/service.py @@ -86,6 +86,7 @@ class RealtimeGateway: type=payload.type, text=payload.text, client_message_id=payload.client_message_id or payload.temp_id, + reply_to_message_id=payload.reply_to_message_id, ), ) await self.publish_message_created( diff --git a/web/src/api/chats.ts b/web/src/api/chats.ts index 75b7d8f..776f00c 100644 --- a/web/src/api/chats.ts +++ b/web/src/api/chats.ts @@ -51,13 +51,15 @@ export async function sendMessageWithClientId( chatId: number, text: string, type: MessageType, - clientMessageId: string + clientMessageId: string, + replyToMessageId?: number ): Promise { const { data } = await http.post("/messages", { chat_id: chatId, text, type, - client_message_id: clientMessageId + client_message_id: clientMessageId, + reply_to_message_id: replyToMessageId }); return data; } @@ -116,3 +118,17 @@ export async function updateMessageStatus( status }); } + +export async function forwardMessage(messageId: number, targetChatId: number): Promise { + const { data } = await http.post(`/messages/${messageId}/forward`, { + target_chat_id: targetChatId + }); + return data; +} + +export async function pinMessage(chatId: number, messageId: number | null): Promise { + const { data } = await http.post(`/chats/${chatId}/pin`, { + message_id: messageId + }); + return data; +} diff --git a/web/src/chat/types.ts b/web/src/chat/types.ts index 5e4e829..fa25de0 100644 --- a/web/src/chat/types.ts +++ b/web/src/chat/types.ts @@ -6,6 +6,7 @@ export interface Chat { id: number; type: ChatType; title: string | null; + pinned_message_id?: number | null; created_at: string; } @@ -15,6 +16,8 @@ export interface Message { sender_id: number; type: MessageType; text: string | null; + reply_to_message_id?: number | null; + forwarded_from_message_id?: number | null; created_at: string; updated_at: string; client_message_id?: string; diff --git a/web/src/components/MessageComposer.tsx b/web/src/components/MessageComposer.tsx index a7c724a..0c61b50 100644 --- a/web/src/components/MessageComposer.tsx +++ b/web/src/components/MessageComposer.tsx @@ -10,6 +10,8 @@ export function MessageComposer() { const addOptimisticMessage = useChatStore((s) => s.addOptimisticMessage); const confirmMessageByClientId = useChatStore((s) => s.confirmMessageByClientId); const removeOptimisticMessage = useChatStore((s) => s.removeOptimisticMessage); + const replyToByChat = useChatStore((s) => s.replyToByChat); + const setReplyToMessage = useChatStore((s) => s.setReplyToMessage); const accessToken = useAuthStore((s) => s.accessToken); const [text, setText] = useState(""); const wsRef = useRef(null); @@ -56,11 +58,13 @@ export function MessageComposer() { } const clientMessageId = makeClientMessageId(); const textValue = text.trim(); + const replyToMessageId = activeChatId ? (replyToByChat[activeChatId]?.id ?? undefined) : undefined; addOptimisticMessage({ chatId: activeChatId, senderId: me.id, type: "text", text: textValue, clientMessageId }); try { - const message = await sendMessageWithClientId(activeChatId, textValue, "text", clientMessageId); + const message = await sendMessageWithClientId(activeChatId, textValue, "text", clientMessageId, replyToMessageId); confirmMessageByClientId(activeChatId, clientMessageId, message); setText(""); + setReplyToMessage(activeChatId, null); const ws = getWs(); ws?.send(JSON.stringify({ event: "typing_stop", payload: { chat_id: activeChatId } })); } catch { @@ -77,6 +81,7 @@ export function MessageComposer() { setUploadProgress(0); setUploadError(null); const clientMessageId = makeClientMessageId(); + const replyToMessageId = activeChatId ? (replyToByChat[activeChatId]?.id ?? undefined) : undefined; try { const upload = await requestUploadUrl(file); await uploadToPresignedUrl(upload.upload_url, upload.required_headers, file, setUploadProgress); @@ -87,9 +92,10 @@ export function MessageComposer() { text: upload.file_url, clientMessageId }); - const message = await sendMessageWithClientId(activeChatId, upload.file_url, messageType, clientMessageId); + const message = await sendMessageWithClientId(activeChatId, upload.file_url, messageType, clientMessageId, replyToMessageId); await attachFile(message.id, upload.file_url, file.type || "application/octet-stream", file.size); confirmMessageByClientId(activeChatId, clientMessageId, message); + setReplyToMessage(activeChatId, null); } catch { removeOptimisticMessage(activeChatId, clientMessageId); setUploadError("Upload failed. Please try again."); @@ -239,6 +245,17 @@ export function MessageComposer() { return (
+ {activeChatId && replyToByChat[activeChatId] ? ( +
+
+

Replying

+

{replyToByChat[activeChatId]?.text || "[media]"}

+
+ +
+ ) : null}