from sqlalchemy import Select, String, func, or_, select from sqlalchemy.orm import aliased from sqlalchemy.ext.asyncio import AsyncSession from app.chats.models import Chat, ChatBan, ChatInviteLink, ChatMember, ChatMemberRole, ChatNotificationSetting, ChatType, ChatUserSetting from app.messages.models import Message, MessageHidden, MessageReceipt async def create_chat(db: AsyncSession, *, chat_type: ChatType, title: str | None) -> Chat: chat = Chat(type=chat_type, title=title) db.add(chat) await db.flush() return chat async def create_chat_with_meta( db: AsyncSession, *, chat_type: ChatType, title: str | None, handle: str | None, description: str | None, is_public: bool, is_saved: bool = False, ) -> Chat: chat = Chat( type=chat_type, title=title, handle=handle, description=description, is_public=is_public, is_saved=is_saved, ) db.add(chat) await db.flush() return chat async def add_chat_member(db: AsyncSession, *, chat_id: int, user_id: int, role: ChatMemberRole) -> ChatMember: member = ChatMember(chat_id=chat_id, user_id=user_id, role=role) db.add(member) await db.flush() return member async def delete_chat_member(db: AsyncSession, member: ChatMember) -> None: await db.delete(member) async def count_chat_members(db: AsyncSession, *, chat_id: int) -> int: result = await db.execute(select(func.count(ChatMember.id)).where(ChatMember.chat_id == chat_id)) return int(result.scalar_one()) def _user_chats_query(user_id: int, query: str | None = None) -> Select[tuple[Chat]]: stmt = ( select(Chat) .join(ChatMember, ChatMember.chat_id == Chat.id) .outerjoin( ChatUserSetting, (ChatUserSetting.chat_id == Chat.id) & (ChatUserSetting.user_id == user_id), ) .where( ChatMember.user_id == user_id, func.coalesce(ChatUserSetting.archived, False).is_(False), Chat.is_saved.is_(False), ) ) if query and query.strip(): q = f"%{query.strip()}%" stmt = stmt.where( or_( Chat.title.ilike(q), Chat.type.cast(String).ilike(q), ) ) return stmt.order_by( func.coalesce(ChatUserSetting.pinned, False).desc(), ChatUserSetting.pinned_at.desc().nullslast(), Chat.id.desc(), ) async def list_user_chats( db: AsyncSession, *, user_id: int, limit: int = 50, before_id: int | None = None, query: str | None = None, ) -> list[Chat]: query_stmt = _user_chats_query(user_id, query=query).limit(limit) if before_id is not None: query_stmt = query_stmt.where(Chat.id < before_id) result = await db.execute(query_stmt) return list(result.scalars().all()) async def list_archived_user_chats( db: AsyncSession, *, user_id: int, limit: int = 50, before_id: int | None = None, ) -> list[Chat]: stmt = ( select(Chat) .join(ChatMember, ChatMember.chat_id == Chat.id) .join( ChatUserSetting, (ChatUserSetting.chat_id == Chat.id) & (ChatUserSetting.user_id == user_id), ) .where(ChatMember.user_id == user_id, ChatUserSetting.archived.is_(True)) .order_by( func.coalesce(ChatUserSetting.pinned, False).desc(), ChatUserSetting.pinned_at.desc().nullslast(), Chat.id.desc(), ) .limit(limit) ) if before_id is not None: stmt = stmt.where(Chat.id < before_id) result = await db.execute(stmt) return list(result.scalars().all()) async def get_chat_by_id(db: AsyncSession, chat_id: int) -> Chat | None: result = await db.execute(select(Chat).where(Chat.id == chat_id)) return result.scalar_one_or_none() async def get_chat_by_handle(db: AsyncSession, handle: str) -> Chat | None: result = await db.execute(select(Chat).where(Chat.handle == handle)) return result.scalar_one_or_none() async def get_chat_member(db: AsyncSession, *, chat_id: int, user_id: int) -> ChatMember | None: result = await db.execute( select(ChatMember).where( ChatMember.chat_id == chat_id, ChatMember.user_id == user_id, ) ) return result.scalar_one_or_none() async def list_chat_members(db: AsyncSession, *, chat_id: int) -> list[ChatMember]: result = await db.execute(select(ChatMember).where(ChatMember.chat_id == chat_id).order_by(ChatMember.id.asc())) return list(result.scalars().all()) async def list_chat_member_user_ids(db: AsyncSession, *, chat_id: int) -> list[int]: result = await db.execute(select(ChatMember.user_id).where(ChatMember.chat_id == chat_id)) return list(result.scalars().all()) async def list_user_chat_ids(db: AsyncSession, *, user_id: int) -> list[int]: result = await db.execute( select(ChatMember.chat_id).where(ChatMember.user_id == user_id).order_by(ChatMember.chat_id.asc()) ) return list(result.scalars().all()) async def find_saved_chat_for_user(db: AsyncSession, *, user_id: int) -> Chat | None: stmt = ( select(Chat) .join(ChatMember, ChatMember.chat_id == Chat.id) .where(ChatMember.user_id == user_id, Chat.is_saved.is_(True)) .order_by(Chat.id.asc()) .limit(1) ) result = await db.execute(stmt) return result.scalar_one_or_none() async def discover_public_chats( db: AsyncSession, *, user_id: int, query: str | None = None, limit: int = 30, ) -> list[tuple[Chat, bool]]: q = select(Chat).where(Chat.is_public.is_(True), Chat.type.in_([ChatType.GROUP, ChatType.CHANNEL]), Chat.is_saved.is_(False)) if query and query.strip(): like = f"%{query.strip()}%" q = q.where(or_(Chat.title.ilike(like), Chat.handle.ilike(like), Chat.description.ilike(like))) q = q.order_by(Chat.id.desc()).limit(limit) chats = list((await db.execute(q)).scalars().all()) if not chats: return [] chat_ids = [c.id for c in chats] m_stmt = select(ChatMember.chat_id).where(ChatMember.user_id == user_id, ChatMember.chat_id.in_(chat_ids)) memberships = set((await db.execute(m_stmt)).scalars().all()) return [(chat, chat.id in memberships) for chat in chats] async def find_private_chat_between_users(db: AsyncSession, *, user_a_id: int, user_b_id: int) -> Chat | None: cm_a = aliased(ChatMember) cm_b = aliased(ChatMember) stmt = ( select(Chat) .join(cm_a, cm_a.chat_id == Chat.id) .join(cm_b, cm_b.chat_id == Chat.id) .where( Chat.type == ChatType.PRIVATE, cm_a.user_id == user_a_id, cm_b.user_id == user_b_id, ) .limit(1) ) result = await db.execute(stmt) return result.scalar_one_or_none() async def get_private_counterpart_user_id(db: AsyncSession, *, chat_id: int, user_id: int) -> int | None: stmt = select(ChatMember.user_id).where(ChatMember.chat_id == chat_id, ChatMember.user_id != user_id).limit(1) result = await db.execute(stmt) return result.scalar_one_or_none() async def get_unread_count_for_chat(db: AsyncSession, *, chat_id: int, user_id: int) -> int: last_read_subquery = ( select(MessageReceipt.last_read_message_id) .where(MessageReceipt.chat_id == chat_id, MessageReceipt.user_id == user_id) .limit(1) .scalar_subquery() ) stmt = ( select(func.count(Message.id)) .outerjoin( MessageHidden, (MessageHidden.message_id == Message.id) & (MessageHidden.user_id == user_id), ) .where( Message.chat_id == chat_id, Message.sender_id != user_id, MessageHidden.id.is_(None), Message.id > func.coalesce(last_read_subquery, 0), ) ) result = await db.execute(stmt) return int(result.scalar_one() or 0) async def get_unread_mentions_count_for_chat( db: AsyncSession, *, chat_id: int, user_id: int, username: str | None, ) -> int: normalized_username = (username or "").strip().lower() if not normalized_username: return 0 last_read_subquery = ( select(MessageReceipt.last_read_message_id) .where(MessageReceipt.chat_id == chat_id, MessageReceipt.user_id == user_id) .limit(1) .scalar_subquery() ) mention_like = f"%@{normalized_username}%" stmt = ( select(func.count(Message.id)) .outerjoin( MessageHidden, (MessageHidden.message_id == Message.id) & (MessageHidden.user_id == user_id), ) .where( Message.chat_id == chat_id, Message.sender_id != user_id, MessageHidden.id.is_(None), Message.id > func.coalesce(last_read_subquery, 0), Message.text.is_not(None), func.lower(Message.text).like(mention_like), ) ) result = await db.execute(stmt) return int(result.scalar_one() or 0) async def get_last_visible_message_for_user( db: AsyncSession, *, chat_id: int, user_id: int, ) -> Message | None: stmt = ( select(Message) .outerjoin( MessageHidden, (MessageHidden.message_id == Message.id) & (MessageHidden.user_id == user_id), ) .where( Message.chat_id == chat_id, MessageHidden.id.is_(None), ) .order_by(Message.id.desc()) .limit(1) ) result = await db.execute(stmt) return result.scalar_one_or_none() async def get_chat_notification_setting( db: AsyncSession, *, chat_id: int, user_id: int ) -> ChatNotificationSetting | None: result = await db.execute( select(ChatNotificationSetting).where( ChatNotificationSetting.chat_id == chat_id, ChatNotificationSetting.user_id == user_id, ) ) return result.scalar_one_or_none() async def upsert_chat_notification_setting( db: AsyncSession, *, chat_id: int, user_id: int, muted: bool, ) -> ChatNotificationSetting: setting = await get_chat_notification_setting(db, chat_id=chat_id, user_id=user_id) if setting: setting.muted = muted await db.flush() return setting setting = ChatNotificationSetting(chat_id=chat_id, user_id=user_id, muted=muted) db.add(setting) await db.flush() return setting async def is_chat_muted_for_user(db: AsyncSession, *, chat_id: int, user_id: int) -> bool: setting = await get_chat_notification_setting(db, chat_id=chat_id, user_id=user_id) return bool(setting and setting.muted) async def get_chat_user_setting(db: AsyncSession, *, chat_id: int, user_id: int) -> ChatUserSetting | None: result = await db.execute( select(ChatUserSetting).where(ChatUserSetting.chat_id == chat_id, ChatUserSetting.user_id == user_id) ) return result.scalar_one_or_none() async def upsert_chat_archived_setting( db: AsyncSession, *, chat_id: int, user_id: int, archived: bool, ) -> ChatUserSetting: setting = await get_chat_user_setting(db, chat_id=chat_id, user_id=user_id) if setting: setting.archived = archived await db.flush() return setting setting = ChatUserSetting(chat_id=chat_id, user_id=user_id, archived=archived) db.add(setting) await db.flush() return setting async def upsert_chat_pinned_setting( db: AsyncSession, *, chat_id: int, user_id: int, pinned: bool, ) -> ChatUserSetting: setting = await get_chat_user_setting(db, chat_id=chat_id, user_id=user_id) if setting: setting.pinned = pinned setting.pinned_at = func.now() if pinned else None await db.flush() return setting setting = ChatUserSetting( chat_id=chat_id, user_id=user_id, archived=False, pinned=pinned, pinned_at=func.now() if pinned else None, ) db.add(setting) await db.flush() return setting async def create_chat_invite_link( db: AsyncSession, *, chat_id: int, creator_user_id: int, token: str, ) -> ChatInviteLink: link = ChatInviteLink(chat_id=chat_id, creator_user_id=creator_user_id, token=token, is_active=True) db.add(link) await db.flush() return link async def get_active_chat_invite_by_token(db: AsyncSession, *, token: str) -> ChatInviteLink | None: result = await db.execute( select(ChatInviteLink) .where(ChatInviteLink.token == token, ChatInviteLink.is_active.is_(True)) .limit(1) ) return result.scalar_one_or_none() async def get_chat_ban(db: AsyncSession, *, chat_id: int, user_id: int) -> ChatBan | None: result = await db.execute(select(ChatBan).where(ChatBan.chat_id == chat_id, ChatBan.user_id == user_id)) return result.scalar_one_or_none() async def is_user_banned_in_chat(db: AsyncSession, *, chat_id: int, user_id: int) -> bool: result = await db.execute(select(ChatBan.id).where(ChatBan.chat_id == chat_id, ChatBan.user_id == user_id).limit(1)) return result.scalar_one_or_none() is not None async def upsert_chat_ban(db: AsyncSession, *, chat_id: int, user_id: int, banned_by_user_id: int) -> ChatBan: existing = await get_chat_ban(db, chat_id=chat_id, user_id=user_id) if existing: existing.banned_by_user_id = banned_by_user_id await db.flush() return existing ban = ChatBan(chat_id=chat_id, user_id=user_id, banned_by_user_id=banned_by_user_id) db.add(ban) await db.flush() return ban async def remove_chat_ban(db: AsyncSession, *, chat_id: int, user_id: int) -> None: ban = await get_chat_ban(db, chat_id=chat_id, user_id=user_id) if ban: await db.delete(ban) async def list_chat_bans(db: AsyncSession, *, chat_id: int) -> list[ChatBan]: result = await db.execute( select(ChatBan).where(ChatBan.chat_id == chat_id).order_by(ChatBan.created_at.desc(), ChatBan.id.desc()) ) return list(result.scalars().all())