255 lines
7.5 KiB
Python
255 lines
7.5 KiB
Python
from sqlalchemy import delete, func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.chats.models import ChatMember
|
|
from app.messages.models import Message, MessageHidden, MessageIdempotencyKey, MessageReaction, MessageReceipt, MessageType
|
|
|
|
|
|
async def create_message(
|
|
db: AsyncSession,
|
|
*,
|
|
chat_id: int,
|
|
sender_id: int,
|
|
reply_to_message_id: int | None,
|
|
forwarded_from_message_id: int | None,
|
|
message_type: MessageType,
|
|
text: str | None,
|
|
) -> Message:
|
|
message = Message(
|
|
chat_id=chat_id,
|
|
sender_id=sender_id,
|
|
reply_to_message_id=reply_to_message_id,
|
|
forwarded_from_message_id=forwarded_from_message_id,
|
|
type=message_type,
|
|
text=text,
|
|
)
|
|
db.add(message)
|
|
await db.flush()
|
|
return message
|
|
|
|
|
|
async def get_message_by_client_message_id(
|
|
db: AsyncSession,
|
|
*,
|
|
chat_id: int,
|
|
sender_id: int,
|
|
client_message_id: str,
|
|
) -> Message | None:
|
|
result = await db.execute(
|
|
select(Message)
|
|
.join(MessageIdempotencyKey, MessageIdempotencyKey.message_id == Message.id)
|
|
.where(
|
|
MessageIdempotencyKey.chat_id == chat_id,
|
|
MessageIdempotencyKey.sender_id == sender_id,
|
|
MessageIdempotencyKey.client_message_id == client_message_id,
|
|
)
|
|
.limit(1)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def create_message_idempotency_key(
|
|
db: AsyncSession,
|
|
*,
|
|
chat_id: int,
|
|
sender_id: int,
|
|
client_message_id: str,
|
|
message_id: int,
|
|
) -> MessageIdempotencyKey:
|
|
key = MessageIdempotencyKey(
|
|
chat_id=chat_id,
|
|
sender_id=sender_id,
|
|
client_message_id=client_message_id,
|
|
message_id=message_id,
|
|
)
|
|
db.add(key)
|
|
await db.flush()
|
|
return key
|
|
|
|
|
|
async def get_message_by_id(db: AsyncSession, message_id: int) -> Message | None:
|
|
result = await db.execute(select(Message).where(Message.id == message_id))
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def list_chat_messages(
|
|
db: AsyncSession,
|
|
chat_id: int,
|
|
*,
|
|
user_id: int,
|
|
limit: int = 50,
|
|
before_id: int | None = None,
|
|
) -> list[Message]:
|
|
query = (
|
|
select(Message)
|
|
.outerjoin(
|
|
MessageHidden,
|
|
(MessageHidden.message_id == Message.id) & (MessageHidden.user_id == user_id),
|
|
)
|
|
.where(Message.chat_id == chat_id, MessageHidden.id.is_(None))
|
|
)
|
|
if before_id is not None:
|
|
query = query.where(Message.id < before_id)
|
|
result = await db.execute(query.order_by(Message.id.desc()).limit(limit))
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def search_messages(
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: int,
|
|
query: str,
|
|
chat_id: int | None = None,
|
|
limit: int = 50,
|
|
) -> list[Message]:
|
|
stmt = (
|
|
select(Message)
|
|
.join(ChatMember, ChatMember.chat_id == Message.chat_id)
|
|
.outerjoin(
|
|
MessageHidden,
|
|
(MessageHidden.message_id == Message.id) & (MessageHidden.user_id == user_id),
|
|
)
|
|
.where(
|
|
ChatMember.user_id == user_id,
|
|
Message.text.is_not(None),
|
|
Message.text.ilike(f"%{query.strip()}%"),
|
|
MessageHidden.id.is_(None),
|
|
)
|
|
.order_by(Message.id.desc())
|
|
.limit(limit)
|
|
)
|
|
if chat_id is not None:
|
|
stmt = stmt.where(Message.chat_id == chat_id)
|
|
result = await db.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def list_message_thread(
|
|
db: AsyncSession,
|
|
*,
|
|
root_message_id: int,
|
|
user_id: int,
|
|
limit: int = 100,
|
|
) -> list[Message]:
|
|
root = await get_message_by_id(db, root_message_id)
|
|
if not root:
|
|
return []
|
|
stmt = (
|
|
select(Message)
|
|
.outerjoin(
|
|
MessageHidden,
|
|
(MessageHidden.message_id == Message.id) & (MessageHidden.user_id == user_id),
|
|
)
|
|
.where(
|
|
Message.chat_id == root.chat_id,
|
|
MessageHidden.id.is_(None),
|
|
(Message.id == root_message_id) | (Message.reply_to_message_id == root_message_id),
|
|
)
|
|
.order_by(Message.id.asc())
|
|
.limit(max(1, min(limit, 200)))
|
|
)
|
|
result = await db.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def delete_message(db: AsyncSession, message: Message) -> None:
|
|
await db.delete(message)
|
|
|
|
|
|
async def hide_message_for_user(db: AsyncSession, *, message_id: int, user_id: int) -> MessageHidden:
|
|
hidden = MessageHidden(message_id=message_id, user_id=user_id)
|
|
db.add(hidden)
|
|
await db.flush()
|
|
return hidden
|
|
|
|
|
|
async def get_hidden_message(db: AsyncSession, *, message_id: int, user_id: int) -> MessageHidden | None:
|
|
result = await db.execute(
|
|
select(MessageHidden).where(MessageHidden.message_id == message_id, MessageHidden.user_id == user_id).limit(1)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def list_chat_message_ids(db: AsyncSession, *, chat_id: int) -> list[int]:
|
|
result = await db.execute(select(Message.id).where(Message.chat_id == chat_id))
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def delete_messages_in_chat(db: AsyncSession, *, chat_id: int) -> None:
|
|
await db.execute(delete(Message).where(Message.chat_id == chat_id))
|
|
|
|
|
|
async def get_message_receipt(db: AsyncSession, *, chat_id: int, user_id: int) -> MessageReceipt | None:
|
|
result = await db.execute(
|
|
select(MessageReceipt).where(
|
|
MessageReceipt.chat_id == chat_id,
|
|
MessageReceipt.user_id == user_id,
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def create_message_receipt(
|
|
db: AsyncSession,
|
|
*,
|
|
chat_id: int,
|
|
user_id: int,
|
|
last_delivered_message_id: int | None,
|
|
last_read_message_id: int | None,
|
|
) -> MessageReceipt:
|
|
receipt = MessageReceipt(
|
|
chat_id=chat_id,
|
|
user_id=user_id,
|
|
last_delivered_message_id=last_delivered_message_id,
|
|
last_read_message_id=last_read_message_id,
|
|
)
|
|
db.add(receipt)
|
|
await db.flush()
|
|
return receipt
|
|
|
|
|
|
async def list_chat_receipts(db: AsyncSession, *, chat_id: int) -> list[MessageReceipt]:
|
|
result = await db.execute(select(MessageReceipt).where(MessageReceipt.chat_id == chat_id))
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def get_message_reaction(db: AsyncSession, *, message_id: int, user_id: int) -> MessageReaction | None:
|
|
result = await db.execute(
|
|
select(MessageReaction)
|
|
.where(MessageReaction.message_id == message_id, MessageReaction.user_id == user_id)
|
|
.limit(1)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def list_message_reactions(db: AsyncSession, *, message_id: int) -> list[tuple[str, int]]:
|
|
result = await db.execute(
|
|
select(MessageReaction.emoji, func.count(MessageReaction.id))
|
|
.where(MessageReaction.message_id == message_id)
|
|
.group_by(MessageReaction.emoji)
|
|
.order_by(func.count(MessageReaction.id).desc(), MessageReaction.emoji.asc())
|
|
)
|
|
return [(str(emoji), int(count)) for emoji, count in result.all()]
|
|
|
|
|
|
async def upsert_message_reaction(
|
|
db: AsyncSession,
|
|
*,
|
|
message_id: int,
|
|
user_id: int,
|
|
emoji: str,
|
|
) -> tuple[MessageReaction | None, str]:
|
|
existing = await get_message_reaction(db, message_id=message_id, user_id=user_id)
|
|
if existing and existing.emoji == emoji:
|
|
await db.delete(existing)
|
|
await db.flush()
|
|
return None, "removed"
|
|
if existing:
|
|
existing.emoji = emoji
|
|
await db.flush()
|
|
return existing, "updated"
|
|
reaction = MessageReaction(message_id=message_id, user_id=user_id, emoji=emoji)
|
|
db.add(reaction)
|
|
await db.flush()
|
|
return reaction, "added"
|