409 lines
17 KiB
Python
409 lines
17 KiB
Python
import json
|
|
|
|
from fastapi import HTTPException, status
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.chats import repository as chats_repository
|
|
from app.chats.models import ChatMemberRole, ChatType
|
|
from app.chats.service import ensure_chat_membership
|
|
from app.media import repository as media_repository
|
|
from app.messages import repository
|
|
from app.messages.models import Message
|
|
from app.messages.spam_guard import enforce_message_spam_policy
|
|
from app.messages.schemas import (
|
|
MessageCreateRequest,
|
|
MessageForwardBulkRequest,
|
|
MessageForwardRequest,
|
|
MessageReactionRead,
|
|
MessageReactionToggleRequest,
|
|
MessageStatusUpdateRequest,
|
|
MessageUpdateRequest,
|
|
)
|
|
from app.notifications.service import dispatch_message_notifications
|
|
from app.users.repository import has_block_relation_between_users
|
|
from app.users.service import get_user_by_id
|
|
|
|
|
|
async def create_chat_message(db: AsyncSession, *, sender_id: int, payload: MessageCreateRequest) -> Message:
|
|
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=sender_id)
|
|
chat = await chats_repository.get_chat_by_id(db, payload.chat_id)
|
|
if not chat:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
|
membership = await chats_repository.get_chat_member(db, chat_id=payload.chat_id, user_id=sender_id)
|
|
if not membership:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You are not a member of this chat")
|
|
if chat.type == ChatType.CHANNEL and membership.role == ChatMemberRole.MEMBER:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only admins can post in channels")
|
|
if chat.type == ChatType.PRIVATE:
|
|
counterpart_id = await chats_repository.get_private_counterpart_user_id(db, chat_id=payload.chat_id, user_id=sender_id)
|
|
if counterpart_id and await has_block_relation_between_users(db, user_a_id=sender_id, user_b_id=counterpart_id):
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Cannot send message due to block settings")
|
|
if counterpart_id:
|
|
counterpart = await get_user_by_id(db, counterpart_id)
|
|
if counterpart and not counterpart.allow_private_messages:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User does not accept private messages")
|
|
if payload.reply_to_message_id is not None:
|
|
reply_to = await repository.get_message_by_id(db, payload.reply_to_message_id)
|
|
if not reply_to or reply_to.chat_id != payload.chat_id:
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Invalid reply target")
|
|
if payload.client_message_id:
|
|
existing = await repository.get_message_by_client_message_id(
|
|
db,
|
|
chat_id=payload.chat_id,
|
|
sender_id=sender_id,
|
|
client_message_id=payload.client_message_id,
|
|
)
|
|
if existing:
|
|
return existing
|
|
if payload.type.value == "text" and not (payload.text and payload.text.strip()):
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Text message cannot be empty")
|
|
await enforce_message_spam_policy(user_id=sender_id, chat_id=payload.chat_id, text=payload.text)
|
|
|
|
try:
|
|
message = await repository.create_message(
|
|
db,
|
|
chat_id=payload.chat_id,
|
|
sender_id=sender_id,
|
|
reply_to_message_id=payload.reply_to_message_id,
|
|
forwarded_from_message_id=None,
|
|
message_type=payload.type,
|
|
text=payload.text,
|
|
)
|
|
if payload.client_message_id:
|
|
await repository.create_message_idempotency_key(
|
|
db,
|
|
chat_id=payload.chat_id,
|
|
sender_id=sender_id,
|
|
client_message_id=payload.client_message_id,
|
|
message_id=message.id,
|
|
)
|
|
await db.commit()
|
|
await db.refresh(message)
|
|
except IntegrityError:
|
|
await db.rollback()
|
|
if payload.client_message_id:
|
|
existing = await repository.get_message_by_client_message_id(
|
|
db,
|
|
chat_id=payload.chat_id,
|
|
sender_id=sender_id,
|
|
client_message_id=payload.client_message_id,
|
|
)
|
|
if existing:
|
|
return existing
|
|
raise
|
|
try:
|
|
await dispatch_message_notifications(db, message)
|
|
except Exception:
|
|
# Notifications should not block message delivery.
|
|
pass
|
|
return message
|
|
|
|
|
|
async def get_messages(
|
|
db: AsyncSession,
|
|
*,
|
|
chat_id: int,
|
|
user_id: int,
|
|
limit: int = 50,
|
|
before_id: int | None = None,
|
|
) -> list[Message]:
|
|
await ensure_chat_membership(db, chat_id=chat_id, user_id=user_id)
|
|
safe_limit = max(1, min(limit, 100))
|
|
messages = await repository.list_chat_messages(db, chat_id, user_id=user_id, limit=safe_limit, before_id=before_id)
|
|
if not messages:
|
|
return messages
|
|
message_ids = [message.id for message in messages]
|
|
attachments = await media_repository.list_attachments_by_message_ids(db, message_ids=message_ids)
|
|
waveform_by_message_id: dict[int, list[int]] = {}
|
|
for attachment in attachments:
|
|
if not attachment.waveform_data:
|
|
continue
|
|
try:
|
|
parsed = json.loads(attachment.waveform_data)
|
|
except Exception:
|
|
continue
|
|
if not isinstance(parsed, list):
|
|
continue
|
|
values: list[int] = []
|
|
for item in parsed[:256]:
|
|
if isinstance(item, (int, float)):
|
|
values.append(max(0, min(31, int(item))))
|
|
if values:
|
|
waveform_by_message_id[attachment.message_id] = values
|
|
receipts = await repository.list_chat_receipts(db, chat_id=chat_id)
|
|
other_receipts = [receipt for receipt in receipts if receipt.user_id != user_id]
|
|
if not other_receipts:
|
|
return messages
|
|
for message in messages:
|
|
waveform = waveform_by_message_id.get(message.id)
|
|
if waveform:
|
|
setattr(message, "attachment_waveform", waveform)
|
|
if message.sender_id != user_id:
|
|
continue
|
|
is_read = any((receipt.last_read_message_id or 0) >= message.id for receipt in other_receipts)
|
|
if is_read:
|
|
setattr(message, "delivery_status", "read")
|
|
continue
|
|
is_delivered = any((receipt.last_delivered_message_id or 0) >= message.id for receipt in other_receipts)
|
|
if is_delivered:
|
|
setattr(message, "delivery_status", "delivered")
|
|
continue
|
|
setattr(message, "delivery_status", "sent")
|
|
return messages
|
|
|
|
|
|
async def search_messages(
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: int,
|
|
query: str,
|
|
chat_id: int | None = None,
|
|
limit: int = 50,
|
|
) -> list[Message]:
|
|
normalized = query.strip()
|
|
if len(normalized) < 2:
|
|
return []
|
|
safe_limit = max(1, min(limit, 100))
|
|
if chat_id is not None:
|
|
await ensure_chat_membership(db, chat_id=chat_id, user_id=user_id)
|
|
return await repository.search_messages(
|
|
db,
|
|
user_id=user_id,
|
|
query=normalized,
|
|
chat_id=chat_id,
|
|
limit=safe_limit,
|
|
)
|
|
|
|
|
|
async def update_message(
|
|
db: AsyncSession,
|
|
*,
|
|
message_id: int,
|
|
user_id: int,
|
|
payload: MessageUpdateRequest,
|
|
) -> Message:
|
|
message = await repository.get_message_by_id(db, message_id)
|
|
if not message:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
|
await ensure_chat_membership(db, chat_id=message.chat_id, user_id=user_id)
|
|
if message.sender_id != user_id:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You can edit only your own messages")
|
|
message.text = payload.text
|
|
await db.commit()
|
|
await db.refresh(message)
|
|
return message
|
|
|
|
|
|
async def delete_message(db: AsyncSession, *, message_id: int, user_id: int) -> None:
|
|
message = await repository.get_message_by_id(db, message_id)
|
|
if not message:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
|
await ensure_chat_membership(db, chat_id=message.chat_id, user_id=user_id)
|
|
chat = await chats_repository.get_chat_by_id(db, message.chat_id)
|
|
if not chat:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
|
membership = await chats_repository.get_chat_member(db, chat_id=message.chat_id, user_id=user_id)
|
|
if not membership:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You are not a member of this chat")
|
|
if chat.type == ChatType.CHANNEL and not chat.is_saved:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail="Channel messages can only be deleted for everyone",
|
|
)
|
|
# Telegram-like default: delete only for current user.
|
|
hidden = await repository.get_hidden_message(db, message_id=message.id, user_id=user_id)
|
|
if not hidden:
|
|
try:
|
|
await repository.hide_message_for_user(db, message_id=message.id, user_id=user_id)
|
|
except IntegrityError:
|
|
await db.rollback()
|
|
return
|
|
await db.commit()
|
|
|
|
|
|
async def delete_message_for_all(db: AsyncSession, *, message_id: int, user_id: int) -> None:
|
|
message = await repository.get_message_by_id(db, message_id)
|
|
if not message:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
|
await ensure_chat_membership(db, chat_id=message.chat_id, user_id=user_id)
|
|
chat = await chats_repository.get_chat_by_id(db, message.chat_id)
|
|
if not chat:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
|
membership = await chats_repository.get_chat_member(db, chat_id=message.chat_id, user_id=user_id)
|
|
if not membership:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You are not a member of this chat")
|
|
if chat.is_saved:
|
|
await delete_message(db, message_id=message_id, user_id=user_id)
|
|
return
|
|
|
|
can_delete_for_all = False
|
|
if chat.type == ChatType.PRIVATE:
|
|
can_delete_for_all = True
|
|
elif message.sender_id == user_id:
|
|
can_delete_for_all = True
|
|
elif chat.type in {ChatType.GROUP, ChatType.CHANNEL} and membership.role in {ChatMemberRole.OWNER, ChatMemberRole.ADMIN}:
|
|
can_delete_for_all = True
|
|
|
|
if not can_delete_for_all:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions for delete-for-all")
|
|
await repository.delete_message(db, message)
|
|
await db.commit()
|
|
|
|
|
|
async def mark_message_status(
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: int,
|
|
payload: MessageStatusUpdateRequest,
|
|
) -> dict[str, int]:
|
|
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
|
|
message = await repository.get_message_by_id(db, payload.message_id)
|
|
if not message or message.chat_id != payload.chat_id:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
|
|
|
receipt = await repository.get_message_receipt(db, chat_id=payload.chat_id, user_id=user_id)
|
|
if not receipt:
|
|
last_delivered = payload.message_id if payload.status in {"message_delivered", "message_read"} else None
|
|
last_read = payload.message_id if payload.status == "message_read" else None
|
|
receipt = await repository.create_message_receipt(
|
|
db,
|
|
chat_id=payload.chat_id,
|
|
user_id=user_id,
|
|
last_delivered_message_id=last_delivered,
|
|
last_read_message_id=last_read,
|
|
)
|
|
else:
|
|
if payload.status in {"message_delivered", "message_read"}:
|
|
current = receipt.last_delivered_message_id or 0
|
|
receipt.last_delivered_message_id = max(current, payload.message_id)
|
|
if payload.status == "message_read":
|
|
current_read = receipt.last_read_message_id or 0
|
|
receipt.last_read_message_id = max(current_read, payload.message_id)
|
|
|
|
await db.commit()
|
|
await db.refresh(receipt)
|
|
return {
|
|
"chat_id": payload.chat_id,
|
|
"message_id": payload.message_id,
|
|
"last_delivered_message_id": receipt.last_delivered_message_id or 0,
|
|
"last_read_message_id": receipt.last_read_message_id or 0,
|
|
}
|
|
|
|
|
|
async def forward_message(
|
|
db: AsyncSession,
|
|
*,
|
|
source_message_id: int,
|
|
sender_id: int,
|
|
payload: MessageForwardRequest,
|
|
) -> Message:
|
|
source = await repository.get_message_by_id(db, source_message_id)
|
|
if not source:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Source message not found")
|
|
await ensure_chat_membership(db, chat_id=source.chat_id, user_id=sender_id)
|
|
await ensure_chat_membership(db, chat_id=payload.target_chat_id, user_id=sender_id)
|
|
target_chat = await chats_repository.get_chat_by_id(db, payload.target_chat_id)
|
|
if not target_chat:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
|
target_membership = await chats_repository.get_chat_member(db, chat_id=payload.target_chat_id, user_id=sender_id)
|
|
if not target_membership:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You are not a member of this chat")
|
|
if target_chat.type == ChatType.CHANNEL and target_membership.role == ChatMemberRole.MEMBER:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only admins can post in channels")
|
|
forwarded = await repository.create_message(
|
|
db,
|
|
chat_id=payload.target_chat_id,
|
|
sender_id=sender_id,
|
|
reply_to_message_id=None,
|
|
forwarded_from_message_id=source.id,
|
|
message_type=source.type,
|
|
text=source.text,
|
|
)
|
|
await db.commit()
|
|
await db.refresh(forwarded)
|
|
return forwarded
|
|
|
|
|
|
async def forward_message_bulk(
|
|
db: AsyncSession,
|
|
*,
|
|
source_message_id: int,
|
|
sender_id: int,
|
|
payload: MessageForwardBulkRequest,
|
|
) -> list[Message]:
|
|
source = await repository.get_message_by_id(db, source_message_id)
|
|
if not source:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Source message not found")
|
|
await ensure_chat_membership(db, chat_id=source.chat_id, user_id=sender_id)
|
|
|
|
target_chat_ids = list(dict.fromkeys(payload.target_chat_ids))
|
|
if not target_chat_ids:
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="No target chats")
|
|
|
|
forwarded_messages: list[Message] = []
|
|
for target_chat_id in target_chat_ids:
|
|
await ensure_chat_membership(db, chat_id=target_chat_id, user_id=sender_id)
|
|
target_chat = await chats_repository.get_chat_by_id(db, target_chat_id)
|
|
if not target_chat:
|
|
continue
|
|
target_membership = await chats_repository.get_chat_member(db, chat_id=target_chat_id, user_id=sender_id)
|
|
if not target_membership:
|
|
continue
|
|
if target_chat.type == ChatType.CHANNEL and target_membership.role == ChatMemberRole.MEMBER:
|
|
continue
|
|
forwarded = await repository.create_message(
|
|
db,
|
|
chat_id=target_chat_id,
|
|
sender_id=sender_id,
|
|
reply_to_message_id=None,
|
|
forwarded_from_message_id=source.id,
|
|
message_type=source.type,
|
|
text=source.text,
|
|
)
|
|
forwarded_messages.append(forwarded)
|
|
await db.commit()
|
|
for message in forwarded_messages:
|
|
await db.refresh(message)
|
|
return forwarded_messages
|
|
|
|
|
|
async def list_message_reactions(
|
|
db: AsyncSession,
|
|
*,
|
|
message_id: int,
|
|
user_id: int,
|
|
) -> list[MessageReactionRead]:
|
|
message = await repository.get_message_by_id(db, message_id)
|
|
if not message:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
|
await ensure_chat_membership(db, chat_id=message.chat_id, user_id=user_id)
|
|
counts = await repository.list_message_reactions(db, message_id=message_id)
|
|
mine = await repository.get_message_reaction(db, message_id=message_id, user_id=user_id)
|
|
mine_emoji = mine.emoji if mine else None
|
|
return [
|
|
MessageReactionRead(emoji=emoji, count=count, reacted=(emoji == mine_emoji))
|
|
for emoji, count in counts
|
|
]
|
|
|
|
|
|
async def toggle_message_reaction(
|
|
db: AsyncSession,
|
|
*,
|
|
message_id: int,
|
|
user_id: int,
|
|
payload: MessageReactionToggleRequest,
|
|
) -> list[MessageReactionRead]:
|
|
message = await repository.get_message_by_id(db, message_id)
|
|
if not message:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
|
await ensure_chat_membership(db, chat_id=message.chat_id, user_id=user_id)
|
|
await repository.upsert_message_reaction(
|
|
db,
|
|
message_id=message_id,
|
|
user_id=user_id,
|
|
emoji=payload.emoji.strip(),
|
|
)
|
|
await db.commit()
|
|
return await list_message_reactions(db, message_id=message_id, user_id=user_id)
|