Files
Messenger/app/chats/repository.py
benya db700bcbcd
All checks were successful
CI / test (push) Successful in 26s
moderation: add chat bans for groups/channels with web actions
2026-03-08 14:29:21 +03:00

432 lines
14 KiB
Python

from sqlalchemy import Select, String, func, or_, select
from sqlalchemy.orm import aliased
from sqlalchemy.ext.asyncio import AsyncSession
from app.chats.models import Chat, ChatBan, ChatInviteLink, ChatMember, ChatMemberRole, ChatNotificationSetting, ChatType, ChatUserSetting
from app.messages.models import Message, MessageHidden, MessageReceipt
async def create_chat(db: AsyncSession, *, chat_type: ChatType, title: str | None) -> Chat:
chat = Chat(type=chat_type, title=title)
db.add(chat)
await db.flush()
return chat
async def create_chat_with_meta(
db: AsyncSession,
*,
chat_type: ChatType,
title: str | None,
handle: str | None,
description: str | None,
is_public: bool,
is_saved: bool = False,
) -> Chat:
chat = Chat(
type=chat_type,
title=title,
handle=handle,
description=description,
is_public=is_public,
is_saved=is_saved,
)
db.add(chat)
await db.flush()
return chat
async def add_chat_member(db: AsyncSession, *, chat_id: int, user_id: int, role: ChatMemberRole) -> ChatMember:
member = ChatMember(chat_id=chat_id, user_id=user_id, role=role)
db.add(member)
await db.flush()
return member
async def delete_chat_member(db: AsyncSession, member: ChatMember) -> None:
await db.delete(member)
async def count_chat_members(db: AsyncSession, *, chat_id: int) -> int:
result = await db.execute(select(func.count(ChatMember.id)).where(ChatMember.chat_id == chat_id))
return int(result.scalar_one())
def _user_chats_query(user_id: int, query: str | None = None) -> Select[tuple[Chat]]:
stmt = (
select(Chat)
.join(ChatMember, ChatMember.chat_id == Chat.id)
.outerjoin(
ChatUserSetting,
(ChatUserSetting.chat_id == Chat.id) & (ChatUserSetting.user_id == user_id),
)
.where(ChatMember.user_id == user_id, func.coalesce(ChatUserSetting.archived, False).is_(False))
)
if query and query.strip():
q = f"%{query.strip()}%"
stmt = stmt.where(
or_(
Chat.title.ilike(q),
Chat.type.cast(String).ilike(q),
)
)
return stmt.order_by(
func.coalesce(ChatUserSetting.pinned, False).desc(),
ChatUserSetting.pinned_at.desc().nullslast(),
Chat.id.desc(),
)
async def list_user_chats(
db: AsyncSession,
*,
user_id: int,
limit: int = 50,
before_id: int | None = None,
query: str | None = None,
) -> list[Chat]:
query_stmt = _user_chats_query(user_id, query=query).limit(limit)
if before_id is not None:
query_stmt = query_stmt.where(Chat.id < before_id)
result = await db.execute(query_stmt)
return list(result.scalars().all())
async def list_archived_user_chats(
db: AsyncSession,
*,
user_id: int,
limit: int = 50,
before_id: int | None = None,
) -> list[Chat]:
stmt = (
select(Chat)
.join(ChatMember, ChatMember.chat_id == Chat.id)
.join(
ChatUserSetting,
(ChatUserSetting.chat_id == Chat.id) & (ChatUserSetting.user_id == user_id),
)
.where(ChatMember.user_id == user_id, ChatUserSetting.archived.is_(True))
.order_by(
func.coalesce(ChatUserSetting.pinned, False).desc(),
ChatUserSetting.pinned_at.desc().nullslast(),
Chat.id.desc(),
)
.limit(limit)
)
if before_id is not None:
stmt = stmt.where(Chat.id < before_id)
result = await db.execute(stmt)
return list(result.scalars().all())
async def get_chat_by_id(db: AsyncSession, chat_id: int) -> Chat | None:
result = await db.execute(select(Chat).where(Chat.id == chat_id))
return result.scalar_one_or_none()
async def get_chat_by_handle(db: AsyncSession, handle: str) -> Chat | None:
result = await db.execute(select(Chat).where(Chat.handle == handle))
return result.scalar_one_or_none()
async def get_chat_member(db: AsyncSession, *, chat_id: int, user_id: int) -> ChatMember | None:
result = await db.execute(
select(ChatMember).where(
ChatMember.chat_id == chat_id,
ChatMember.user_id == user_id,
)
)
return result.scalar_one_or_none()
async def list_chat_members(db: AsyncSession, *, chat_id: int) -> list[ChatMember]:
result = await db.execute(select(ChatMember).where(ChatMember.chat_id == chat_id).order_by(ChatMember.id.asc()))
return list(result.scalars().all())
async def list_chat_member_user_ids(db: AsyncSession, *, chat_id: int) -> list[int]:
result = await db.execute(select(ChatMember.user_id).where(ChatMember.chat_id == chat_id))
return list(result.scalars().all())
async def list_user_chat_ids(db: AsyncSession, *, user_id: int) -> list[int]:
result = await db.execute(
select(ChatMember.chat_id).where(ChatMember.user_id == user_id).order_by(ChatMember.chat_id.asc())
)
return list(result.scalars().all())
async def find_saved_chat_for_user(db: AsyncSession, *, user_id: int) -> Chat | None:
stmt = (
select(Chat)
.join(ChatMember, ChatMember.chat_id == Chat.id)
.where(ChatMember.user_id == user_id, Chat.is_saved.is_(True))
.limit(1)
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def discover_public_chats(
db: AsyncSession,
*,
user_id: int,
query: str | None = None,
limit: int = 30,
) -> list[tuple[Chat, bool]]:
q = select(Chat).where(Chat.is_public.is_(True), Chat.type.in_([ChatType.GROUP, ChatType.CHANNEL]), Chat.is_saved.is_(False))
if query and query.strip():
like = f"%{query.strip()}%"
q = q.where(or_(Chat.title.ilike(like), Chat.handle.ilike(like), Chat.description.ilike(like)))
q = q.order_by(Chat.id.desc()).limit(limit)
chats = list((await db.execute(q)).scalars().all())
if not chats:
return []
chat_ids = [c.id for c in chats]
m_stmt = select(ChatMember.chat_id).where(ChatMember.user_id == user_id, ChatMember.chat_id.in_(chat_ids))
memberships = set((await db.execute(m_stmt)).scalars().all())
return [(chat, chat.id in memberships) for chat in chats]
async def find_private_chat_between_users(db: AsyncSession, *, user_a_id: int, user_b_id: int) -> Chat | None:
cm_a = aliased(ChatMember)
cm_b = aliased(ChatMember)
stmt = (
select(Chat)
.join(cm_a, cm_a.chat_id == Chat.id)
.join(cm_b, cm_b.chat_id == Chat.id)
.where(
Chat.type == ChatType.PRIVATE,
cm_a.user_id == user_a_id,
cm_b.user_id == user_b_id,
)
.limit(1)
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_private_counterpart_user_id(db: AsyncSession, *, chat_id: int, user_id: int) -> int | None:
stmt = select(ChatMember.user_id).where(ChatMember.chat_id == chat_id, ChatMember.user_id != user_id).limit(1)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_unread_count_for_chat(db: AsyncSession, *, chat_id: int, user_id: int) -> int:
last_read_subquery = (
select(MessageReceipt.last_read_message_id)
.where(MessageReceipt.chat_id == chat_id, MessageReceipt.user_id == user_id)
.limit(1)
.scalar_subquery()
)
stmt = (
select(func.count(Message.id))
.outerjoin(
MessageHidden,
(MessageHidden.message_id == Message.id) & (MessageHidden.user_id == user_id),
)
.where(
Message.chat_id == chat_id,
Message.sender_id != user_id,
MessageHidden.id.is_(None),
Message.id > func.coalesce(last_read_subquery, 0),
)
)
result = await db.execute(stmt)
return int(result.scalar_one() or 0)
async def get_unread_mentions_count_for_chat(
db: AsyncSession,
*,
chat_id: int,
user_id: int,
username: str | None,
) -> int:
normalized_username = (username or "").strip().lower()
if not normalized_username:
return 0
last_read_subquery = (
select(MessageReceipt.last_read_message_id)
.where(MessageReceipt.chat_id == chat_id, MessageReceipt.user_id == user_id)
.limit(1)
.scalar_subquery()
)
mention_like = f"%@{normalized_username}%"
stmt = (
select(func.count(Message.id))
.outerjoin(
MessageHidden,
(MessageHidden.message_id == Message.id) & (MessageHidden.user_id == user_id),
)
.where(
Message.chat_id == chat_id,
Message.sender_id != user_id,
MessageHidden.id.is_(None),
Message.id > func.coalesce(last_read_subquery, 0),
Message.text.is_not(None),
func.lower(Message.text).like(mention_like),
)
)
result = await db.execute(stmt)
return int(result.scalar_one() or 0)
async def get_last_visible_message_for_user(
db: AsyncSession,
*,
chat_id: int,
user_id: int,
) -> Message | None:
stmt = (
select(Message)
.outerjoin(
MessageHidden,
(MessageHidden.message_id == Message.id) & (MessageHidden.user_id == user_id),
)
.where(
Message.chat_id == chat_id,
MessageHidden.id.is_(None),
)
.order_by(Message.id.desc())
.limit(1)
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_chat_notification_setting(
db: AsyncSession, *, chat_id: int, user_id: int
) -> ChatNotificationSetting | None:
result = await db.execute(
select(ChatNotificationSetting).where(
ChatNotificationSetting.chat_id == chat_id,
ChatNotificationSetting.user_id == user_id,
)
)
return result.scalar_one_or_none()
async def upsert_chat_notification_setting(
db: AsyncSession,
*,
chat_id: int,
user_id: int,
muted: bool,
) -> ChatNotificationSetting:
setting = await get_chat_notification_setting(db, chat_id=chat_id, user_id=user_id)
if setting:
setting.muted = muted
await db.flush()
return setting
setting = ChatNotificationSetting(chat_id=chat_id, user_id=user_id, muted=muted)
db.add(setting)
await db.flush()
return setting
async def is_chat_muted_for_user(db: AsyncSession, *, chat_id: int, user_id: int) -> bool:
setting = await get_chat_notification_setting(db, chat_id=chat_id, user_id=user_id)
return bool(setting and setting.muted)
async def get_chat_user_setting(db: AsyncSession, *, chat_id: int, user_id: int) -> ChatUserSetting | None:
result = await db.execute(
select(ChatUserSetting).where(ChatUserSetting.chat_id == chat_id, ChatUserSetting.user_id == user_id)
)
return result.scalar_one_or_none()
async def upsert_chat_archived_setting(
db: AsyncSession,
*,
chat_id: int,
user_id: int,
archived: bool,
) -> ChatUserSetting:
setting = await get_chat_user_setting(db, chat_id=chat_id, user_id=user_id)
if setting:
setting.archived = archived
await db.flush()
return setting
setting = ChatUserSetting(chat_id=chat_id, user_id=user_id, archived=archived)
db.add(setting)
await db.flush()
return setting
async def upsert_chat_pinned_setting(
db: AsyncSession,
*,
chat_id: int,
user_id: int,
pinned: bool,
) -> ChatUserSetting:
setting = await get_chat_user_setting(db, chat_id=chat_id, user_id=user_id)
if setting:
setting.pinned = pinned
setting.pinned_at = func.now() if pinned else None
await db.flush()
return setting
setting = ChatUserSetting(
chat_id=chat_id,
user_id=user_id,
archived=False,
pinned=pinned,
pinned_at=func.now() if pinned else None,
)
db.add(setting)
await db.flush()
return setting
async def create_chat_invite_link(
db: AsyncSession,
*,
chat_id: int,
creator_user_id: int,
token: str,
) -> ChatInviteLink:
link = ChatInviteLink(chat_id=chat_id, creator_user_id=creator_user_id, token=token, is_active=True)
db.add(link)
await db.flush()
return link
async def get_active_chat_invite_by_token(db: AsyncSession, *, token: str) -> ChatInviteLink | None:
result = await db.execute(
select(ChatInviteLink)
.where(ChatInviteLink.token == token, ChatInviteLink.is_active.is_(True))
.limit(1)
)
return result.scalar_one_or_none()
async def get_chat_ban(db: AsyncSession, *, chat_id: int, user_id: int) -> ChatBan | None:
result = await db.execute(select(ChatBan).where(ChatBan.chat_id == chat_id, ChatBan.user_id == user_id))
return result.scalar_one_or_none()
async def is_user_banned_in_chat(db: AsyncSession, *, chat_id: int, user_id: int) -> bool:
result = await db.execute(select(ChatBan.id).where(ChatBan.chat_id == chat_id, ChatBan.user_id == user_id).limit(1))
return result.scalar_one_or_none() is not None
async def upsert_chat_ban(db: AsyncSession, *, chat_id: int, user_id: int, banned_by_user_id: int) -> ChatBan:
existing = await get_chat_ban(db, chat_id=chat_id, user_id=user_id)
if existing:
existing.banned_by_user_id = banned_by_user_id
await db.flush()
return existing
ban = ChatBan(chat_id=chat_id, user_id=user_id, banned_by_user_id=banned_by_user_id)
db.add(ban)
await db.flush()
return ban
async def remove_chat_ban(db: AsyncSession, *, chat_id: int, user_id: int) -> None:
ban = await get_chat_ban(db, chat_id=chat_id, user_id=user_id)
if ban:
await db.delete(ban)