first commit
This commit is contained in:
0
app/chats/__init__.py
Normal file
0
app/chats/__init__.py
Normal file
50
app/chats/models.py
Normal file
50
app/chats/models.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import DateTime, Enum as SAEnum, ForeignKey, String, UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database.base import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.messages.models import Message
|
||||
from app.users.models import User
|
||||
|
||||
|
||||
class ChatType(str, Enum):
|
||||
PRIVATE = "private"
|
||||
GROUP = "group"
|
||||
CHANNEL = "channel"
|
||||
|
||||
|
||||
class ChatMemberRole(str, Enum):
|
||||
OWNER = "owner"
|
||||
ADMIN = "admin"
|
||||
MEMBER = "member"
|
||||
|
||||
|
||||
class Chat(Base):
|
||||
__tablename__ = "chats"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
type: Mapped[ChatType] = mapped_column(SAEnum(ChatType), nullable=False, index=True)
|
||||
title: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
|
||||
members: Mapped[list["ChatMember"]] = relationship(back_populates="chat", cascade="all, delete-orphan")
|
||||
messages: Mapped[list["Message"]] = relationship(back_populates="chat", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class ChatMember(Base):
|
||||
__tablename__ = "chat_members"
|
||||
__table_args__ = (UniqueConstraint("chat_id", "user_id", name="uq_chat_members_chat_id_user_id"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
chat_id: Mapped[int] = mapped_column(ForeignKey("chats.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
role: Mapped[ChatMemberRole] = mapped_column(SAEnum(ChatMemberRole), nullable=False, default=ChatMemberRole.MEMBER)
|
||||
joined_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
|
||||
chat: Mapped["Chat"] = relationship(back_populates="members")
|
||||
user: Mapped["User"] = relationship(back_populates="memberships")
|
||||
62
app/chats/repository.py
Normal file
62
app/chats/repository.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from sqlalchemy import Select, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.chats.models import Chat, ChatMember, ChatMemberRole, ChatType
|
||||
|
||||
|
||||
async def create_chat(db: AsyncSession, *, chat_type: ChatType, title: str | None) -> Chat:
|
||||
chat = Chat(type=chat_type, title=title)
|
||||
db.add(chat)
|
||||
await db.flush()
|
||||
return chat
|
||||
|
||||
|
||||
async def add_chat_member(db: AsyncSession, *, chat_id: int, user_id: int, role: ChatMemberRole) -> ChatMember:
|
||||
member = ChatMember(chat_id=chat_id, user_id=user_id, role=role)
|
||||
db.add(member)
|
||||
await db.flush()
|
||||
return member
|
||||
|
||||
|
||||
def _user_chats_query(user_id: int) -> Select[tuple[Chat]]:
|
||||
return (
|
||||
select(Chat)
|
||||
.join(ChatMember, ChatMember.chat_id == Chat.id)
|
||||
.where(ChatMember.user_id == user_id)
|
||||
.order_by(Chat.id.desc())
|
||||
)
|
||||
|
||||
|
||||
async def list_user_chats(db: AsyncSession, *, user_id: int, limit: int = 50, before_id: int | None = None) -> list[Chat]:
|
||||
query = _user_chats_query(user_id).limit(limit)
|
||||
if before_id is not None:
|
||||
query = query.where(Chat.id < before_id)
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def get_chat_by_id(db: AsyncSession, chat_id: int) -> Chat | None:
|
||||
result = await db.execute(select(Chat).where(Chat.id == chat_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_chat_member(db: AsyncSession, *, chat_id: int, user_id: int) -> ChatMember | None:
|
||||
result = await db.execute(
|
||||
select(ChatMember).where(
|
||||
ChatMember.chat_id == chat_id,
|
||||
ChatMember.user_id == user_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def list_chat_members(db: AsyncSession, *, chat_id: int) -> list[ChatMember]:
|
||||
result = await db.execute(select(ChatMember).where(ChatMember.chat_id == chat_id).order_by(ChatMember.id.asc()))
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def list_user_chat_ids(db: AsyncSession, *, user_id: int) -> list[int]:
|
||||
result = await db.execute(
|
||||
select(ChatMember.chat_id).where(ChatMember.user_id == user_id).order_by(ChatMember.chat_id.asc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
45
app/chats/router.py
Normal file
45
app/chats/router.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.service import get_current_user
|
||||
from app.chats.schemas import ChatCreateRequest, ChatDetailRead, ChatRead
|
||||
from app.chats.service import create_chat_for_user, get_chat_for_user, get_chats_for_user
|
||||
from app.database.session import get_db
|
||||
from app.users.models import User
|
||||
|
||||
router = APIRouter(prefix="/chats", tags=["chats"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[ChatRead])
|
||||
async def list_chats(
|
||||
limit: int = 50,
|
||||
before_id: int | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list[ChatRead]:
|
||||
return await get_chats_for_user(db, user_id=current_user.id, limit=limit, before_id=before_id)
|
||||
|
||||
|
||||
@router.post("", response_model=ChatRead)
|
||||
async def create_chat(
|
||||
payload: ChatCreateRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> ChatRead:
|
||||
return await create_chat_for_user(db, creator_id=current_user.id, payload=payload)
|
||||
|
||||
|
||||
@router.get("/{chat_id}", response_model=ChatDetailRead)
|
||||
async def get_chat(
|
||||
chat_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> ChatDetailRead:
|
||||
chat, members = await get_chat_for_user(db, chat_id=chat_id, user_id=current_user.id)
|
||||
return ChatDetailRead(
|
||||
id=chat.id,
|
||||
type=chat.type,
|
||||
title=chat.title,
|
||||
created_at=chat.created_at,
|
||||
members=members,
|
||||
)
|
||||
33
app/chats/schemas.py
Normal file
33
app/chats/schemas.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.chats.models import ChatMemberRole, ChatType
|
||||
|
||||
|
||||
class ChatRead(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
type: ChatType
|
||||
title: str | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ChatMemberRead(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
user_id: int
|
||||
role: ChatMemberRole
|
||||
joined_at: datetime
|
||||
|
||||
|
||||
class ChatDetailRead(ChatRead):
|
||||
members: list[ChatMemberRead]
|
||||
|
||||
|
||||
class ChatCreateRequest(BaseModel):
|
||||
type: ChatType
|
||||
title: str | None = Field(default=None, max_length=255)
|
||||
member_ids: list[int] = Field(default_factory=list)
|
||||
62
app/chats/service.py
Normal file
62
app/chats/service.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.chats import repository
|
||||
from app.chats.models import Chat, ChatMemberRole, ChatType
|
||||
from app.chats.schemas import ChatCreateRequest
|
||||
from app.users.repository import get_user_by_id
|
||||
|
||||
|
||||
async def create_chat_for_user(db: AsyncSession, *, creator_id: int, payload: ChatCreateRequest) -> Chat:
|
||||
member_ids = list(dict.fromkeys(payload.member_ids))
|
||||
member_ids = [member_id for member_id in member_ids if member_id != creator_id]
|
||||
|
||||
if payload.type == ChatType.PRIVATE and len(member_ids) != 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Private chat requires exactly one target user.",
|
||||
)
|
||||
if payload.type in {ChatType.GROUP, ChatType.CHANNEL} and not payload.title:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Group and channel chats require title.",
|
||||
)
|
||||
|
||||
for member_id in member_ids:
|
||||
user = await get_user_by_id(db, member_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User {member_id} not found")
|
||||
|
||||
chat = await repository.create_chat(db, chat_type=payload.type, title=payload.title)
|
||||
await repository.add_chat_member(db, chat_id=chat.id, user_id=creator_id, role=ChatMemberRole.OWNER)
|
||||
|
||||
default_role = ChatMemberRole.MEMBER
|
||||
for member_id in member_ids:
|
||||
await repository.add_chat_member(db, chat_id=chat.id, user_id=member_id, role=default_role)
|
||||
|
||||
await db.commit()
|
||||
return chat
|
||||
|
||||
|
||||
async def get_chats_for_user(db: AsyncSession, *, user_id: int, limit: int = 50, before_id: int | None = None) -> list[Chat]:
|
||||
safe_limit = max(1, min(limit, 100))
|
||||
return await repository.list_user_chats(db, user_id=user_id, limit=safe_limit, before_id=before_id)
|
||||
|
||||
|
||||
async def get_chat_for_user(db: AsyncSession, *, chat_id: int, user_id: int) -> tuple[Chat, list]:
|
||||
chat = await repository.get_chat_by_id(db, chat_id)
|
||||
if not chat:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
||||
|
||||
membership = await repository.get_chat_member(db, chat_id=chat_id, user_id=user_id)
|
||||
if not membership:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You are not a member of this chat")
|
||||
|
||||
members = await repository.list_chat_members(db, chat_id=chat_id)
|
||||
return chat, members
|
||||
|
||||
|
||||
async def ensure_chat_membership(db: AsyncSession, *, chat_id: int, user_id: int) -> None:
|
||||
membership = await repository.get_chat_member(db, chat_id=chat_id, user_id=user_id)
|
||||
if not membership:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You are not a member of this chat")
|
||||
Reference in New Issue
Block a user