432 lines
14 KiB
Python
432 lines
14 KiB
Python
from sqlalchemy import Select, String, func, or_, select
|
|
from sqlalchemy.orm import aliased
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.chats.models import Chat, ChatBan, ChatInviteLink, ChatMember, ChatMemberRole, ChatNotificationSetting, ChatType, ChatUserSetting
|
|
from app.messages.models import Message, MessageHidden, MessageReceipt
|
|
|
|
|
|
async def create_chat(db: AsyncSession, *, chat_type: ChatType, title: str | None) -> Chat:
|
|
chat = Chat(type=chat_type, title=title)
|
|
db.add(chat)
|
|
await db.flush()
|
|
return chat
|
|
|
|
|
|
async def create_chat_with_meta(
|
|
db: AsyncSession,
|
|
*,
|
|
chat_type: ChatType,
|
|
title: str | None,
|
|
handle: str | None,
|
|
description: str | None,
|
|
is_public: bool,
|
|
is_saved: bool = False,
|
|
) -> Chat:
|
|
chat = Chat(
|
|
type=chat_type,
|
|
title=title,
|
|
handle=handle,
|
|
description=description,
|
|
is_public=is_public,
|
|
is_saved=is_saved,
|
|
)
|
|
db.add(chat)
|
|
await db.flush()
|
|
return chat
|
|
|
|
|
|
async def add_chat_member(db: AsyncSession, *, chat_id: int, user_id: int, role: ChatMemberRole) -> ChatMember:
|
|
member = ChatMember(chat_id=chat_id, user_id=user_id, role=role)
|
|
db.add(member)
|
|
await db.flush()
|
|
return member
|
|
|
|
|
|
async def delete_chat_member(db: AsyncSession, member: ChatMember) -> None:
|
|
await db.delete(member)
|
|
|
|
|
|
async def count_chat_members(db: AsyncSession, *, chat_id: int) -> int:
|
|
result = await db.execute(select(func.count(ChatMember.id)).where(ChatMember.chat_id == chat_id))
|
|
return int(result.scalar_one())
|
|
|
|
|
|
def _user_chats_query(user_id: int, query: str | None = None) -> Select[tuple[Chat]]:
|
|
stmt = (
|
|
select(Chat)
|
|
.join(ChatMember, ChatMember.chat_id == Chat.id)
|
|
.outerjoin(
|
|
ChatUserSetting,
|
|
(ChatUserSetting.chat_id == Chat.id) & (ChatUserSetting.user_id == user_id),
|
|
)
|
|
.where(ChatMember.user_id == user_id, func.coalesce(ChatUserSetting.archived, False).is_(False))
|
|
)
|
|
if query and query.strip():
|
|
q = f"%{query.strip()}%"
|
|
stmt = stmt.where(
|
|
or_(
|
|
Chat.title.ilike(q),
|
|
Chat.type.cast(String).ilike(q),
|
|
)
|
|
)
|
|
return stmt.order_by(
|
|
func.coalesce(ChatUserSetting.pinned, False).desc(),
|
|
ChatUserSetting.pinned_at.desc().nullslast(),
|
|
Chat.id.desc(),
|
|
)
|
|
|
|
|
|
async def list_user_chats(
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: int,
|
|
limit: int = 50,
|
|
before_id: int | None = None,
|
|
query: str | None = None,
|
|
) -> list[Chat]:
|
|
query_stmt = _user_chats_query(user_id, query=query).limit(limit)
|
|
if before_id is not None:
|
|
query_stmt = query_stmt.where(Chat.id < before_id)
|
|
result = await db.execute(query_stmt)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def list_archived_user_chats(
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: int,
|
|
limit: int = 50,
|
|
before_id: int | None = None,
|
|
) -> list[Chat]:
|
|
stmt = (
|
|
select(Chat)
|
|
.join(ChatMember, ChatMember.chat_id == Chat.id)
|
|
.join(
|
|
ChatUserSetting,
|
|
(ChatUserSetting.chat_id == Chat.id) & (ChatUserSetting.user_id == user_id),
|
|
)
|
|
.where(ChatMember.user_id == user_id, ChatUserSetting.archived.is_(True))
|
|
.order_by(
|
|
func.coalesce(ChatUserSetting.pinned, False).desc(),
|
|
ChatUserSetting.pinned_at.desc().nullslast(),
|
|
Chat.id.desc(),
|
|
)
|
|
.limit(limit)
|
|
)
|
|
if before_id is not None:
|
|
stmt = stmt.where(Chat.id < before_id)
|
|
result = await db.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def get_chat_by_id(db: AsyncSession, chat_id: int) -> Chat | None:
|
|
result = await db.execute(select(Chat).where(Chat.id == chat_id))
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_chat_by_handle(db: AsyncSession, handle: str) -> Chat | None:
|
|
result = await db.execute(select(Chat).where(Chat.handle == handle))
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_chat_member(db: AsyncSession, *, chat_id: int, user_id: int) -> ChatMember | None:
|
|
result = await db.execute(
|
|
select(ChatMember).where(
|
|
ChatMember.chat_id == chat_id,
|
|
ChatMember.user_id == user_id,
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def list_chat_members(db: AsyncSession, *, chat_id: int) -> list[ChatMember]:
|
|
result = await db.execute(select(ChatMember).where(ChatMember.chat_id == chat_id).order_by(ChatMember.id.asc()))
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def list_chat_member_user_ids(db: AsyncSession, *, chat_id: int) -> list[int]:
|
|
result = await db.execute(select(ChatMember.user_id).where(ChatMember.chat_id == chat_id))
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def list_user_chat_ids(db: AsyncSession, *, user_id: int) -> list[int]:
|
|
result = await db.execute(
|
|
select(ChatMember.chat_id).where(ChatMember.user_id == user_id).order_by(ChatMember.chat_id.asc())
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def find_saved_chat_for_user(db: AsyncSession, *, user_id: int) -> Chat | None:
|
|
stmt = (
|
|
select(Chat)
|
|
.join(ChatMember, ChatMember.chat_id == Chat.id)
|
|
.where(ChatMember.user_id == user_id, Chat.is_saved.is_(True))
|
|
.limit(1)
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def discover_public_chats(
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: int,
|
|
query: str | None = None,
|
|
limit: int = 30,
|
|
) -> list[tuple[Chat, bool]]:
|
|
q = select(Chat).where(Chat.is_public.is_(True), Chat.type.in_([ChatType.GROUP, ChatType.CHANNEL]), Chat.is_saved.is_(False))
|
|
if query and query.strip():
|
|
like = f"%{query.strip()}%"
|
|
q = q.where(or_(Chat.title.ilike(like), Chat.handle.ilike(like), Chat.description.ilike(like)))
|
|
q = q.order_by(Chat.id.desc()).limit(limit)
|
|
chats = list((await db.execute(q)).scalars().all())
|
|
if not chats:
|
|
return []
|
|
chat_ids = [c.id for c in chats]
|
|
m_stmt = select(ChatMember.chat_id).where(ChatMember.user_id == user_id, ChatMember.chat_id.in_(chat_ids))
|
|
memberships = set((await db.execute(m_stmt)).scalars().all())
|
|
return [(chat, chat.id in memberships) for chat in chats]
|
|
|
|
|
|
async def find_private_chat_between_users(db: AsyncSession, *, user_a_id: int, user_b_id: int) -> Chat | None:
|
|
cm_a = aliased(ChatMember)
|
|
cm_b = aliased(ChatMember)
|
|
stmt = (
|
|
select(Chat)
|
|
.join(cm_a, cm_a.chat_id == Chat.id)
|
|
.join(cm_b, cm_b.chat_id == Chat.id)
|
|
.where(
|
|
Chat.type == ChatType.PRIVATE,
|
|
cm_a.user_id == user_a_id,
|
|
cm_b.user_id == user_b_id,
|
|
)
|
|
.limit(1)
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_private_counterpart_user_id(db: AsyncSession, *, chat_id: int, user_id: int) -> int | None:
|
|
stmt = select(ChatMember.user_id).where(ChatMember.chat_id == chat_id, ChatMember.user_id != user_id).limit(1)
|
|
result = await db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_unread_count_for_chat(db: AsyncSession, *, chat_id: int, user_id: int) -> int:
|
|
last_read_subquery = (
|
|
select(MessageReceipt.last_read_message_id)
|
|
.where(MessageReceipt.chat_id == chat_id, MessageReceipt.user_id == user_id)
|
|
.limit(1)
|
|
.scalar_subquery()
|
|
)
|
|
stmt = (
|
|
select(func.count(Message.id))
|
|
.outerjoin(
|
|
MessageHidden,
|
|
(MessageHidden.message_id == Message.id) & (MessageHidden.user_id == user_id),
|
|
)
|
|
.where(
|
|
Message.chat_id == chat_id,
|
|
Message.sender_id != user_id,
|
|
MessageHidden.id.is_(None),
|
|
Message.id > func.coalesce(last_read_subquery, 0),
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
return int(result.scalar_one() or 0)
|
|
|
|
|
|
async def get_unread_mentions_count_for_chat(
|
|
db: AsyncSession,
|
|
*,
|
|
chat_id: int,
|
|
user_id: int,
|
|
username: str | None,
|
|
) -> int:
|
|
normalized_username = (username or "").strip().lower()
|
|
if not normalized_username:
|
|
return 0
|
|
last_read_subquery = (
|
|
select(MessageReceipt.last_read_message_id)
|
|
.where(MessageReceipt.chat_id == chat_id, MessageReceipt.user_id == user_id)
|
|
.limit(1)
|
|
.scalar_subquery()
|
|
)
|
|
mention_like = f"%@{normalized_username}%"
|
|
stmt = (
|
|
select(func.count(Message.id))
|
|
.outerjoin(
|
|
MessageHidden,
|
|
(MessageHidden.message_id == Message.id) & (MessageHidden.user_id == user_id),
|
|
)
|
|
.where(
|
|
Message.chat_id == chat_id,
|
|
Message.sender_id != user_id,
|
|
MessageHidden.id.is_(None),
|
|
Message.id > func.coalesce(last_read_subquery, 0),
|
|
Message.text.is_not(None),
|
|
func.lower(Message.text).like(mention_like),
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
return int(result.scalar_one() or 0)
|
|
|
|
|
|
async def get_last_visible_message_for_user(
|
|
db: AsyncSession,
|
|
*,
|
|
chat_id: int,
|
|
user_id: int,
|
|
) -> Message | None:
|
|
stmt = (
|
|
select(Message)
|
|
.outerjoin(
|
|
MessageHidden,
|
|
(MessageHidden.message_id == Message.id) & (MessageHidden.user_id == user_id),
|
|
)
|
|
.where(
|
|
Message.chat_id == chat_id,
|
|
MessageHidden.id.is_(None),
|
|
)
|
|
.order_by(Message.id.desc())
|
|
.limit(1)
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_chat_notification_setting(
|
|
db: AsyncSession, *, chat_id: int, user_id: int
|
|
) -> ChatNotificationSetting | None:
|
|
result = await db.execute(
|
|
select(ChatNotificationSetting).where(
|
|
ChatNotificationSetting.chat_id == chat_id,
|
|
ChatNotificationSetting.user_id == user_id,
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def upsert_chat_notification_setting(
|
|
db: AsyncSession,
|
|
*,
|
|
chat_id: int,
|
|
user_id: int,
|
|
muted: bool,
|
|
) -> ChatNotificationSetting:
|
|
setting = await get_chat_notification_setting(db, chat_id=chat_id, user_id=user_id)
|
|
if setting:
|
|
setting.muted = muted
|
|
await db.flush()
|
|
return setting
|
|
setting = ChatNotificationSetting(chat_id=chat_id, user_id=user_id, muted=muted)
|
|
db.add(setting)
|
|
await db.flush()
|
|
return setting
|
|
|
|
|
|
async def is_chat_muted_for_user(db: AsyncSession, *, chat_id: int, user_id: int) -> bool:
|
|
setting = await get_chat_notification_setting(db, chat_id=chat_id, user_id=user_id)
|
|
return bool(setting and setting.muted)
|
|
|
|
|
|
async def get_chat_user_setting(db: AsyncSession, *, chat_id: int, user_id: int) -> ChatUserSetting | None:
|
|
result = await db.execute(
|
|
select(ChatUserSetting).where(ChatUserSetting.chat_id == chat_id, ChatUserSetting.user_id == user_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def upsert_chat_archived_setting(
|
|
db: AsyncSession,
|
|
*,
|
|
chat_id: int,
|
|
user_id: int,
|
|
archived: bool,
|
|
) -> ChatUserSetting:
|
|
setting = await get_chat_user_setting(db, chat_id=chat_id, user_id=user_id)
|
|
if setting:
|
|
setting.archived = archived
|
|
await db.flush()
|
|
return setting
|
|
setting = ChatUserSetting(chat_id=chat_id, user_id=user_id, archived=archived)
|
|
db.add(setting)
|
|
await db.flush()
|
|
return setting
|
|
|
|
|
|
async def upsert_chat_pinned_setting(
|
|
db: AsyncSession,
|
|
*,
|
|
chat_id: int,
|
|
user_id: int,
|
|
pinned: bool,
|
|
) -> ChatUserSetting:
|
|
setting = await get_chat_user_setting(db, chat_id=chat_id, user_id=user_id)
|
|
if setting:
|
|
setting.pinned = pinned
|
|
setting.pinned_at = func.now() if pinned else None
|
|
await db.flush()
|
|
return setting
|
|
setting = ChatUserSetting(
|
|
chat_id=chat_id,
|
|
user_id=user_id,
|
|
archived=False,
|
|
pinned=pinned,
|
|
pinned_at=func.now() if pinned else None,
|
|
)
|
|
db.add(setting)
|
|
await db.flush()
|
|
return setting
|
|
|
|
|
|
async def create_chat_invite_link(
|
|
db: AsyncSession,
|
|
*,
|
|
chat_id: int,
|
|
creator_user_id: int,
|
|
token: str,
|
|
) -> ChatInviteLink:
|
|
link = ChatInviteLink(chat_id=chat_id, creator_user_id=creator_user_id, token=token, is_active=True)
|
|
db.add(link)
|
|
await db.flush()
|
|
return link
|
|
|
|
|
|
async def get_active_chat_invite_by_token(db: AsyncSession, *, token: str) -> ChatInviteLink | None:
|
|
result = await db.execute(
|
|
select(ChatInviteLink)
|
|
.where(ChatInviteLink.token == token, ChatInviteLink.is_active.is_(True))
|
|
.limit(1)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_chat_ban(db: AsyncSession, *, chat_id: int, user_id: int) -> ChatBan | None:
|
|
result = await db.execute(select(ChatBan).where(ChatBan.chat_id == chat_id, ChatBan.user_id == user_id))
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def is_user_banned_in_chat(db: AsyncSession, *, chat_id: int, user_id: int) -> bool:
|
|
result = await db.execute(select(ChatBan.id).where(ChatBan.chat_id == chat_id, ChatBan.user_id == user_id).limit(1))
|
|
return result.scalar_one_or_none() is not None
|
|
|
|
|
|
async def upsert_chat_ban(db: AsyncSession, *, chat_id: int, user_id: int, banned_by_user_id: int) -> ChatBan:
|
|
existing = await get_chat_ban(db, chat_id=chat_id, user_id=user_id)
|
|
if existing:
|
|
existing.banned_by_user_id = banned_by_user_id
|
|
await db.flush()
|
|
return existing
|
|
ban = ChatBan(chat_id=chat_id, user_id=user_id, banned_by_user_id=banned_by_user_id)
|
|
db.add(ban)
|
|
await db.flush()
|
|
return ban
|
|
|
|
|
|
async def remove_chat_ban(db: AsyncSession, *, chat_id: int, user_id: int) -> None:
|
|
ban = await get_chat_ban(db, chat_id=chat_id, user_id=user_id)
|
|
if ban:
|
|
await db.delete(ban)
|