from sqlalchemy import Select, String, func, or_, select from sqlalchemy.orm import aliased from sqlalchemy.ext.asyncio import AsyncSession from app.chats.models import Chat, ChatMember, ChatMemberRole, ChatType async def create_chat(db: AsyncSession, *, chat_type: ChatType, title: str | None) -> Chat: chat = Chat(type=chat_type, title=title) db.add(chat) await db.flush() return chat async def add_chat_member(db: AsyncSession, *, chat_id: int, user_id: int, role: ChatMemberRole) -> ChatMember: member = ChatMember(chat_id=chat_id, user_id=user_id, role=role) db.add(member) await db.flush() return member async def delete_chat_member(db: AsyncSession, member: ChatMember) -> None: await db.delete(member) async def count_chat_members(db: AsyncSession, *, chat_id: int) -> int: result = await db.execute(select(func.count(ChatMember.id)).where(ChatMember.chat_id == chat_id)) return int(result.scalar_one()) def _user_chats_query(user_id: int, query: str | None = None) -> Select[tuple[Chat]]: stmt = select(Chat).join(ChatMember, ChatMember.chat_id == Chat.id).where(ChatMember.user_id == user_id) if query and query.strip(): q = f"%{query.strip()}%" stmt = stmt.where( or_( Chat.title.ilike(q), Chat.type.cast(String).ilike(q), ) ) return stmt.order_by(Chat.id.desc()) async def list_user_chats( db: AsyncSession, *, user_id: int, limit: int = 50, before_id: int | None = None, query: str | None = None, ) -> list[Chat]: query_stmt = _user_chats_query(user_id, query=query).limit(limit) if before_id is not None: query_stmt = query_stmt.where(Chat.id < before_id) result = await db.execute(query_stmt) return list(result.scalars().all()) async def get_chat_by_id(db: AsyncSession, chat_id: int) -> Chat | None: result = await db.execute(select(Chat).where(Chat.id == chat_id)) return result.scalar_one_or_none() async def get_chat_member(db: AsyncSession, *, chat_id: int, user_id: int) -> ChatMember | None: result = await db.execute( select(ChatMember).where( ChatMember.chat_id == chat_id, ChatMember.user_id == user_id, ) ) return result.scalar_one_or_none() async def list_chat_members(db: AsyncSession, *, chat_id: int) -> list[ChatMember]: result = await db.execute(select(ChatMember).where(ChatMember.chat_id == chat_id).order_by(ChatMember.id.asc())) return list(result.scalars().all()) async def list_user_chat_ids(db: AsyncSession, *, user_id: int) -> list[int]: result = await db.execute( select(ChatMember.chat_id).where(ChatMember.user_id == user_id).order_by(ChatMember.chat_id.asc()) ) return list(result.scalars().all()) async def find_private_chat_between_users(db: AsyncSession, *, user_a_id: int, user_b_id: int) -> Chat | None: cm_a = aliased(ChatMember) cm_b = aliased(ChatMember) stmt = ( select(Chat) .join(cm_a, cm_a.chat_id == Chat.id) .join(cm_b, cm_b.chat_id == Chat.id) .where( Chat.type == ChatType.PRIVATE, cm_a.user_id == user_a_id, cm_b.user_id == user_b_id, ) .limit(1) ) result = await db.execute(stmt) return result.scalar_one_or_none()