first commit
This commit is contained in:
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
0
app/auth/__init__.py
Normal file
0
app/auth/__init__.py
Normal file
34
app/auth/models.py
Normal file
34
app/auth/models.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database.base import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.users.models import User
|
||||
|
||||
|
||||
class EmailVerificationToken(Base):
|
||||
__tablename__ = "email_verification_tokens"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
token: Mapped[str] = mapped_column(String(255), unique=True, index=True, nullable=False)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="email_verification_tokens")
|
||||
|
||||
|
||||
class PasswordResetToken(Base):
|
||||
__tablename__ = "password_reset_tokens"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
token: Mapped[str] = mapped_column(String(255), unique=True, index=True, nullable=False)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="password_reset_tokens")
|
||||
46
app/auth/repository.py
Normal file
46
app/auth/repository.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.models import EmailVerificationToken, PasswordResetToken
|
||||
|
||||
|
||||
async def create_email_verification_token(db: AsyncSession, user_id: int, token: str, expires_at: datetime) -> None:
|
||||
db.add(
|
||||
EmailVerificationToken(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
expires_at=expires_at,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def get_email_verification_token(db: AsyncSession, token: str) -> EmailVerificationToken | None:
|
||||
result = await db.execute(select(EmailVerificationToken).where(EmailVerificationToken.token == token))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def delete_email_verification_tokens_for_user(db: AsyncSession, user_id: int) -> None:
|
||||
await db.execute(delete(EmailVerificationToken).where(EmailVerificationToken.user_id == user_id))
|
||||
|
||||
|
||||
async def create_password_reset_token(db: AsyncSession, user_id: int, token: str, expires_at: datetime) -> None:
|
||||
db.add(
|
||||
PasswordResetToken(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
expires_at=expires_at,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def get_password_reset_token(db: AsyncSession, token: str) -> PasswordResetToken | None:
|
||||
result = await db.execute(select(PasswordResetToken).where(PasswordResetToken.token == token))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def delete_password_reset_tokens_for_user(db: AsyncSession, user_id: int) -> None:
|
||||
await db.execute(delete(PasswordResetToken).where(PasswordResetToken.user_id == user_id))
|
||||
81
app/auth/router.py
Normal file
81
app/auth/router.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.schemas import (
|
||||
AuthUserResponse,
|
||||
LoginRequest,
|
||||
MessageResponse,
|
||||
RegisterRequest,
|
||||
RequestPasswordResetRequest,
|
||||
ResendVerificationRequest,
|
||||
ResetPasswordRequest,
|
||||
TokenResponse,
|
||||
VerifyEmailRequest,
|
||||
)
|
||||
from app.auth.service import (
|
||||
get_current_user,
|
||||
get_email_sender,
|
||||
login_user,
|
||||
register_user,
|
||||
request_password_reset,
|
||||
resend_verification_email,
|
||||
reset_password,
|
||||
verify_email,
|
||||
)
|
||||
from app.database.session import get_db
|
||||
from app.email.service import EmailService
|
||||
from app.users.models import User
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", response_model=MessageResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
payload: RegisterRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
email_service: EmailService = Depends(get_email_sender),
|
||||
) -> MessageResponse:
|
||||
await register_user(db, payload, email_service)
|
||||
return MessageResponse(message="Registration successful. Verification email sent.")
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(payload: LoginRequest, db: AsyncSession = Depends(get_db)) -> TokenResponse:
|
||||
return await login_user(db, payload)
|
||||
|
||||
|
||||
@router.post("/verify-email", response_model=MessageResponse)
|
||||
async def verify_email_endpoint(payload: VerifyEmailRequest, db: AsyncSession = Depends(get_db)) -> MessageResponse:
|
||||
await verify_email(db, payload)
|
||||
return MessageResponse(message="Email verified successfully.")
|
||||
|
||||
|
||||
@router.post("/resend-verification", response_model=MessageResponse)
|
||||
async def resend_verification(
|
||||
payload: ResendVerificationRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
email_service: EmailService = Depends(get_email_sender),
|
||||
) -> MessageResponse:
|
||||
await resend_verification_email(db, payload, email_service)
|
||||
return MessageResponse(message="If the account exists, a verification email was sent.")
|
||||
|
||||
|
||||
@router.post("/request-password-reset", response_model=MessageResponse)
|
||||
async def request_password_reset_endpoint(
|
||||
payload: RequestPasswordResetRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
email_service: EmailService = Depends(get_email_sender),
|
||||
) -> MessageResponse:
|
||||
await request_password_reset(db, payload, email_service)
|
||||
return MessageResponse(message="If the account exists, a reset email was sent.")
|
||||
|
||||
|
||||
@router.post("/reset-password", response_model=MessageResponse)
|
||||
async def reset_password_endpoint(payload: ResetPasswordRequest, db: AsyncSession = Depends(get_db)) -> MessageResponse:
|
||||
await reset_password(db, payload)
|
||||
return MessageResponse(message="Password reset successfully.")
|
||||
|
||||
|
||||
@router.get("/me", response_model=AuthUserResponse)
|
||||
async def me(current_user: User = Depends(get_current_user)) -> AuthUserResponse:
|
||||
return current_user
|
||||
53
app/auth/schemas.py
Normal file
53
app/auth/schemas.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
email: EmailStr
|
||||
username: str = Field(min_length=3, max_length=50)
|
||||
password: str = Field(min_length=8, max_length=128)
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=8, max_length=128)
|
||||
|
||||
|
||||
class VerifyEmailRequest(BaseModel):
|
||||
token: str = Field(min_length=16, max_length=512)
|
||||
|
||||
|
||||
class ResendVerificationRequest(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class RequestPasswordResetRequest(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
token: str = Field(min_length=16, max_length=512)
|
||||
new_password: str = Field(min_length=8, max_length=128)
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class AuthUserResponse(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
email: EmailStr
|
||||
username: str
|
||||
avatar_url: str | None = None
|
||||
email_verified: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
195
app/auth/service.py
Normal file
195
app/auth/service.py
Normal file
@@ -0,0 +1,195 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth import repository as auth_repository
|
||||
from app.auth.schemas import (
|
||||
LoginRequest,
|
||||
RegisterRequest,
|
||||
RequestPasswordResetRequest,
|
||||
ResendVerificationRequest,
|
||||
ResetPasswordRequest,
|
||||
TokenResponse,
|
||||
VerifyEmailRequest,
|
||||
)
|
||||
from app.config.settings import settings
|
||||
from app.database.session import get_db
|
||||
from app.email.service import EmailService, get_email_service
|
||||
from app.users.models import User
|
||||
from app.users.repository import create_user, get_user_by_email, get_user_by_id, get_user_by_username
|
||||
from app.utils.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
generate_random_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
)
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/login")
|
||||
|
||||
|
||||
async def register_user(
|
||||
db: AsyncSession,
|
||||
payload: RegisterRequest,
|
||||
email_service: EmailService,
|
||||
) -> None:
|
||||
existing_email = await get_user_by_email(db, payload.email)
|
||||
if existing_email:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email is already registered")
|
||||
|
||||
existing_username = await get_user_by_username(db, payload.username)
|
||||
if existing_username:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Username is already taken")
|
||||
|
||||
user = await create_user(
|
||||
db,
|
||||
email=payload.email,
|
||||
username=payload.username,
|
||||
password_hash=hash_password(payload.password),
|
||||
)
|
||||
|
||||
verification_token = generate_random_token()
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours)
|
||||
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
|
||||
await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at)
|
||||
await db.commit()
|
||||
|
||||
await email_service.send_verification_email(payload.email, verification_token)
|
||||
|
||||
|
||||
async def verify_email(db: AsyncSession, payload: VerifyEmailRequest) -> None:
|
||||
record = await auth_repository.get_email_verification_token(db, payload.token)
|
||||
if not record:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification token")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc)
|
||||
if expires_at < now:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Verification token expired")
|
||||
|
||||
user = await get_user_by_id(db, record.user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
user.email_verified = True
|
||||
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def resend_verification_email(
|
||||
db: AsyncSession,
|
||||
payload: ResendVerificationRequest,
|
||||
email_service: EmailService,
|
||||
) -> None:
|
||||
user = await get_user_by_email(db, payload.email)
|
||||
if not user or user.email_verified:
|
||||
return
|
||||
|
||||
verification_token = generate_random_token()
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours)
|
||||
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
|
||||
await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at)
|
||||
await db.commit()
|
||||
|
||||
await email_service.send_verification_email(user.email, verification_token)
|
||||
|
||||
|
||||
async def login_user(db: AsyncSession, payload: LoginRequest) -> TokenResponse:
|
||||
user = await get_user_by_email(db, payload.email)
|
||||
if not user or not verify_password(payload.password, user.password_hash):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
||||
|
||||
if not user.email_verified:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified")
|
||||
|
||||
return TokenResponse(
|
||||
access_token=create_access_token(str(user.id)),
|
||||
refresh_token=create_refresh_token(str(user.id)),
|
||||
)
|
||||
|
||||
|
||||
async def request_password_reset(
|
||||
db: AsyncSession,
|
||||
payload: RequestPasswordResetRequest,
|
||||
email_service: EmailService,
|
||||
) -> None:
|
||||
user = await get_user_by_email(db, payload.email)
|
||||
if not user:
|
||||
return
|
||||
|
||||
reset_token = generate_random_token()
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.password_reset_token_expire_hours)
|
||||
await auth_repository.delete_password_reset_tokens_for_user(db, user.id)
|
||||
await auth_repository.create_password_reset_token(db, user.id, reset_token, expires_at)
|
||||
await db.commit()
|
||||
|
||||
await email_service.send_password_reset_email(user.email, reset_token)
|
||||
|
||||
|
||||
async def reset_password(db: AsyncSession, payload: ResetPasswordRequest) -> None:
|
||||
record = await auth_repository.get_password_reset_token(db, payload.token)
|
||||
if not record:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid reset token")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc)
|
||||
if expires_at < now:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Reset token expired")
|
||||
|
||||
user = await get_user_by_id(db, record.user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
user.password_hash = hash_password(payload.new_password)
|
||||
await auth_repository.delete_password_reset_tokens_for_user(db, user.id)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db)) -> User:
|
||||
credentials_error = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
except ValueError as exc:
|
||||
raise credentials_error from exc
|
||||
|
||||
if payload.get("type") != "access":
|
||||
raise credentials_error
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id or not str(user_id).isdigit():
|
||||
raise credentials_error
|
||||
|
||||
user = await get_user_by_id(db, int(user_id))
|
||||
if not user:
|
||||
raise credentials_error
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_user_for_ws(token: str, db: AsyncSession) -> User:
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") from exc
|
||||
|
||||
if payload.get("type") != "access":
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type")
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id or not str(user_id).isdigit():
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload")
|
||||
|
||||
user = await get_user_by_id(db, int(user_id))
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
||||
return user
|
||||
|
||||
|
||||
def get_email_sender() -> EmailService:
|
||||
return get_email_service()
|
||||
0
app/chats/__init__.py
Normal file
0
app/chats/__init__.py
Normal file
50
app/chats/models.py
Normal file
50
app/chats/models.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import DateTime, Enum as SAEnum, ForeignKey, String, UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database.base import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.messages.models import Message
|
||||
from app.users.models import User
|
||||
|
||||
|
||||
class ChatType(str, Enum):
|
||||
PRIVATE = "private"
|
||||
GROUP = "group"
|
||||
CHANNEL = "channel"
|
||||
|
||||
|
||||
class ChatMemberRole(str, Enum):
|
||||
OWNER = "owner"
|
||||
ADMIN = "admin"
|
||||
MEMBER = "member"
|
||||
|
||||
|
||||
class Chat(Base):
|
||||
__tablename__ = "chats"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
type: Mapped[ChatType] = mapped_column(SAEnum(ChatType), nullable=False, index=True)
|
||||
title: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
|
||||
members: Mapped[list["ChatMember"]] = relationship(back_populates="chat", cascade="all, delete-orphan")
|
||||
messages: Mapped[list["Message"]] = relationship(back_populates="chat", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class ChatMember(Base):
|
||||
__tablename__ = "chat_members"
|
||||
__table_args__ = (UniqueConstraint("chat_id", "user_id", name="uq_chat_members_chat_id_user_id"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
chat_id: Mapped[int] = mapped_column(ForeignKey("chats.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
role: Mapped[ChatMemberRole] = mapped_column(SAEnum(ChatMemberRole), nullable=False, default=ChatMemberRole.MEMBER)
|
||||
joined_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
|
||||
chat: Mapped["Chat"] = relationship(back_populates="members")
|
||||
user: Mapped["User"] = relationship(back_populates="memberships")
|
||||
62
app/chats/repository.py
Normal file
62
app/chats/repository.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from sqlalchemy import Select, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.chats.models import Chat, ChatMember, ChatMemberRole, ChatType
|
||||
|
||||
|
||||
async def create_chat(db: AsyncSession, *, chat_type: ChatType, title: str | None) -> Chat:
|
||||
chat = Chat(type=chat_type, title=title)
|
||||
db.add(chat)
|
||||
await db.flush()
|
||||
return chat
|
||||
|
||||
|
||||
async def add_chat_member(db: AsyncSession, *, chat_id: int, user_id: int, role: ChatMemberRole) -> ChatMember:
|
||||
member = ChatMember(chat_id=chat_id, user_id=user_id, role=role)
|
||||
db.add(member)
|
||||
await db.flush()
|
||||
return member
|
||||
|
||||
|
||||
def _user_chats_query(user_id: int) -> Select[tuple[Chat]]:
|
||||
return (
|
||||
select(Chat)
|
||||
.join(ChatMember, ChatMember.chat_id == Chat.id)
|
||||
.where(ChatMember.user_id == user_id)
|
||||
.order_by(Chat.id.desc())
|
||||
)
|
||||
|
||||
|
||||
async def list_user_chats(db: AsyncSession, *, user_id: int, limit: int = 50, before_id: int | None = None) -> list[Chat]:
|
||||
query = _user_chats_query(user_id).limit(limit)
|
||||
if before_id is not None:
|
||||
query = query.where(Chat.id < before_id)
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def get_chat_by_id(db: AsyncSession, chat_id: int) -> Chat | None:
|
||||
result = await db.execute(select(Chat).where(Chat.id == chat_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_chat_member(db: AsyncSession, *, chat_id: int, user_id: int) -> ChatMember | None:
|
||||
result = await db.execute(
|
||||
select(ChatMember).where(
|
||||
ChatMember.chat_id == chat_id,
|
||||
ChatMember.user_id == user_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def list_chat_members(db: AsyncSession, *, chat_id: int) -> list[ChatMember]:
|
||||
result = await db.execute(select(ChatMember).where(ChatMember.chat_id == chat_id).order_by(ChatMember.id.asc()))
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def list_user_chat_ids(db: AsyncSession, *, user_id: int) -> list[int]:
|
||||
result = await db.execute(
|
||||
select(ChatMember.chat_id).where(ChatMember.user_id == user_id).order_by(ChatMember.chat_id.asc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
45
app/chats/router.py
Normal file
45
app/chats/router.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.service import get_current_user
|
||||
from app.chats.schemas import ChatCreateRequest, ChatDetailRead, ChatRead
|
||||
from app.chats.service import create_chat_for_user, get_chat_for_user, get_chats_for_user
|
||||
from app.database.session import get_db
|
||||
from app.users.models import User
|
||||
|
||||
router = APIRouter(prefix="/chats", tags=["chats"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[ChatRead])
|
||||
async def list_chats(
|
||||
limit: int = 50,
|
||||
before_id: int | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list[ChatRead]:
|
||||
return await get_chats_for_user(db, user_id=current_user.id, limit=limit, before_id=before_id)
|
||||
|
||||
|
||||
@router.post("", response_model=ChatRead)
|
||||
async def create_chat(
|
||||
payload: ChatCreateRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> ChatRead:
|
||||
return await create_chat_for_user(db, creator_id=current_user.id, payload=payload)
|
||||
|
||||
|
||||
@router.get("/{chat_id}", response_model=ChatDetailRead)
|
||||
async def get_chat(
|
||||
chat_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> ChatDetailRead:
|
||||
chat, members = await get_chat_for_user(db, chat_id=chat_id, user_id=current_user.id)
|
||||
return ChatDetailRead(
|
||||
id=chat.id,
|
||||
type=chat.type,
|
||||
title=chat.title,
|
||||
created_at=chat.created_at,
|
||||
members=members,
|
||||
)
|
||||
33
app/chats/schemas.py
Normal file
33
app/chats/schemas.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.chats.models import ChatMemberRole, ChatType
|
||||
|
||||
|
||||
class ChatRead(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
type: ChatType
|
||||
title: str | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ChatMemberRead(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
user_id: int
|
||||
role: ChatMemberRole
|
||||
joined_at: datetime
|
||||
|
||||
|
||||
class ChatDetailRead(ChatRead):
|
||||
members: list[ChatMemberRead]
|
||||
|
||||
|
||||
class ChatCreateRequest(BaseModel):
|
||||
type: ChatType
|
||||
title: str | None = Field(default=None, max_length=255)
|
||||
member_ids: list[int] = Field(default_factory=list)
|
||||
62
app/chats/service.py
Normal file
62
app/chats/service.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.chats import repository
|
||||
from app.chats.models import Chat, ChatMemberRole, ChatType
|
||||
from app.chats.schemas import ChatCreateRequest
|
||||
from app.users.repository import get_user_by_id
|
||||
|
||||
|
||||
async def create_chat_for_user(db: AsyncSession, *, creator_id: int, payload: ChatCreateRequest) -> Chat:
|
||||
member_ids = list(dict.fromkeys(payload.member_ids))
|
||||
member_ids = [member_id for member_id in member_ids if member_id != creator_id]
|
||||
|
||||
if payload.type == ChatType.PRIVATE and len(member_ids) != 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Private chat requires exactly one target user.",
|
||||
)
|
||||
if payload.type in {ChatType.GROUP, ChatType.CHANNEL} and not payload.title:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Group and channel chats require title.",
|
||||
)
|
||||
|
||||
for member_id in member_ids:
|
||||
user = await get_user_by_id(db, member_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User {member_id} not found")
|
||||
|
||||
chat = await repository.create_chat(db, chat_type=payload.type, title=payload.title)
|
||||
await repository.add_chat_member(db, chat_id=chat.id, user_id=creator_id, role=ChatMemberRole.OWNER)
|
||||
|
||||
default_role = ChatMemberRole.MEMBER
|
||||
for member_id in member_ids:
|
||||
await repository.add_chat_member(db, chat_id=chat.id, user_id=member_id, role=default_role)
|
||||
|
||||
await db.commit()
|
||||
return chat
|
||||
|
||||
|
||||
async def get_chats_for_user(db: AsyncSession, *, user_id: int, limit: int = 50, before_id: int | None = None) -> list[Chat]:
|
||||
safe_limit = max(1, min(limit, 100))
|
||||
return await repository.list_user_chats(db, user_id=user_id, limit=safe_limit, before_id=before_id)
|
||||
|
||||
|
||||
async def get_chat_for_user(db: AsyncSession, *, chat_id: int, user_id: int) -> tuple[Chat, list]:
|
||||
chat = await repository.get_chat_by_id(db, chat_id)
|
||||
if not chat:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
||||
|
||||
membership = await repository.get_chat_member(db, chat_id=chat_id, user_id=user_id)
|
||||
if not membership:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You are not a member of this chat")
|
||||
|
||||
members = await repository.list_chat_members(db, chat_id=chat_id)
|
||||
return chat, members
|
||||
|
||||
|
||||
async def ensure_chat_membership(db: AsyncSession, *, chat_id: int, user_id: int) -> None:
|
||||
membership = await repository.get_chat_member(db, chat_id=chat_id, user_id=user_id)
|
||||
if not membership:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You are not a member of this chat")
|
||||
0
app/config/__init__.py
Normal file
0
app/config/__init__.py
Normal file
41
app/config/settings.py
Normal file
41
app/config/settings.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
app_name: str = "BenyaMessenger"
|
||||
environment: str = "development"
|
||||
debug: bool = True
|
||||
api_v1_prefix: str = "/api/v1"
|
||||
auto_create_tables: bool = True
|
||||
|
||||
secret_key: str = Field(default="change-me-please-12345", min_length=16)
|
||||
access_token_expire_minutes: int = 30
|
||||
refresh_token_expire_days: int = 30
|
||||
jwt_algorithm: str = "HS256"
|
||||
email_verification_token_expire_hours: int = 24
|
||||
password_reset_token_expire_hours: int = 1
|
||||
|
||||
postgres_dsn: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/messenger"
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
|
||||
s3_endpoint_url: str = "http://localhost:9000"
|
||||
s3_access_key: str = "minioadmin"
|
||||
s3_secret_key: str = "minioadmin"
|
||||
s3_region: str = "us-east-1"
|
||||
s3_bucket_name: str = "messenger-media"
|
||||
s3_presign_expire_seconds: int = 900
|
||||
max_upload_size_bytes: int = 104857600
|
||||
frontend_base_url: str = "http://localhost:5173"
|
||||
|
||||
smtp_host: str = "localhost"
|
||||
smtp_port: int = 1025
|
||||
smtp_username: str = ""
|
||||
smtp_password: str = ""
|
||||
smtp_use_tls: bool = False
|
||||
smtp_from_email: str = "no-reply@benyamessenger.local"
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
0
app/database/__init__.py
Normal file
0
app/database/__init__.py
Normal file
15
app/database/base.py
Normal file
15
app/database/base.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
NAMING_CONVENTION = {
|
||||
"ix": "ix_%(column_0_label)s",
|
||||
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk_%(table_name)s",
|
||||
}
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
metadata = MetaData(naming_convention=NAMING_CONVENTION)
|
||||
19
app/database/models.py
Normal file
19
app/database/models.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from app.auth.models import EmailVerificationToken, PasswordResetToken
|
||||
from app.chats.models import Chat, ChatMember
|
||||
from app.email.models import EmailLog
|
||||
from app.media.models import Attachment
|
||||
from app.messages.models import Message
|
||||
from app.notifications.models import NotificationLog
|
||||
from app.users.models import User
|
||||
|
||||
__all__ = [
|
||||
"Attachment",
|
||||
"Chat",
|
||||
"ChatMember",
|
||||
"EmailLog",
|
||||
"EmailVerificationToken",
|
||||
"Message",
|
||||
"NotificationLog",
|
||||
"PasswordResetToken",
|
||||
"User",
|
||||
]
|
||||
25
app/database/session.py
Normal file
25
app/database/session.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.postgres_dsn,
|
||||
echo=settings.debug,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncIterator[AsyncSession]:
|
||||
async with AsyncSessionLocal() as session:
|
||||
yield session
|
||||
1
app/email/__init__.py
Normal file
1
app/email/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
16
app/email/models.py
Normal file
16
app/email/models.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database.base import Base
|
||||
|
||||
|
||||
class EmailLog(Base):
|
||||
__tablename__ = "email_logs"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
recipient: Mapped[str] = mapped_column(String(255), index=True, nullable=False)
|
||||
subject: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
body: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
7
app/email/repository.py
Normal file
7
app/email/repository.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.email.models import EmailLog
|
||||
|
||||
|
||||
async def create_email_log(db: AsyncSession, *, recipient: str, subject: str, body: str) -> None:
|
||||
db.add(EmailLog(recipient=recipient, subject=subject, body=body))
|
||||
3
app/email/router.py
Normal file
3
app/email/router.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter(prefix="/email", tags=["email"])
|
||||
7
app/email/schemas.py
Normal file
7
app/email/schemas.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
|
||||
class EmailPayload(BaseModel):
|
||||
recipient: EmailStr
|
||||
subject: str
|
||||
body: str
|
||||
23
app/email/service.py
Normal file
23
app/email/service.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import logging
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmailService:
|
||||
async def send_verification_email(self, email: str, token: str) -> None:
|
||||
verify_link = f"{settings.frontend_base_url}/verify-email?token={token}"
|
||||
subject = "Verify your BenyaMessenger account"
|
||||
body = f"Open this link to verify your account: {verify_link}"
|
||||
logger.info("EMAIL to=%s subject=%s body=%s", email, subject, body)
|
||||
|
||||
async def send_password_reset_email(self, email: str, token: str) -> None:
|
||||
reset_link = f"{settings.frontend_base_url}/reset-password?token={token}"
|
||||
subject = "Reset your BenyaMessenger password"
|
||||
body = f"Open this link to reset your password: {reset_link}"
|
||||
logger.info("EMAIL to=%s subject=%s body=%s", email, subject, body)
|
||||
|
||||
|
||||
def get_email_service() -> EmailService:
|
||||
return EmailService()
|
||||
43
app/main.py
Normal file
43
app/main.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.auth.router import router as auth_router
|
||||
from app.chats.router import router as chats_router
|
||||
from app.config.settings import settings
|
||||
from app.database import models # noqa: F401
|
||||
from app.database.base import Base
|
||||
from app.database.session import engine
|
||||
from app.media.router import router as media_router
|
||||
from app.messages.router import router as messages_router
|
||||
from app.notifications.router import router as notifications_router
|
||||
from app.realtime.router import router as realtime_router
|
||||
from app.realtime.service import realtime_gateway
|
||||
from app.users.router import router as users_router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI):
|
||||
await realtime_gateway.start()
|
||||
if settings.auto_create_tables:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
await realtime_gateway.stop()
|
||||
|
||||
|
||||
app = FastAPI(title=settings.app_name, debug=settings.debug, lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health", tags=["health"])
|
||||
async def health() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
app.include_router(auth_router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(users_router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(chats_router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(messages_router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(media_router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(notifications_router, prefix=settings.api_v1_prefix)
|
||||
app.include_router(realtime_router, prefix=settings.api_v1_prefix)
|
||||
0
app/media/__init__.py
Normal file
0
app/media/__init__.py
Normal file
16
app/media/models.py
Normal file
16
app/media/models.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from sqlalchemy import ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database.base import Base
|
||||
|
||||
|
||||
class Attachment(Base):
|
||||
__tablename__ = "attachments"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
message_id: Mapped[int] = mapped_column(ForeignKey("messages.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
file_url: Mapped[str] = mapped_column(String(1024), nullable=False)
|
||||
file_type: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
file_size: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
|
||||
message = relationship("Message", back_populates="attachments")
|
||||
26
app/media/repository.py
Normal file
26
app/media/repository.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.media.models import Attachment
|
||||
|
||||
|
||||
async def create_attachment(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
message_id: int,
|
||||
file_url: str,
|
||||
file_type: str,
|
||||
file_size: int,
|
||||
) -> Attachment:
|
||||
attachment = Attachment(
|
||||
message_id=message_id,
|
||||
file_url=file_url,
|
||||
file_type=file_type,
|
||||
file_size=file_size,
|
||||
)
|
||||
db.add(attachment)
|
||||
await db.flush()
|
||||
return attachment
|
||||
|
||||
|
||||
async def get_attachment_by_id(db: AsyncSession, attachment_id: int) -> Attachment | None:
|
||||
return await db.get(Attachment, attachment_id)
|
||||
27
app/media/router.py
Normal file
27
app/media/router.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.service import get_current_user
|
||||
from app.database.session import get_db
|
||||
from app.media.schemas import AttachmentCreateRequest, AttachmentRead, UploadUrlRequest, UploadUrlResponse
|
||||
from app.media.service import generate_upload_url, store_attachment_metadata
|
||||
from app.users.models import User
|
||||
|
||||
router = APIRouter(prefix="/media", tags=["media"])
|
||||
|
||||
|
||||
@router.post("/upload-url", response_model=UploadUrlResponse)
|
||||
async def create_upload_url(
|
||||
payload: UploadUrlRequest,
|
||||
_current_user: User = Depends(get_current_user),
|
||||
) -> UploadUrlResponse:
|
||||
return await generate_upload_url(payload)
|
||||
|
||||
|
||||
@router.post("/attachments", response_model=AttachmentRead)
|
||||
async def create_attachment_metadata(
|
||||
payload: AttachmentCreateRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> AttachmentRead:
|
||||
return await store_attachment_metadata(db, user_id=current_user.id, payload=payload)
|
||||
32
app/media/schemas.py
Normal file
32
app/media/schemas.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class UploadUrlRequest(BaseModel):
|
||||
file_name: str = Field(min_length=1, max_length=255)
|
||||
file_type: str = Field(min_length=1, max_length=64)
|
||||
file_size: int = Field(gt=0)
|
||||
|
||||
|
||||
class UploadUrlResponse(BaseModel):
|
||||
upload_url: str
|
||||
file_url: str
|
||||
object_key: str
|
||||
expires_in: int
|
||||
required_headers: dict[str, str]
|
||||
|
||||
|
||||
class AttachmentCreateRequest(BaseModel):
|
||||
message_id: int
|
||||
file_url: str = Field(min_length=1, max_length=1024)
|
||||
file_type: str = Field(min_length=1, max_length=64)
|
||||
file_size: int = Field(gt=0)
|
||||
|
||||
|
||||
class AttachmentRead(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
message_id: int
|
||||
file_url: str
|
||||
file_type: str
|
||||
file_size: int
|
||||
127
app/media/service.py
Normal file
127
app/media/service.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import re
|
||||
from urllib.parse import quote
|
||||
from uuid import uuid4
|
||||
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import BotoCoreError, ClientError
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.media import repository
|
||||
from app.media.schemas import AttachmentCreateRequest, AttachmentRead, UploadUrlRequest, UploadUrlResponse
|
||||
from app.messages.repository import get_message_by_id
|
||||
|
||||
ALLOWED_MIME_TYPES = {
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/webp",
|
||||
"video/mp4",
|
||||
"video/webm",
|
||||
"audio/mpeg",
|
||||
"audio/ogg",
|
||||
"audio/wav",
|
||||
"application/pdf",
|
||||
"application/zip",
|
||||
"text/plain",
|
||||
}
|
||||
|
||||
_SAFE_NAME_RE = re.compile(r"[^a-zA-Z0-9._-]+")
|
||||
|
||||
|
||||
def _sanitize_filename(file_name: str) -> str:
|
||||
sanitized = _SAFE_NAME_RE.sub("_", file_name).strip("._")
|
||||
if not sanitized:
|
||||
sanitized = "file.bin"
|
||||
return sanitized[:120]
|
||||
|
||||
|
||||
def _build_file_url(bucket: str, object_key: str) -> str:
|
||||
base = settings.s3_endpoint_url.rstrip("/")
|
||||
encoded_key = quote(object_key)
|
||||
return f"{base}/{bucket}/{encoded_key}"
|
||||
|
||||
|
||||
def _allowed_file_url_prefix() -> str:
|
||||
return f"{settings.s3_endpoint_url.rstrip('/')}/{settings.s3_bucket_name}/"
|
||||
|
||||
|
||||
def _validate_media(file_type: str, file_size: int) -> None:
|
||||
if file_type not in ALLOWED_MIME_TYPES:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Unsupported file type")
|
||||
if file_size > settings.max_upload_size_bytes:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="File size exceeds limit")
|
||||
|
||||
|
||||
def _get_s3_client():
|
||||
return boto3.client(
|
||||
"s3",
|
||||
endpoint_url=settings.s3_endpoint_url,
|
||||
aws_access_key_id=settings.s3_access_key,
|
||||
aws_secret_access_key=settings.s3_secret_key,
|
||||
region_name=settings.s3_region,
|
||||
config=Config(signature_version="s3v4", s3={"addressing_style": "path"}),
|
||||
)
|
||||
|
||||
|
||||
async def generate_upload_url(payload: UploadUrlRequest) -> UploadUrlResponse:
|
||||
_validate_media(payload.file_type, payload.file_size)
|
||||
|
||||
file_name = _sanitize_filename(payload.file_name)
|
||||
object_key = f"uploads/{uuid4()}-{file_name}"
|
||||
bucket = settings.s3_bucket_name
|
||||
|
||||
try:
|
||||
s3_client = _get_s3_client()
|
||||
upload_url = s3_client.generate_presigned_url(
|
||||
"put_object",
|
||||
Params={
|
||||
"Bucket": bucket,
|
||||
"Key": object_key,
|
||||
"ContentType": payload.file_type,
|
||||
},
|
||||
ExpiresIn=settings.s3_presign_expire_seconds,
|
||||
HttpMethod="PUT",
|
||||
)
|
||||
except (BotoCoreError, ClientError) as exc:
|
||||
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Storage service unavailable") from exc
|
||||
|
||||
return UploadUrlResponse(
|
||||
upload_url=upload_url,
|
||||
file_url=_build_file_url(bucket, object_key),
|
||||
object_key=object_key,
|
||||
expires_in=settings.s3_presign_expire_seconds,
|
||||
required_headers={"Content-Type": payload.file_type},
|
||||
)
|
||||
|
||||
|
||||
async def store_attachment_metadata(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: int,
|
||||
payload: AttachmentCreateRequest,
|
||||
) -> AttachmentRead:
|
||||
_validate_media(payload.file_type, payload.file_size)
|
||||
if not payload.file_url.startswith(_allowed_file_url_prefix()):
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Invalid file URL")
|
||||
|
||||
message = await get_message_by_id(db, payload.message_id)
|
||||
if not message:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
||||
if message.sender_id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only the message sender can attach files",
|
||||
)
|
||||
|
||||
attachment = await repository.create_attachment(
|
||||
db,
|
||||
message_id=payload.message_id,
|
||||
file_url=payload.file_url,
|
||||
file_type=payload.file_type,
|
||||
file_size=payload.file_size,
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(attachment)
|
||||
return AttachmentRead.model_validate(attachment)
|
||||
0
app/messages/__init__.py
Normal file
0
app/messages/__init__.py
Normal file
44
app/messages/models.py
Normal file
44
app/messages/models.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import DateTime, Enum as SAEnum, ForeignKey, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database.base import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.chats.models import Chat
|
||||
from app.media.models import Attachment
|
||||
from app.users.models import User
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
VOICE = "voice"
|
||||
FILE = "file"
|
||||
CIRCLE_VIDEO = "circle_video"
|
||||
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
chat_id: Mapped[int] = mapped_column(ForeignKey("chats.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
sender_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
type: Mapped[MessageType] = mapped_column(SAEnum(MessageType), nullable=False, default=MessageType.TEXT)
|
||||
text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
chat: Mapped["Chat"] = relationship(back_populates="messages")
|
||||
sender: Mapped["User"] = relationship(back_populates="sent_messages")
|
||||
attachments: Mapped[list["Attachment"]] = relationship(back_populates="message", cascade="all, delete-orphan")
|
||||
41
app/messages/repository.py
Normal file
41
app/messages/repository.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.messages.models import Message, MessageType
|
||||
|
||||
|
||||
async def create_message(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
chat_id: int,
|
||||
sender_id: int,
|
||||
message_type: MessageType,
|
||||
text: str | None,
|
||||
) -> Message:
|
||||
message = Message(chat_id=chat_id, sender_id=sender_id, type=message_type, text=text)
|
||||
db.add(message)
|
||||
await db.flush()
|
||||
return message
|
||||
|
||||
|
||||
async def get_message_by_id(db: AsyncSession, message_id: int) -> Message | None:
|
||||
result = await db.execute(select(Message).where(Message.id == message_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def list_chat_messages(
|
||||
db: AsyncSession,
|
||||
chat_id: int,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_id: int | None = None,
|
||||
) -> list[Message]:
|
||||
query = select(Message).where(Message.chat_id == chat_id)
|
||||
if before_id is not None:
|
||||
query = query.where(Message.id < before_id)
|
||||
result = await db.execute(query.order_by(Message.id.desc()).limit(limit))
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def delete_message(db: AsyncSession, message: Message) -> None:
|
||||
await db.delete(message)
|
||||
49
app/messages/router.py
Normal file
49
app/messages/router.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.service import get_current_user
|
||||
from app.database.session import get_db
|
||||
from app.messages.schemas import MessageCreateRequest, MessageRead, MessageUpdateRequest
|
||||
from app.messages.service import create_chat_message, delete_message, get_messages, update_message
|
||||
from app.users.models import User
|
||||
|
||||
router = APIRouter(prefix="/messages", tags=["messages"])
|
||||
|
||||
|
||||
@router.post("", response_model=MessageRead, status_code=status.HTTP_201_CREATED)
|
||||
async def create_message(
|
||||
payload: MessageCreateRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> MessageRead:
|
||||
return await create_chat_message(db, sender_id=current_user.id, payload=payload)
|
||||
|
||||
|
||||
@router.get("/{chat_id}", response_model=list[MessageRead])
|
||||
async def list_messages(
|
||||
chat_id: int,
|
||||
limit: int = 50,
|
||||
before_id: int | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list[MessageRead]:
|
||||
return await get_messages(db, chat_id=chat_id, user_id=current_user.id, limit=limit, before_id=before_id)
|
||||
|
||||
|
||||
@router.put("/{message_id}", response_model=MessageRead)
|
||||
async def edit_message(
|
||||
message_id: int,
|
||||
payload: MessageUpdateRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> MessageRead:
|
||||
return await update_message(db, message_id=message_id, user_id=current_user.id, payload=payload)
|
||||
|
||||
|
||||
@router.delete("/{message_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def remove_message(
|
||||
message_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> None:
|
||||
await delete_message(db, message_id=message_id, user_id=current_user.id)
|
||||
27
app/messages/schemas.py
Normal file
27
app/messages/schemas.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.messages.models import MessageType
|
||||
|
||||
|
||||
class MessageRead(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
chat_id: int
|
||||
sender_id: int
|
||||
type: MessageType
|
||||
text: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class MessageCreateRequest(BaseModel):
|
||||
chat_id: int
|
||||
type: MessageType = MessageType.TEXT
|
||||
text: str | None = Field(default=None, max_length=4096)
|
||||
|
||||
|
||||
class MessageUpdateRequest(BaseModel):
|
||||
text: str = Field(min_length=1, max_length=4096)
|
||||
67
app/messages/service.py
Normal file
67
app/messages/service.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.chats.service import ensure_chat_membership
|
||||
from app.messages import repository
|
||||
from app.messages.models import Message
|
||||
from app.messages.schemas import MessageCreateRequest, MessageUpdateRequest
|
||||
|
||||
|
||||
async def create_chat_message(db: AsyncSession, *, sender_id: int, payload: MessageCreateRequest) -> Message:
|
||||
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=sender_id)
|
||||
if payload.type.value == "text" and not (payload.text and payload.text.strip()):
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Text message cannot be empty")
|
||||
|
||||
message = await repository.create_message(
|
||||
db,
|
||||
chat_id=payload.chat_id,
|
||||
sender_id=sender_id,
|
||||
message_type=payload.type,
|
||||
text=payload.text,
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(message)
|
||||
return message
|
||||
|
||||
|
||||
async def get_messages(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
chat_id: int,
|
||||
user_id: int,
|
||||
limit: int = 50,
|
||||
before_id: int | None = None,
|
||||
) -> list[Message]:
|
||||
await ensure_chat_membership(db, chat_id=chat_id, user_id=user_id)
|
||||
safe_limit = max(1, min(limit, 100))
|
||||
return await repository.list_chat_messages(db, chat_id, limit=safe_limit, before_id=before_id)
|
||||
|
||||
|
||||
async def update_message(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
message_id: int,
|
||||
user_id: int,
|
||||
payload: MessageUpdateRequest,
|
||||
) -> Message:
|
||||
message = await repository.get_message_by_id(db, message_id)
|
||||
if not message:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
||||
await ensure_chat_membership(db, chat_id=message.chat_id, user_id=user_id)
|
||||
if message.sender_id != user_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You can edit only your own messages")
|
||||
message.text = payload.text
|
||||
await db.commit()
|
||||
await db.refresh(message)
|
||||
return message
|
||||
|
||||
|
||||
async def delete_message(db: AsyncSession, *, message_id: int, user_id: int) -> None:
|
||||
message = await repository.get_message_by_id(db, message_id)
|
||||
if not message:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
||||
await ensure_chat_membership(db, chat_id=message.chat_id, user_id=user_id)
|
||||
if message.sender_id != user_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You can delete only your own messages")
|
||||
await repository.delete_message(db, message)
|
||||
await db.commit()
|
||||
0
app/notifications/__init__.py
Normal file
0
app/notifications/__init__.py
Normal file
16
app/notifications/models.py
Normal file
16
app/notifications/models.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database.base import Base
|
||||
|
||||
|
||||
class NotificationLog(Base):
|
||||
__tablename__ = "notification_logs"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
user_id: Mapped[int] = mapped_column(index=True)
|
||||
event_type: Mapped[str] = mapped_column(String(64), index=True)
|
||||
payload: Mapped[str] = mapped_column(String(1024))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
7
app/notifications/repository.py
Normal file
7
app/notifications/repository.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.notifications.models import NotificationLog
|
||||
|
||||
|
||||
async def create_notification_log(db: AsyncSession, *, user_id: int, event_type: str, payload: str) -> None:
|
||||
db.add(NotificationLog(user_id=user_id, event_type=event_type, payload=payload))
|
||||
3
app/notifications/router.py
Normal file
3
app/notifications/router.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||
7
app/notifications/schemas.py
Normal file
7
app/notifications/schemas.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class NotificationRequest(BaseModel):
|
||||
user_id: int
|
||||
event_type: str
|
||||
payload: dict
|
||||
13
app/notifications/service.py
Normal file
13
app/notifications/service.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.notifications.repository import create_notification_log
|
||||
from app.notifications.schemas import NotificationRequest
|
||||
|
||||
|
||||
async def enqueue_notification(db: AsyncSession, payload: NotificationRequest) -> None:
|
||||
await create_notification_log(
|
||||
db,
|
||||
user_id=payload.user_id,
|
||||
event_type=payload.event_type,
|
||||
payload=payload.payload.__repr__(),
|
||||
)
|
||||
0
app/realtime/__init__.py
Normal file
0
app/realtime/__init__.py
Normal file
10
app/realtime/models.py
Normal file
10
app/realtime/models.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ConnectionContext:
|
||||
user_id: int
|
||||
connection_id: str
|
||||
websocket: WebSocket
|
||||
48
app/realtime/repository.py
Normal file
48
app/realtime/repository.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
class RedisRealtimeRepository:
|
||||
def __init__(self) -> None:
|
||||
self._redis: Redis | None = None
|
||||
self._pubsub = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self._redis:
|
||||
return
|
||||
self._redis = Redis.from_url(settings.redis_url, decode_responses=True)
|
||||
self._pubsub = self._redis.pubsub()
|
||||
await self._pubsub.psubscribe("chat:*")
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._pubsub:
|
||||
await self._pubsub.close()
|
||||
self._pubsub = None
|
||||
if self._redis:
|
||||
await self._redis.aclose()
|
||||
self._redis = None
|
||||
|
||||
async def publish_event(self, channel: str, payload: dict) -> None:
|
||||
if not self._redis:
|
||||
await self.connect()
|
||||
assert self._redis is not None
|
||||
await self._redis.publish(channel, json.dumps(payload))
|
||||
|
||||
async def consume(self, handler: Callable[[str, dict], Awaitable[None]]) -> None:
|
||||
if not self._pubsub:
|
||||
await self.connect()
|
||||
assert self._pubsub is not None
|
||||
while True:
|
||||
message = await self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
||||
if not message:
|
||||
continue
|
||||
channel = message.get("channel")
|
||||
data = message.get("data")
|
||||
if not channel or not isinstance(data, str):
|
||||
continue
|
||||
payload = json.loads(data)
|
||||
await handler(channel, payload)
|
||||
76
app/realtime/router.py
Normal file
76
app/realtime/router.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, status
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.auth.service import get_current_user_for_ws
|
||||
from app.database.session import AsyncSessionLocal
|
||||
from app.realtime.schemas import (
|
||||
ChatEventPayload,
|
||||
IncomingRealtimeEvent,
|
||||
MessageStatusPayload,
|
||||
OutgoingRealtimeEvent,
|
||||
SendMessagePayload,
|
||||
)
|
||||
from app.realtime.service import realtime_gateway
|
||||
|
||||
router = APIRouter(prefix="/realtime", tags=["realtime"])
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_gateway(websocket: WebSocket) -> None:
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
try:
|
||||
user = await get_current_user_for_ws(token, db)
|
||||
except Exception:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
user_chat_ids = await realtime_gateway.load_user_chat_ids(db, user.id)
|
||||
await websocket.accept()
|
||||
connection_id = await realtime_gateway.register(user.id, websocket, user_chat_ids)
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw_data = await websocket.receive_json()
|
||||
try:
|
||||
event = IncomingRealtimeEvent.model_validate(raw_data)
|
||||
await _dispatch_event(db, user.id, event)
|
||||
except ValidationError:
|
||||
await websocket.send_json(
|
||||
OutgoingRealtimeEvent(
|
||||
event="error",
|
||||
payload={"detail": "Invalid event payload"},
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
).model_dump(mode="json")
|
||||
)
|
||||
except Exception as exc:
|
||||
await websocket.send_json(
|
||||
OutgoingRealtimeEvent(
|
||||
event="error",
|
||||
payload={"detail": str(exc)},
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
).model_dump(mode="json")
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
await realtime_gateway.unregister(user.id, connection_id, user_chat_ids)
|
||||
|
||||
|
||||
async def _dispatch_event(db, user_id: int, event: IncomingRealtimeEvent) -> None:
|
||||
if event.event == "send_message":
|
||||
payload = SendMessagePayload.model_validate(event.payload)
|
||||
await realtime_gateway.handle_send_message(db, user_id, payload)
|
||||
return
|
||||
if event.event in {"typing_start", "typing_stop"}:
|
||||
payload = ChatEventPayload.model_validate(event.payload)
|
||||
await realtime_gateway.handle_typing_event(db, user_id, payload, event.event)
|
||||
return
|
||||
if event.event in {"message_read", "message_delivered"}:
|
||||
payload = MessageStatusPayload.model_validate(event.payload)
|
||||
await realtime_gateway.handle_message_status(db, user_id, payload, event.event)
|
||||
return
|
||||
48
app/realtime/schemas.py
Normal file
48
app/realtime/schemas.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.messages.models import MessageType
|
||||
|
||||
|
||||
RealtimeEventName = Literal[
|
||||
"connect",
|
||||
"disconnect",
|
||||
"send_message",
|
||||
"receive_message",
|
||||
"typing_start",
|
||||
"typing_stop",
|
||||
"message_read",
|
||||
"message_delivered",
|
||||
"error",
|
||||
]
|
||||
|
||||
|
||||
class SendMessagePayload(BaseModel):
|
||||
chat_id: int
|
||||
type: MessageType = MessageType.TEXT
|
||||
text: str | None = Field(default=None, max_length=4096)
|
||||
temp_id: str | None = None
|
||||
|
||||
|
||||
class ChatEventPayload(BaseModel):
|
||||
chat_id: int
|
||||
|
||||
|
||||
class MessageStatusPayload(BaseModel):
|
||||
chat_id: int
|
||||
message_id: int
|
||||
|
||||
|
||||
class IncomingRealtimeEvent(BaseModel):
|
||||
event: Literal["send_message", "typing_start", "typing_stop", "message_read", "message_delivered"]
|
||||
payload: dict[str, Any]
|
||||
|
||||
|
||||
class OutgoingRealtimeEvent(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
event: RealtimeEventName
|
||||
payload: dict[str, Any]
|
||||
timestamp: datetime
|
||||
178
app/realtime/service.py
Normal file
178
app/realtime/service.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import WebSocket
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.chats.repository import list_user_chat_ids
|
||||
from app.chats.service import ensure_chat_membership
|
||||
from app.messages.schemas import MessageCreateRequest, MessageRead
|
||||
from app.messages.service import create_chat_message
|
||||
from app.realtime.models import ConnectionContext
|
||||
from app.realtime.repository import RedisRealtimeRepository
|
||||
from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload
|
||||
|
||||
|
||||
class RealtimeGateway:
|
||||
def __init__(self) -> None:
|
||||
self._repo = RedisRealtimeRepository()
|
||||
self._consume_task: asyncio.Task | None = None
|
||||
self._distributed_enabled = False
|
||||
self._connections: dict[int, dict[str, ConnectionContext]] = defaultdict(dict)
|
||||
self._chat_subscribers: dict[int, set[int]] = defaultdict(set)
|
||||
|
||||
async def start(self) -> None:
|
||||
try:
|
||||
await self._repo.connect()
|
||||
if not self._consume_task:
|
||||
self._consume_task = asyncio.create_task(self._repo.consume(self._handle_redis_event))
|
||||
self._distributed_enabled = True
|
||||
except Exception:
|
||||
self._distributed_enabled = False
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._consume_task:
|
||||
self._consume_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._consume_task
|
||||
self._consume_task = None
|
||||
await self._repo.close()
|
||||
self._distributed_enabled = False
|
||||
|
||||
async def register(self, user_id: int, websocket: WebSocket, user_chat_ids: list[int]) -> str:
|
||||
connection_id = str(uuid4())
|
||||
self._connections[user_id][connection_id] = ConnectionContext(
|
||||
user_id=user_id,
|
||||
connection_id=connection_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
for chat_id in user_chat_ids:
|
||||
self._chat_subscribers[chat_id].add(user_id)
|
||||
await self._send_user_event(
|
||||
user_id,
|
||||
OutgoingRealtimeEvent(
|
||||
event="connect",
|
||||
payload={"connection_id": connection_id},
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
),
|
||||
)
|
||||
return connection_id
|
||||
|
||||
async def unregister(self, user_id: int, connection_id: str, user_chat_ids: list[int]) -> None:
|
||||
user_connections = self._connections.get(user_id, {})
|
||||
user_connections.pop(connection_id, None)
|
||||
if not user_connections:
|
||||
self._connections.pop(user_id, None)
|
||||
for chat_id in user_chat_ids:
|
||||
subscribers = self._chat_subscribers.get(chat_id)
|
||||
if not subscribers:
|
||||
continue
|
||||
subscribers.discard(user_id)
|
||||
if not subscribers:
|
||||
self._chat_subscribers.pop(chat_id, None)
|
||||
|
||||
async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None:
|
||||
message = await create_chat_message(
|
||||
db,
|
||||
sender_id=user_id,
|
||||
payload=MessageCreateRequest(chat_id=payload.chat_id, type=payload.type, text=payload.text),
|
||||
)
|
||||
message_data = MessageRead.model_validate(message).model_dump(mode="json")
|
||||
await self._publish_chat_event(
|
||||
payload.chat_id,
|
||||
event="receive_message",
|
||||
payload={
|
||||
"chat_id": payload.chat_id,
|
||||
"message": message_data,
|
||||
"temp_id": payload.temp_id,
|
||||
"sender_id": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def handle_typing_event(self, db: AsyncSession, user_id: int, payload: ChatEventPayload, event: str) -> None:
|
||||
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
|
||||
await self._publish_chat_event(
|
||||
payload.chat_id,
|
||||
event=event,
|
||||
payload={"chat_id": payload.chat_id, "user_id": user_id},
|
||||
)
|
||||
|
||||
async def handle_message_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
payload: MessageStatusPayload,
|
||||
event: str,
|
||||
) -> None:
|
||||
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
|
||||
await self._publish_chat_event(
|
||||
payload.chat_id,
|
||||
event=event,
|
||||
payload={
|
||||
"chat_id": payload.chat_id,
|
||||
"message_id": payload.message_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def load_user_chat_ids(self, db: AsyncSession, user_id: int) -> list[int]:
|
||||
return await list_user_chat_ids(db, user_id=user_id)
|
||||
|
||||
async def _handle_redis_event(self, channel: str, payload: dict) -> None:
|
||||
chat_id = self._extract_chat_id(channel)
|
||||
if chat_id is None:
|
||||
return
|
||||
subscribers = self._chat_subscribers.get(chat_id, set())
|
||||
if not subscribers:
|
||||
return
|
||||
event = OutgoingRealtimeEvent(
|
||||
event=payload.get("event", "error"),
|
||||
payload=payload.get("payload", {}),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
await asyncio.gather(*(self._send_user_event(user_id, event) for user_id in subscribers), return_exceptions=True)
|
||||
|
||||
async def _publish_chat_event(self, chat_id: int, *, event: str, payload: dict) -> None:
|
||||
event_payload = {
|
||||
"event": event,
|
||||
"payload": payload,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
if self._distributed_enabled:
|
||||
await self._repo.publish_event(f"chat:{chat_id}", event_payload)
|
||||
return
|
||||
await self._handle_redis_event(f"chat:{chat_id}", event_payload)
|
||||
|
||||
async def _send_user_event(self, user_id: int, event: OutgoingRealtimeEvent) -> None:
|
||||
user_connections = self._connections.get(user_id, {})
|
||||
if not user_connections:
|
||||
return
|
||||
disconnected: list[str] = []
|
||||
for connection_id, context in user_connections.items():
|
||||
try:
|
||||
await context.websocket.send_json(event.model_dump(mode="json"))
|
||||
except Exception:
|
||||
disconnected.append(connection_id)
|
||||
for connection_id in disconnected:
|
||||
user_connections.pop(connection_id, None)
|
||||
if not user_connections:
|
||||
self._connections.pop(user_id, None)
|
||||
for chat_id, subscribers in list(self._chat_subscribers.items()):
|
||||
subscribers.discard(user_id)
|
||||
if not subscribers:
|
||||
self._chat_subscribers.pop(chat_id, None)
|
||||
|
||||
@staticmethod
|
||||
def _extract_chat_id(channel: str) -> int | None:
|
||||
if not channel.startswith("chat:"):
|
||||
return None
|
||||
chat_id = channel.split(":", maxsplit=1)[1]
|
||||
if not chat_id.isdigit():
|
||||
return None
|
||||
return int(chat_id)
|
||||
|
||||
|
||||
realtime_gateway = RealtimeGateway()
|
||||
0
app/users/__init__.py
Normal file
0
app/users/__init__.py
Normal file
41
app/users/models.py
Normal file
41
app/users/models.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database.base import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.auth.models import EmailVerificationToken, PasswordResetToken
|
||||
from app.chats.models import ChatMember
|
||||
from app.messages.models import Message
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
username: Mapped[str] = mapped_column(String(50), unique=True, index=True)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255))
|
||||
avatar_url: Mapped[str | None] = mapped_column(String(512), nullable=True)
|
||||
email_verified: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
memberships: Mapped[list["ChatMember"]] = relationship(back_populates="user", cascade="all, delete-orphan")
|
||||
sent_messages: Mapped[list["Message"]] = relationship(back_populates="sender")
|
||||
email_verification_tokens: Mapped[list["EmailVerificationToken"]] = relationship(
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
password_reset_tokens: Mapped[list["PasswordResetToken"]] = relationship(
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
26
app/users/repository.py
Normal file
26
app/users/repository.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.users.models import User
|
||||
|
||||
|
||||
async def create_user(db: AsyncSession, *, email: str, username: str, password_hash: str) -> User:
|
||||
user = User(email=email, username=username, password_hash=password_hash, email_verified=False)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
return user
|
||||
|
||||
|
||||
async def get_user_by_id(db: AsyncSession, user_id: int) -> User | None:
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
|
||||
result = await db.execute(select(User).where(User.username == username))
|
||||
return result.scalar_one_or_none()
|
||||
43
app/users/router.py
Normal file
43
app/users/router.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.service import get_current_user
|
||||
from app.database.session import get_db
|
||||
from app.users.models import User
|
||||
from app.users.schemas import UserProfileUpdate, UserRead
|
||||
from app.users.service import get_user_by_id, get_user_by_username, update_user_profile
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserRead)
|
||||
async def read_me(current_user: User = Depends(get_current_user)) -> UserRead:
|
||||
return current_user
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserRead)
|
||||
async def read_user(user_id: int, db: AsyncSession = Depends(get_db), _current_user: User = Depends(get_current_user)) -> UserRead:
|
||||
user = await get_user_by_id(db, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/profile", response_model=UserRead)
|
||||
async def update_profile(
|
||||
payload: UserProfileUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> UserRead:
|
||||
if payload.username and payload.username != current_user.username:
|
||||
username_owner = await get_user_by_username(db, payload.username)
|
||||
if username_owner:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Username already taken")
|
||||
|
||||
updated = await update_user_profile(
|
||||
db,
|
||||
current_user,
|
||||
username=payload.username,
|
||||
avatar_url=payload.avatar_url,
|
||||
)
|
||||
return updated
|
||||
27
app/users/schemas.py
Normal file
27
app/users/schemas.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
username: str = Field(min_length=3, max_length=50)
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str = Field(min_length=8, max_length=128)
|
||||
|
||||
|
||||
class UserRead(UserBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
avatar_url: str | None = None
|
||||
email_verified: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class UserProfileUpdate(BaseModel):
|
||||
username: str | None = Field(default=None, min_length=3, max_length=50)
|
||||
avatar_url: str | None = Field(default=None, max_length=512)
|
||||
32
app/users/service.py
Normal file
32
app/users/service.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.users import repository
|
||||
from app.users.models import User
|
||||
|
||||
|
||||
async def get_user_by_id(db: AsyncSession, user_id: int) -> User | None:
|
||||
return await repository.get_user_by_id(db, user_id)
|
||||
|
||||
|
||||
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
|
||||
return await repository.get_user_by_email(db, email)
|
||||
|
||||
|
||||
async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
|
||||
return await repository.get_user_by_username(db, username)
|
||||
|
||||
|
||||
async def update_user_profile(
|
||||
db: AsyncSession,
|
||||
user: User,
|
||||
*,
|
||||
username: str | None = None,
|
||||
avatar_url: str | None = None,
|
||||
) -> User:
|
||||
if username is not None:
|
||||
user.username = username
|
||||
if avatar_url is not None:
|
||||
user.avatar_url = avatar_url
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
0
app/utils/__init__.py
Normal file
0
app/utils/__init__.py
Normal file
51
app/utils/security.py
Normal file
51
app/utils/security.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from secrets import token_urlsafe
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(password, hashed_password)
|
||||
|
||||
|
||||
def _create_token(subject: str, token_type: str, expires_delta: timedelta) -> str:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
payload = {"sub": subject, "type": token_type, "exp": expire}
|
||||
return jwt.encode(payload, settings.secret_key, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def create_access_token(subject: str) -> str:
|
||||
return _create_token(
|
||||
subject=subject,
|
||||
token_type="access",
|
||||
expires_delta=timedelta(minutes=settings.access_token_expire_minutes),
|
||||
)
|
||||
|
||||
|
||||
def create_refresh_token(subject: str) -> str:
|
||||
return _create_token(
|
||||
subject=subject,
|
||||
token_type="refresh",
|
||||
expires_delta=timedelta(days=settings.refresh_token_expire_days),
|
||||
)
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
try:
|
||||
return jwt.decode(token, settings.secret_key, algorithms=[settings.jwt_algorithm])
|
||||
except JWTError as exc:
|
||||
raise ValueError("Invalid token") from exc
|
||||
|
||||
|
||||
def generate_random_token() -> str:
|
||||
return token_urlsafe(48)
|
||||
Reference in New Issue
Block a user