Files
Messenger/app/messages/repository.py
benya c6e8b779b0
All checks were successful
CI / test (push) Successful in 21s
feat(threads): add basic message thread API and web thread panel
2026-03-08 13:37:53 +03:00

255 lines
7.5 KiB
Python

from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.chats.models import ChatMember
from app.messages.models import Message, MessageHidden, MessageIdempotencyKey, MessageReaction, MessageReceipt, MessageType
async def create_message(
db: AsyncSession,
*,
chat_id: int,
sender_id: int,
reply_to_message_id: int | None,
forwarded_from_message_id: int | None,
message_type: MessageType,
text: str | None,
) -> Message:
message = Message(
chat_id=chat_id,
sender_id=sender_id,
reply_to_message_id=reply_to_message_id,
forwarded_from_message_id=forwarded_from_message_id,
type=message_type,
text=text,
)
db.add(message)
await db.flush()
return message
async def get_message_by_client_message_id(
db: AsyncSession,
*,
chat_id: int,
sender_id: int,
client_message_id: str,
) -> Message | None:
result = await db.execute(
select(Message)
.join(MessageIdempotencyKey, MessageIdempotencyKey.message_id == Message.id)
.where(
MessageIdempotencyKey.chat_id == chat_id,
MessageIdempotencyKey.sender_id == sender_id,
MessageIdempotencyKey.client_message_id == client_message_id,
)
.limit(1)
)
return result.scalar_one_or_none()
async def create_message_idempotency_key(
db: AsyncSession,
*,
chat_id: int,
sender_id: int,
client_message_id: str,
message_id: int,
) -> MessageIdempotencyKey:
key = MessageIdempotencyKey(
chat_id=chat_id,
sender_id=sender_id,
client_message_id=client_message_id,
message_id=message_id,
)
db.add(key)
await db.flush()
return key
async def get_message_by_id(db: AsyncSession, message_id: int) -> Message | None:
result = await db.execute(select(Message).where(Message.id == message_id))
return result.scalar_one_or_none()
async def list_chat_messages(
db: AsyncSession,
chat_id: int,
*,
user_id: int,
limit: int = 50,
before_id: int | None = None,
) -> list[Message]:
query = (
select(Message)
.outerjoin(
MessageHidden,
(MessageHidden.message_id == Message.id) & (MessageHidden.user_id == user_id),
)
.where(Message.chat_id == chat_id, MessageHidden.id.is_(None))
)
if before_id is not None:
query = query.where(Message.id < before_id)
result = await db.execute(query.order_by(Message.id.desc()).limit(limit))
return list(result.scalars().all())
async def search_messages(
db: AsyncSession,
*,
user_id: int,
query: str,
chat_id: int | None = None,
limit: int = 50,
) -> list[Message]:
stmt = (
select(Message)
.join(ChatMember, ChatMember.chat_id == Message.chat_id)
.outerjoin(
MessageHidden,
(MessageHidden.message_id == Message.id) & (MessageHidden.user_id == user_id),
)
.where(
ChatMember.user_id == user_id,
Message.text.is_not(None),
Message.text.ilike(f"%{query.strip()}%"),
MessageHidden.id.is_(None),
)
.order_by(Message.id.desc())
.limit(limit)
)
if chat_id is not None:
stmt = stmt.where(Message.chat_id == chat_id)
result = await db.execute(stmt)
return list(result.scalars().all())
async def list_message_thread(
db: AsyncSession,
*,
root_message_id: int,
user_id: int,
limit: int = 100,
) -> list[Message]:
root = await get_message_by_id(db, root_message_id)
if not root:
return []
stmt = (
select(Message)
.outerjoin(
MessageHidden,
(MessageHidden.message_id == Message.id) & (MessageHidden.user_id == user_id),
)
.where(
Message.chat_id == root.chat_id,
MessageHidden.id.is_(None),
(Message.id == root_message_id) | (Message.reply_to_message_id == root_message_id),
)
.order_by(Message.id.asc())
.limit(max(1, min(limit, 200)))
)
result = await db.execute(stmt)
return list(result.scalars().all())
async def delete_message(db: AsyncSession, message: Message) -> None:
await db.delete(message)
async def hide_message_for_user(db: AsyncSession, *, message_id: int, user_id: int) -> MessageHidden:
hidden = MessageHidden(message_id=message_id, user_id=user_id)
db.add(hidden)
await db.flush()
return hidden
async def get_hidden_message(db: AsyncSession, *, message_id: int, user_id: int) -> MessageHidden | None:
result = await db.execute(
select(MessageHidden).where(MessageHidden.message_id == message_id, MessageHidden.user_id == user_id).limit(1)
)
return result.scalar_one_or_none()
async def list_chat_message_ids(db: AsyncSession, *, chat_id: int) -> list[int]:
result = await db.execute(select(Message.id).where(Message.chat_id == chat_id))
return list(result.scalars().all())
async def delete_messages_in_chat(db: AsyncSession, *, chat_id: int) -> None:
await db.execute(delete(Message).where(Message.chat_id == chat_id))
async def get_message_receipt(db: AsyncSession, *, chat_id: int, user_id: int) -> MessageReceipt | None:
result = await db.execute(
select(MessageReceipt).where(
MessageReceipt.chat_id == chat_id,
MessageReceipt.user_id == user_id,
)
)
return result.scalar_one_or_none()
async def create_message_receipt(
db: AsyncSession,
*,
chat_id: int,
user_id: int,
last_delivered_message_id: int | None,
last_read_message_id: int | None,
) -> MessageReceipt:
receipt = MessageReceipt(
chat_id=chat_id,
user_id=user_id,
last_delivered_message_id=last_delivered_message_id,
last_read_message_id=last_read_message_id,
)
db.add(receipt)
await db.flush()
return receipt
async def list_chat_receipts(db: AsyncSession, *, chat_id: int) -> list[MessageReceipt]:
result = await db.execute(select(MessageReceipt).where(MessageReceipt.chat_id == chat_id))
return list(result.scalars().all())
async def get_message_reaction(db: AsyncSession, *, message_id: int, user_id: int) -> MessageReaction | None:
result = await db.execute(
select(MessageReaction)
.where(MessageReaction.message_id == message_id, MessageReaction.user_id == user_id)
.limit(1)
)
return result.scalar_one_or_none()
async def list_message_reactions(db: AsyncSession, *, message_id: int) -> list[tuple[str, int]]:
result = await db.execute(
select(MessageReaction.emoji, func.count(MessageReaction.id))
.where(MessageReaction.message_id == message_id)
.group_by(MessageReaction.emoji)
.order_by(func.count(MessageReaction.id).desc(), MessageReaction.emoji.asc())
)
return [(str(emoji), int(count)) for emoji, count in result.all()]
async def upsert_message_reaction(
db: AsyncSession,
*,
message_id: int,
user_id: int,
emoji: str,
) -> tuple[MessageReaction | None, str]:
existing = await get_message_reaction(db, message_id=message_id, user_id=user_id)
if existing and existing.emoji == emoji:
await db.delete(existing)
await db.flush()
return None, "removed"
if existing:
existing.emoji = emoji
await db.flush()
return existing, "updated"
reaction = MessageReaction(message_id=message_id, user_id=user_id, emoji=emoji)
db.add(reaction)
await db.flush()
return reaction, "added"