first commit

This commit is contained in:
2026-03-07 21:31:38 +03:00
commit a879ba7b50
68 changed files with 2487 additions and 0 deletions

0
app/__init__.py Normal file
View File

0
app/auth/__init__.py Normal file
View File

34
app/auth/models.py Normal file
View File

@@ -0,0 +1,34 @@
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import DateTime, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database.base import Base
if TYPE_CHECKING:
from app.users.models import User
class EmailVerificationToken(Base):
__tablename__ = "email_verification_tokens"
id: Mapped[int] = mapped_column(primary_key=True, index=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
token: Mapped[str] = mapped_column(String(255), unique=True, index=True, nullable=False)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
user: Mapped["User"] = relationship(back_populates="email_verification_tokens")
class PasswordResetToken(Base):
__tablename__ = "password_reset_tokens"
id: Mapped[int] = mapped_column(primary_key=True, index=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
token: Mapped[str] = mapped_column(String(255), unique=True, index=True, nullable=False)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
user: Mapped["User"] = relationship(back_populates="password_reset_tokens")

46
app/auth/repository.py Normal file
View File

@@ -0,0 +1,46 @@
from datetime import datetime, timezone
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.models import EmailVerificationToken, PasswordResetToken
async def create_email_verification_token(db: AsyncSession, user_id: int, token: str, expires_at: datetime) -> None:
db.add(
EmailVerificationToken(
user_id=user_id,
token=token,
expires_at=expires_at,
created_at=datetime.now(timezone.utc),
)
)
async def get_email_verification_token(db: AsyncSession, token: str) -> EmailVerificationToken | None:
result = await db.execute(select(EmailVerificationToken).where(EmailVerificationToken.token == token))
return result.scalar_one_or_none()
async def delete_email_verification_tokens_for_user(db: AsyncSession, user_id: int) -> None:
await db.execute(delete(EmailVerificationToken).where(EmailVerificationToken.user_id == user_id))
async def create_password_reset_token(db: AsyncSession, user_id: int, token: str, expires_at: datetime) -> None:
db.add(
PasswordResetToken(
user_id=user_id,
token=token,
expires_at=expires_at,
created_at=datetime.now(timezone.utc),
)
)
async def get_password_reset_token(db: AsyncSession, token: str) -> PasswordResetToken | None:
result = await db.execute(select(PasswordResetToken).where(PasswordResetToken.token == token))
return result.scalar_one_or_none()
async def delete_password_reset_tokens_for_user(db: AsyncSession, user_id: int) -> None:
await db.execute(delete(PasswordResetToken).where(PasswordResetToken.user_id == user_id))

81
app/auth/router.py Normal file
View File

@@ -0,0 +1,81 @@
from fastapi import APIRouter, Depends, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.schemas import (
AuthUserResponse,
LoginRequest,
MessageResponse,
RegisterRequest,
RequestPasswordResetRequest,
ResendVerificationRequest,
ResetPasswordRequest,
TokenResponse,
VerifyEmailRequest,
)
from app.auth.service import (
get_current_user,
get_email_sender,
login_user,
register_user,
request_password_reset,
resend_verification_email,
reset_password,
verify_email,
)
from app.database.session import get_db
from app.email.service import EmailService
from app.users.models import User
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/register", response_model=MessageResponse, status_code=status.HTTP_201_CREATED)
async def register(
payload: RegisterRequest,
db: AsyncSession = Depends(get_db),
email_service: EmailService = Depends(get_email_sender),
) -> MessageResponse:
await register_user(db, payload, email_service)
return MessageResponse(message="Registration successful. Verification email sent.")
@router.post("/login", response_model=TokenResponse)
async def login(payload: LoginRequest, db: AsyncSession = Depends(get_db)) -> TokenResponse:
return await login_user(db, payload)
@router.post("/verify-email", response_model=MessageResponse)
async def verify_email_endpoint(payload: VerifyEmailRequest, db: AsyncSession = Depends(get_db)) -> MessageResponse:
await verify_email(db, payload)
return MessageResponse(message="Email verified successfully.")
@router.post("/resend-verification", response_model=MessageResponse)
async def resend_verification(
payload: ResendVerificationRequest,
db: AsyncSession = Depends(get_db),
email_service: EmailService = Depends(get_email_sender),
) -> MessageResponse:
await resend_verification_email(db, payload, email_service)
return MessageResponse(message="If the account exists, a verification email was sent.")
@router.post("/request-password-reset", response_model=MessageResponse)
async def request_password_reset_endpoint(
payload: RequestPasswordResetRequest,
db: AsyncSession = Depends(get_db),
email_service: EmailService = Depends(get_email_sender),
) -> MessageResponse:
await request_password_reset(db, payload, email_service)
return MessageResponse(message="If the account exists, a reset email was sent.")
@router.post("/reset-password", response_model=MessageResponse)
async def reset_password_endpoint(payload: ResetPasswordRequest, db: AsyncSession = Depends(get_db)) -> MessageResponse:
await reset_password(db, payload)
return MessageResponse(message="Password reset successfully.")
@router.get("/me", response_model=AuthUserResponse)
async def me(current_user: User = Depends(get_current_user)) -> AuthUserResponse:
return current_user

53
app/auth/schemas.py Normal file
View File

@@ -0,0 +1,53 @@
from datetime import datetime
from pydantic import BaseModel, ConfigDict, EmailStr, Field
class RegisterRequest(BaseModel):
email: EmailStr
username: str = Field(min_length=3, max_length=50)
password: str = Field(min_length=8, max_length=128)
class LoginRequest(BaseModel):
email: EmailStr
password: str = Field(min_length=8, max_length=128)
class VerifyEmailRequest(BaseModel):
token: str = Field(min_length=16, max_length=512)
class ResendVerificationRequest(BaseModel):
email: EmailStr
class RequestPasswordResetRequest(BaseModel):
email: EmailStr
class ResetPasswordRequest(BaseModel):
token: str = Field(min_length=16, max_length=512)
new_password: str = Field(min_length=8, max_length=128)
class MessageResponse(BaseModel):
message: str
class TokenResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
class AuthUserResponse(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
email: EmailStr
username: str
avatar_url: str | None = None
email_verified: bool
created_at: datetime
updated_at: datetime

195
app/auth/service.py Normal file
View File

@@ -0,0 +1,195 @@
from datetime import datetime, timedelta, timezone
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth import repository as auth_repository
from app.auth.schemas import (
LoginRequest,
RegisterRequest,
RequestPasswordResetRequest,
ResendVerificationRequest,
ResetPasswordRequest,
TokenResponse,
VerifyEmailRequest,
)
from app.config.settings import settings
from app.database.session import get_db
from app.email.service import EmailService, get_email_service
from app.users.models import User
from app.users.repository import create_user, get_user_by_email, get_user_by_id, get_user_by_username
from app.utils.security import (
create_access_token,
create_refresh_token,
decode_token,
generate_random_token,
hash_password,
verify_password,
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/login")
async def register_user(
db: AsyncSession,
payload: RegisterRequest,
email_service: EmailService,
) -> None:
existing_email = await get_user_by_email(db, payload.email)
if existing_email:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email is already registered")
existing_username = await get_user_by_username(db, payload.username)
if existing_username:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Username is already taken")
user = await create_user(
db,
email=payload.email,
username=payload.username,
password_hash=hash_password(payload.password),
)
verification_token = generate_random_token()
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours)
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at)
await db.commit()
await email_service.send_verification_email(payload.email, verification_token)
async def verify_email(db: AsyncSession, payload: VerifyEmailRequest) -> None:
record = await auth_repository.get_email_verification_token(db, payload.token)
if not record:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification token")
now = datetime.now(timezone.utc)
expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc)
if expires_at < now:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Verification token expired")
user = await get_user_by_id(db, record.user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
user.email_verified = True
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
await db.commit()
async def resend_verification_email(
db: AsyncSession,
payload: ResendVerificationRequest,
email_service: EmailService,
) -> None:
user = await get_user_by_email(db, payload.email)
if not user or user.email_verified:
return
verification_token = generate_random_token()
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours)
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at)
await db.commit()
await email_service.send_verification_email(user.email, verification_token)
async def login_user(db: AsyncSession, payload: LoginRequest) -> TokenResponse:
user = await get_user_by_email(db, payload.email)
if not user or not verify_password(payload.password, user.password_hash):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
if not user.email_verified:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified")
return TokenResponse(
access_token=create_access_token(str(user.id)),
refresh_token=create_refresh_token(str(user.id)),
)
async def request_password_reset(
db: AsyncSession,
payload: RequestPasswordResetRequest,
email_service: EmailService,
) -> None:
user = await get_user_by_email(db, payload.email)
if not user:
return
reset_token = generate_random_token()
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.password_reset_token_expire_hours)
await auth_repository.delete_password_reset_tokens_for_user(db, user.id)
await auth_repository.create_password_reset_token(db, user.id, reset_token, expires_at)
await db.commit()
await email_service.send_password_reset_email(user.email, reset_token)
async def reset_password(db: AsyncSession, payload: ResetPasswordRequest) -> None:
record = await auth_repository.get_password_reset_token(db, payload.token)
if not record:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid reset token")
now = datetime.now(timezone.utc)
expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc)
if expires_at < now:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Reset token expired")
user = await get_user_by_id(db, record.user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
user.password_hash = hash_password(payload.new_password)
await auth_repository.delete_password_reset_tokens_for_user(db, user.id)
await db.commit()
async def get_current_user(token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db)) -> User:
credentials_error = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = decode_token(token)
except ValueError as exc:
raise credentials_error from exc
if payload.get("type") != "access":
raise credentials_error
user_id = payload.get("sub")
if not user_id or not str(user_id).isdigit():
raise credentials_error
user = await get_user_by_id(db, int(user_id))
if not user:
raise credentials_error
return user
async def get_current_user_for_ws(token: str, db: AsyncSession) -> User:
try:
payload = decode_token(token)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") from exc
if payload.get("type") != "access":
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type")
user_id = payload.get("sub")
if not user_id or not str(user_id).isdigit():
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload")
user = await get_user_by_id(db, int(user_id))
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
return user
def get_email_sender() -> EmailService:
return get_email_service()

0
app/chats/__init__.py Normal file
View File

50
app/chats/models.py Normal file
View File

@@ -0,0 +1,50 @@
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING
from sqlalchemy import DateTime, Enum as SAEnum, ForeignKey, String, UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database.base import Base
if TYPE_CHECKING:
from app.messages.models import Message
from app.users.models import User
class ChatType(str, Enum):
PRIVATE = "private"
GROUP = "group"
CHANNEL = "channel"
class ChatMemberRole(str, Enum):
OWNER = "owner"
ADMIN = "admin"
MEMBER = "member"
class Chat(Base):
__tablename__ = "chats"
id: Mapped[int] = mapped_column(primary_key=True, index=True)
type: Mapped[ChatType] = mapped_column(SAEnum(ChatType), nullable=False, index=True)
title: Mapped[str | None] = mapped_column(String(255), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
members: Mapped[list["ChatMember"]] = relationship(back_populates="chat", cascade="all, delete-orphan")
messages: Mapped[list["Message"]] = relationship(back_populates="chat", cascade="all, delete-orphan")
class ChatMember(Base):
__tablename__ = "chat_members"
__table_args__ = (UniqueConstraint("chat_id", "user_id", name="uq_chat_members_chat_id_user_id"),)
id: Mapped[int] = mapped_column(primary_key=True, index=True)
chat_id: Mapped[int] = mapped_column(ForeignKey("chats.id", ondelete="CASCADE"), nullable=False, index=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
role: Mapped[ChatMemberRole] = mapped_column(SAEnum(ChatMemberRole), nullable=False, default=ChatMemberRole.MEMBER)
joined_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
chat: Mapped["Chat"] = relationship(back_populates="members")
user: Mapped["User"] = relationship(back_populates="memberships")

62
app/chats/repository.py Normal file
View File

@@ -0,0 +1,62 @@
from sqlalchemy import Select, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.chats.models import Chat, ChatMember, ChatMemberRole, ChatType
async def create_chat(db: AsyncSession, *, chat_type: ChatType, title: str | None) -> Chat:
chat = Chat(type=chat_type, title=title)
db.add(chat)
await db.flush()
return chat
async def add_chat_member(db: AsyncSession, *, chat_id: int, user_id: int, role: ChatMemberRole) -> ChatMember:
member = ChatMember(chat_id=chat_id, user_id=user_id, role=role)
db.add(member)
await db.flush()
return member
def _user_chats_query(user_id: int) -> Select[tuple[Chat]]:
return (
select(Chat)
.join(ChatMember, ChatMember.chat_id == Chat.id)
.where(ChatMember.user_id == user_id)
.order_by(Chat.id.desc())
)
async def list_user_chats(db: AsyncSession, *, user_id: int, limit: int = 50, before_id: int | None = None) -> list[Chat]:
query = _user_chats_query(user_id).limit(limit)
if before_id is not None:
query = query.where(Chat.id < before_id)
result = await db.execute(query)
return list(result.scalars().all())
async def get_chat_by_id(db: AsyncSession, chat_id: int) -> Chat | None:
result = await db.execute(select(Chat).where(Chat.id == chat_id))
return result.scalar_one_or_none()
async def get_chat_member(db: AsyncSession, *, chat_id: int, user_id: int) -> ChatMember | None:
result = await db.execute(
select(ChatMember).where(
ChatMember.chat_id == chat_id,
ChatMember.user_id == user_id,
)
)
return result.scalar_one_or_none()
async def list_chat_members(db: AsyncSession, *, chat_id: int) -> list[ChatMember]:
result = await db.execute(select(ChatMember).where(ChatMember.chat_id == chat_id).order_by(ChatMember.id.asc()))
return list(result.scalars().all())
async def list_user_chat_ids(db: AsyncSession, *, user_id: int) -> list[int]:
result = await db.execute(
select(ChatMember.chat_id).where(ChatMember.user_id == user_id).order_by(ChatMember.chat_id.asc())
)
return list(result.scalars().all())

45
app/chats/router.py Normal file
View File

@@ -0,0 +1,45 @@
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.service import get_current_user
from app.chats.schemas import ChatCreateRequest, ChatDetailRead, ChatRead
from app.chats.service import create_chat_for_user, get_chat_for_user, get_chats_for_user
from app.database.session import get_db
from app.users.models import User
router = APIRouter(prefix="/chats", tags=["chats"])
@router.get("", response_model=list[ChatRead])
async def list_chats(
limit: int = 50,
before_id: int | None = None,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> list[ChatRead]:
return await get_chats_for_user(db, user_id=current_user.id, limit=limit, before_id=before_id)
@router.post("", response_model=ChatRead)
async def create_chat(
payload: ChatCreateRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> ChatRead:
return await create_chat_for_user(db, creator_id=current_user.id, payload=payload)
@router.get("/{chat_id}", response_model=ChatDetailRead)
async def get_chat(
chat_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> ChatDetailRead:
chat, members = await get_chat_for_user(db, chat_id=chat_id, user_id=current_user.id)
return ChatDetailRead(
id=chat.id,
type=chat.type,
title=chat.title,
created_at=chat.created_at,
members=members,
)

33
app/chats/schemas.py Normal file
View File

@@ -0,0 +1,33 @@
from datetime import datetime
from pydantic import BaseModel, ConfigDict, Field
from app.chats.models import ChatMemberRole, ChatType
class ChatRead(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
type: ChatType
title: str | None = None
created_at: datetime
class ChatMemberRead(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
user_id: int
role: ChatMemberRole
joined_at: datetime
class ChatDetailRead(ChatRead):
members: list[ChatMemberRead]
class ChatCreateRequest(BaseModel):
type: ChatType
title: str | None = Field(default=None, max_length=255)
member_ids: list[int] = Field(default_factory=list)

62
app/chats/service.py Normal file
View File

@@ -0,0 +1,62 @@
from fastapi import HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.chats import repository
from app.chats.models import Chat, ChatMemberRole, ChatType
from app.chats.schemas import ChatCreateRequest
from app.users.repository import get_user_by_id
async def create_chat_for_user(db: AsyncSession, *, creator_id: int, payload: ChatCreateRequest) -> Chat:
member_ids = list(dict.fromkeys(payload.member_ids))
member_ids = [member_id for member_id in member_ids if member_id != creator_id]
if payload.type == ChatType.PRIVATE and len(member_ids) != 1:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Private chat requires exactly one target user.",
)
if payload.type in {ChatType.GROUP, ChatType.CHANNEL} and not payload.title:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Group and channel chats require title.",
)
for member_id in member_ids:
user = await get_user_by_id(db, member_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User {member_id} not found")
chat = await repository.create_chat(db, chat_type=payload.type, title=payload.title)
await repository.add_chat_member(db, chat_id=chat.id, user_id=creator_id, role=ChatMemberRole.OWNER)
default_role = ChatMemberRole.MEMBER
for member_id in member_ids:
await repository.add_chat_member(db, chat_id=chat.id, user_id=member_id, role=default_role)
await db.commit()
return chat
async def get_chats_for_user(db: AsyncSession, *, user_id: int, limit: int = 50, before_id: int | None = None) -> list[Chat]:
safe_limit = max(1, min(limit, 100))
return await repository.list_user_chats(db, user_id=user_id, limit=safe_limit, before_id=before_id)
async def get_chat_for_user(db: AsyncSession, *, chat_id: int, user_id: int) -> tuple[Chat, list]:
chat = await repository.get_chat_by_id(db, chat_id)
if not chat:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
membership = await repository.get_chat_member(db, chat_id=chat_id, user_id=user_id)
if not membership:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You are not a member of this chat")
members = await repository.list_chat_members(db, chat_id=chat_id)
return chat, members
async def ensure_chat_membership(db: AsyncSession, *, chat_id: int, user_id: int) -> None:
membership = await repository.get_chat_member(db, chat_id=chat_id, user_id=user_id)
if not membership:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You are not a member of this chat")

0
app/config/__init__.py Normal file
View File

41
app/config/settings.py Normal file
View File

@@ -0,0 +1,41 @@
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
app_name: str = "BenyaMessenger"
environment: str = "development"
debug: bool = True
api_v1_prefix: str = "/api/v1"
auto_create_tables: bool = True
secret_key: str = Field(default="change-me-please-12345", min_length=16)
access_token_expire_minutes: int = 30
refresh_token_expire_days: int = 30
jwt_algorithm: str = "HS256"
email_verification_token_expire_hours: int = 24
password_reset_token_expire_hours: int = 1
postgres_dsn: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/messenger"
redis_url: str = "redis://localhost:6379/0"
s3_endpoint_url: str = "http://localhost:9000"
s3_access_key: str = "minioadmin"
s3_secret_key: str = "minioadmin"
s3_region: str = "us-east-1"
s3_bucket_name: str = "messenger-media"
s3_presign_expire_seconds: int = 900
max_upload_size_bytes: int = 104857600
frontend_base_url: str = "http://localhost:5173"
smtp_host: str = "localhost"
smtp_port: int = 1025
smtp_username: str = ""
smtp_password: str = ""
smtp_use_tls: bool = False
smtp_from_email: str = "no-reply@benyamessenger.local"
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
settings = Settings()

0
app/database/__init__.py Normal file
View File

15
app/database/base.py Normal file
View File

@@ -0,0 +1,15 @@
from sqlalchemy import MetaData
from sqlalchemy.orm import DeclarativeBase
NAMING_CONVENTION = {
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
class Base(DeclarativeBase):
metadata = MetaData(naming_convention=NAMING_CONVENTION)

19
app/database/models.py Normal file
View File

@@ -0,0 +1,19 @@
from app.auth.models import EmailVerificationToken, PasswordResetToken
from app.chats.models import Chat, ChatMember
from app.email.models import EmailLog
from app.media.models import Attachment
from app.messages.models import Message
from app.notifications.models import NotificationLog
from app.users.models import User
__all__ = [
"Attachment",
"Chat",
"ChatMember",
"EmailLog",
"EmailVerificationToken",
"Message",
"NotificationLog",
"PasswordResetToken",
"User",
]

25
app/database/session.py Normal file
View File

@@ -0,0 +1,25 @@
from collections.abc import AsyncIterator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.config.settings import settings
engine = create_async_engine(
settings.postgres_dsn,
echo=settings.debug,
pool_pre_ping=True,
)
AsyncSessionLocal = async_sessionmaker(
bind=engine,
class_=AsyncSession,
autoflush=False,
autocommit=False,
expire_on_commit=False,
)
async def get_db() -> AsyncIterator[AsyncSession]:
async with AsyncSessionLocal() as session:
yield session

1
app/email/__init__.py Normal file
View File

@@ -0,0 +1 @@

16
app/email/models.py Normal file
View File

@@ -0,0 +1,16 @@
from datetime import datetime
from sqlalchemy import DateTime, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from app.database.base import Base
class EmailLog(Base):
__tablename__ = "email_logs"
id: Mapped[int] = mapped_column(primary_key=True, index=True)
recipient: Mapped[str] = mapped_column(String(255), index=True, nullable=False)
subject: Mapped[str] = mapped_column(String(255), nullable=False)
body: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)

7
app/email/repository.py Normal file
View File

@@ -0,0 +1,7 @@
from sqlalchemy.ext.asyncio import AsyncSession
from app.email.models import EmailLog
async def create_email_log(db: AsyncSession, *, recipient: str, subject: str, body: str) -> None:
db.add(EmailLog(recipient=recipient, subject=subject, body=body))

3
app/email/router.py Normal file
View File

@@ -0,0 +1,3 @@
from fastapi import APIRouter
router = APIRouter(prefix="/email", tags=["email"])

7
app/email/schemas.py Normal file
View File

@@ -0,0 +1,7 @@
from pydantic import BaseModel, EmailStr
class EmailPayload(BaseModel):
recipient: EmailStr
subject: str
body: str

23
app/email/service.py Normal file
View File

@@ -0,0 +1,23 @@
import logging
from app.config.settings import settings
logger = logging.getLogger(__name__)
class EmailService:
async def send_verification_email(self, email: str, token: str) -> None:
verify_link = f"{settings.frontend_base_url}/verify-email?token={token}"
subject = "Verify your BenyaMessenger account"
body = f"Open this link to verify your account: {verify_link}"
logger.info("EMAIL to=%s subject=%s body=%s", email, subject, body)
async def send_password_reset_email(self, email: str, token: str) -> None:
reset_link = f"{settings.frontend_base_url}/reset-password?token={token}"
subject = "Reset your BenyaMessenger password"
body = f"Open this link to reset your password: {reset_link}"
logger.info("EMAIL to=%s subject=%s body=%s", email, subject, body)
def get_email_service() -> EmailService:
return EmailService()

43
app/main.py Normal file
View File

@@ -0,0 +1,43 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from app.auth.router import router as auth_router
from app.chats.router import router as chats_router
from app.config.settings import settings
from app.database import models # noqa: F401
from app.database.base import Base
from app.database.session import engine
from app.media.router import router as media_router
from app.messages.router import router as messages_router
from app.notifications.router import router as notifications_router
from app.realtime.router import router as realtime_router
from app.realtime.service import realtime_gateway
from app.users.router import router as users_router
@asynccontextmanager
async def lifespan(_app: FastAPI):
await realtime_gateway.start()
if settings.auto_create_tables:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
await realtime_gateway.stop()
app = FastAPI(title=settings.app_name, debug=settings.debug, lifespan=lifespan)
@app.get("/health", tags=["health"])
async def health() -> dict[str, str]:
return {"status": "ok"}
app.include_router(auth_router, prefix=settings.api_v1_prefix)
app.include_router(users_router, prefix=settings.api_v1_prefix)
app.include_router(chats_router, prefix=settings.api_v1_prefix)
app.include_router(messages_router, prefix=settings.api_v1_prefix)
app.include_router(media_router, prefix=settings.api_v1_prefix)
app.include_router(notifications_router, prefix=settings.api_v1_prefix)
app.include_router(realtime_router, prefix=settings.api_v1_prefix)

0
app/media/__init__.py Normal file
View File

16
app/media/models.py Normal file
View File

@@ -0,0 +1,16 @@
from sqlalchemy import ForeignKey, Integer, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database.base import Base
class Attachment(Base):
__tablename__ = "attachments"
id: Mapped[int] = mapped_column(primary_key=True, index=True)
message_id: Mapped[int] = mapped_column(ForeignKey("messages.id", ondelete="CASCADE"), nullable=False, index=True)
file_url: Mapped[str] = mapped_column(String(1024), nullable=False)
file_type: Mapped[str] = mapped_column(String(64), nullable=False)
file_size: Mapped[int] = mapped_column(Integer, nullable=False)
message = relationship("Message", back_populates="attachments")

26
app/media/repository.py Normal file
View File

@@ -0,0 +1,26 @@
from sqlalchemy.ext.asyncio import AsyncSession
from app.media.models import Attachment
async def create_attachment(
db: AsyncSession,
*,
message_id: int,
file_url: str,
file_type: str,
file_size: int,
) -> Attachment:
attachment = Attachment(
message_id=message_id,
file_url=file_url,
file_type=file_type,
file_size=file_size,
)
db.add(attachment)
await db.flush()
return attachment
async def get_attachment_by_id(db: AsyncSession, attachment_id: int) -> Attachment | None:
return await db.get(Attachment, attachment_id)

27
app/media/router.py Normal file
View File

@@ -0,0 +1,27 @@
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.service import get_current_user
from app.database.session import get_db
from app.media.schemas import AttachmentCreateRequest, AttachmentRead, UploadUrlRequest, UploadUrlResponse
from app.media.service import generate_upload_url, store_attachment_metadata
from app.users.models import User
router = APIRouter(prefix="/media", tags=["media"])
@router.post("/upload-url", response_model=UploadUrlResponse)
async def create_upload_url(
payload: UploadUrlRequest,
_current_user: User = Depends(get_current_user),
) -> UploadUrlResponse:
return await generate_upload_url(payload)
@router.post("/attachments", response_model=AttachmentRead)
async def create_attachment_metadata(
payload: AttachmentCreateRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> AttachmentRead:
return await store_attachment_metadata(db, user_id=current_user.id, payload=payload)

32
app/media/schemas.py Normal file
View File

@@ -0,0 +1,32 @@
from pydantic import BaseModel, ConfigDict, Field
class UploadUrlRequest(BaseModel):
file_name: str = Field(min_length=1, max_length=255)
file_type: str = Field(min_length=1, max_length=64)
file_size: int = Field(gt=0)
class UploadUrlResponse(BaseModel):
upload_url: str
file_url: str
object_key: str
expires_in: int
required_headers: dict[str, str]
class AttachmentCreateRequest(BaseModel):
message_id: int
file_url: str = Field(min_length=1, max_length=1024)
file_type: str = Field(min_length=1, max_length=64)
file_size: int = Field(gt=0)
class AttachmentRead(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
message_id: int
file_url: str
file_type: str
file_size: int

127
app/media/service.py Normal file
View File

@@ -0,0 +1,127 @@
import re
from urllib.parse import quote
from uuid import uuid4
import boto3
from botocore.client import Config
from botocore.exceptions import BotoCoreError, ClientError
from fastapi import HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.media import repository
from app.media.schemas import AttachmentCreateRequest, AttachmentRead, UploadUrlRequest, UploadUrlResponse
from app.messages.repository import get_message_by_id
ALLOWED_MIME_TYPES = {
"image/jpeg",
"image/png",
"image/webp",
"video/mp4",
"video/webm",
"audio/mpeg",
"audio/ogg",
"audio/wav",
"application/pdf",
"application/zip",
"text/plain",
}
_SAFE_NAME_RE = re.compile(r"[^a-zA-Z0-9._-]+")
def _sanitize_filename(file_name: str) -> str:
sanitized = _SAFE_NAME_RE.sub("_", file_name).strip("._")
if not sanitized:
sanitized = "file.bin"
return sanitized[:120]
def _build_file_url(bucket: str, object_key: str) -> str:
base = settings.s3_endpoint_url.rstrip("/")
encoded_key = quote(object_key)
return f"{base}/{bucket}/{encoded_key}"
def _allowed_file_url_prefix() -> str:
return f"{settings.s3_endpoint_url.rstrip('/')}/{settings.s3_bucket_name}/"
def _validate_media(file_type: str, file_size: int) -> None:
if file_type not in ALLOWED_MIME_TYPES:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Unsupported file type")
if file_size > settings.max_upload_size_bytes:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="File size exceeds limit")
def _get_s3_client():
return boto3.client(
"s3",
endpoint_url=settings.s3_endpoint_url,
aws_access_key_id=settings.s3_access_key,
aws_secret_access_key=settings.s3_secret_key,
region_name=settings.s3_region,
config=Config(signature_version="s3v4", s3={"addressing_style": "path"}),
)
async def generate_upload_url(payload: UploadUrlRequest) -> UploadUrlResponse:
_validate_media(payload.file_type, payload.file_size)
file_name = _sanitize_filename(payload.file_name)
object_key = f"uploads/{uuid4()}-{file_name}"
bucket = settings.s3_bucket_name
try:
s3_client = _get_s3_client()
upload_url = s3_client.generate_presigned_url(
"put_object",
Params={
"Bucket": bucket,
"Key": object_key,
"ContentType": payload.file_type,
},
ExpiresIn=settings.s3_presign_expire_seconds,
HttpMethod="PUT",
)
except (BotoCoreError, ClientError) as exc:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Storage service unavailable") from exc
return UploadUrlResponse(
upload_url=upload_url,
file_url=_build_file_url(bucket, object_key),
object_key=object_key,
expires_in=settings.s3_presign_expire_seconds,
required_headers={"Content-Type": payload.file_type},
)
async def store_attachment_metadata(
db: AsyncSession,
*,
user_id: int,
payload: AttachmentCreateRequest,
) -> AttachmentRead:
_validate_media(payload.file_type, payload.file_size)
if not payload.file_url.startswith(_allowed_file_url_prefix()):
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Invalid file URL")
message = await get_message_by_id(db, payload.message_id)
if not message:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
if message.sender_id != user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only the message sender can attach files",
)
attachment = await repository.create_attachment(
db,
message_id=payload.message_id,
file_url=payload.file_url,
file_type=payload.file_type,
file_size=payload.file_size,
)
await db.commit()
await db.refresh(attachment)
return AttachmentRead.model_validate(attachment)

0
app/messages/__init__.py Normal file
View File

44
app/messages/models.py Normal file
View File

@@ -0,0 +1,44 @@
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING
from sqlalchemy import DateTime, Enum as SAEnum, ForeignKey, Text, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database.base import Base
if TYPE_CHECKING:
from app.chats.models import Chat
from app.media.models import Attachment
from app.users.models import User
class MessageType(str, Enum):
TEXT = "text"
IMAGE = "image"
VIDEO = "video"
AUDIO = "audio"
VOICE = "voice"
FILE = "file"
CIRCLE_VIDEO = "circle_video"
class Message(Base):
__tablename__ = "messages"
id: Mapped[int] = mapped_column(primary_key=True, index=True)
chat_id: Mapped[int] = mapped_column(ForeignKey("chats.id", ondelete="CASCADE"), nullable=False, index=True)
sender_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
type: Mapped[MessageType] = mapped_column(SAEnum(MessageType), nullable=False, default=MessageType.TEXT)
text: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
chat: Mapped["Chat"] = relationship(back_populates="messages")
sender: Mapped["User"] = relationship(back_populates="sent_messages")
attachments: Mapped[list["Attachment"]] = relationship(back_populates="message", cascade="all, delete-orphan")

View File

@@ -0,0 +1,41 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.messages.models import Message, MessageType
async def create_message(
db: AsyncSession,
*,
chat_id: int,
sender_id: int,
message_type: MessageType,
text: str | None,
) -> Message:
message = Message(chat_id=chat_id, sender_id=sender_id, type=message_type, text=text)
db.add(message)
await db.flush()
return message
async def get_message_by_id(db: AsyncSession, message_id: int) -> Message | None:
result = await db.execute(select(Message).where(Message.id == message_id))
return result.scalar_one_or_none()
async def list_chat_messages(
db: AsyncSession,
chat_id: int,
*,
limit: int = 50,
before_id: int | None = None,
) -> list[Message]:
query = select(Message).where(Message.chat_id == chat_id)
if before_id is not None:
query = query.where(Message.id < before_id)
result = await db.execute(query.order_by(Message.id.desc()).limit(limit))
return list(result.scalars().all())
async def delete_message(db: AsyncSession, message: Message) -> None:
await db.delete(message)

49
app/messages/router.py Normal file
View File

@@ -0,0 +1,49 @@
from fastapi import APIRouter, Depends, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.service import get_current_user
from app.database.session import get_db
from app.messages.schemas import MessageCreateRequest, MessageRead, MessageUpdateRequest
from app.messages.service import create_chat_message, delete_message, get_messages, update_message
from app.users.models import User
router = APIRouter(prefix="/messages", tags=["messages"])
@router.post("", response_model=MessageRead, status_code=status.HTTP_201_CREATED)
async def create_message(
payload: MessageCreateRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> MessageRead:
return await create_chat_message(db, sender_id=current_user.id, payload=payload)
@router.get("/{chat_id}", response_model=list[MessageRead])
async def list_messages(
chat_id: int,
limit: int = 50,
before_id: int | None = None,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> list[MessageRead]:
return await get_messages(db, chat_id=chat_id, user_id=current_user.id, limit=limit, before_id=before_id)
@router.put("/{message_id}", response_model=MessageRead)
async def edit_message(
message_id: int,
payload: MessageUpdateRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> MessageRead:
return await update_message(db, message_id=message_id, user_id=current_user.id, payload=payload)
@router.delete("/{message_id}", status_code=status.HTTP_204_NO_CONTENT)
async def remove_message(
message_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> None:
await delete_message(db, message_id=message_id, user_id=current_user.id)

27
app/messages/schemas.py Normal file
View File

@@ -0,0 +1,27 @@
from datetime import datetime
from pydantic import BaseModel, ConfigDict, Field
from app.messages.models import MessageType
class MessageRead(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
chat_id: int
sender_id: int
type: MessageType
text: str | None
created_at: datetime
updated_at: datetime
class MessageCreateRequest(BaseModel):
chat_id: int
type: MessageType = MessageType.TEXT
text: str | None = Field(default=None, max_length=4096)
class MessageUpdateRequest(BaseModel):
text: str = Field(min_length=1, max_length=4096)

67
app/messages/service.py Normal file
View File

@@ -0,0 +1,67 @@
from fastapi import HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.chats.service import ensure_chat_membership
from app.messages import repository
from app.messages.models import Message
from app.messages.schemas import MessageCreateRequest, MessageUpdateRequest
async def create_chat_message(db: AsyncSession, *, sender_id: int, payload: MessageCreateRequest) -> Message:
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=sender_id)
if payload.type.value == "text" and not (payload.text and payload.text.strip()):
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Text message cannot be empty")
message = await repository.create_message(
db,
chat_id=payload.chat_id,
sender_id=sender_id,
message_type=payload.type,
text=payload.text,
)
await db.commit()
await db.refresh(message)
return message
async def get_messages(
db: AsyncSession,
*,
chat_id: int,
user_id: int,
limit: int = 50,
before_id: int | None = None,
) -> list[Message]:
await ensure_chat_membership(db, chat_id=chat_id, user_id=user_id)
safe_limit = max(1, min(limit, 100))
return await repository.list_chat_messages(db, chat_id, limit=safe_limit, before_id=before_id)
async def update_message(
db: AsyncSession,
*,
message_id: int,
user_id: int,
payload: MessageUpdateRequest,
) -> Message:
message = await repository.get_message_by_id(db, message_id)
if not message:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
await ensure_chat_membership(db, chat_id=message.chat_id, user_id=user_id)
if message.sender_id != user_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You can edit only your own messages")
message.text = payload.text
await db.commit()
await db.refresh(message)
return message
async def delete_message(db: AsyncSession, *, message_id: int, user_id: int) -> None:
message = await repository.get_message_by_id(db, message_id)
if not message:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
await ensure_chat_membership(db, chat_id=message.chat_id, user_id=user_id)
if message.sender_id != user_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You can delete only your own messages")
await repository.delete_message(db, message)
await db.commit()

View File

View File

@@ -0,0 +1,16 @@
from datetime import datetime
from sqlalchemy import DateTime, String, func
from sqlalchemy.orm import Mapped, mapped_column
from app.database.base import Base
class NotificationLog(Base):
__tablename__ = "notification_logs"
id: Mapped[int] = mapped_column(primary_key=True, index=True)
user_id: Mapped[int] = mapped_column(index=True)
event_type: Mapped[str] = mapped_column(String(64), index=True)
payload: Mapped[str] = mapped_column(String(1024))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)

View File

@@ -0,0 +1,7 @@
from sqlalchemy.ext.asyncio import AsyncSession
from app.notifications.models import NotificationLog
async def create_notification_log(db: AsyncSession, *, user_id: int, event_type: str, payload: str) -> None:
db.add(NotificationLog(user_id=user_id, event_type=event_type, payload=payload))

View File

@@ -0,0 +1,3 @@
from fastapi import APIRouter
router = APIRouter(prefix="/notifications", tags=["notifications"])

View File

@@ -0,0 +1,7 @@
from pydantic import BaseModel
class NotificationRequest(BaseModel):
user_id: int
event_type: str
payload: dict

View File

@@ -0,0 +1,13 @@
from sqlalchemy.ext.asyncio import AsyncSession
from app.notifications.repository import create_notification_log
from app.notifications.schemas import NotificationRequest
async def enqueue_notification(db: AsyncSession, payload: NotificationRequest) -> None:
await create_notification_log(
db,
user_id=payload.user_id,
event_type=payload.event_type,
payload=payload.payload.__repr__(),
)

0
app/realtime/__init__.py Normal file
View File

10
app/realtime/models.py Normal file
View File

@@ -0,0 +1,10 @@
from dataclasses import dataclass
from fastapi import WebSocket
@dataclass(slots=True)
class ConnectionContext:
user_id: int
connection_id: str
websocket: WebSocket

View File

@@ -0,0 +1,48 @@
import json
from collections.abc import Awaitable, Callable
from redis.asyncio import Redis
from app.config.settings import settings
class RedisRealtimeRepository:
def __init__(self) -> None:
self._redis: Redis | None = None
self._pubsub = None
async def connect(self) -> None:
if self._redis:
return
self._redis = Redis.from_url(settings.redis_url, decode_responses=True)
self._pubsub = self._redis.pubsub()
await self._pubsub.psubscribe("chat:*")
async def close(self) -> None:
if self._pubsub:
await self._pubsub.close()
self._pubsub = None
if self._redis:
await self._redis.aclose()
self._redis = None
async def publish_event(self, channel: str, payload: dict) -> None:
if not self._redis:
await self.connect()
assert self._redis is not None
await self._redis.publish(channel, json.dumps(payload))
async def consume(self, handler: Callable[[str, dict], Awaitable[None]]) -> None:
if not self._pubsub:
await self.connect()
assert self._pubsub is not None
while True:
message = await self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if not message:
continue
channel = message.get("channel")
data = message.get("data")
if not channel or not isinstance(data, str):
continue
payload = json.loads(data)
await handler(channel, payload)

76
app/realtime/router.py Normal file
View File

@@ -0,0 +1,76 @@
from datetime import datetime, timezone
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, status
from pydantic import ValidationError
from app.auth.service import get_current_user_for_ws
from app.database.session import AsyncSessionLocal
from app.realtime.schemas import (
ChatEventPayload,
IncomingRealtimeEvent,
MessageStatusPayload,
OutgoingRealtimeEvent,
SendMessagePayload,
)
from app.realtime.service import realtime_gateway
router = APIRouter(prefix="/realtime", tags=["realtime"])
@router.websocket("/ws")
async def websocket_gateway(websocket: WebSocket) -> None:
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
async with AsyncSessionLocal() as db:
try:
user = await get_current_user_for_ws(token, db)
except Exception:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
user_chat_ids = await realtime_gateway.load_user_chat_ids(db, user.id)
await websocket.accept()
connection_id = await realtime_gateway.register(user.id, websocket, user_chat_ids)
try:
while True:
raw_data = await websocket.receive_json()
try:
event = IncomingRealtimeEvent.model_validate(raw_data)
await _dispatch_event(db, user.id, event)
except ValidationError:
await websocket.send_json(
OutgoingRealtimeEvent(
event="error",
payload={"detail": "Invalid event payload"},
timestamp=datetime.now(timezone.utc),
).model_dump(mode="json")
)
except Exception as exc:
await websocket.send_json(
OutgoingRealtimeEvent(
event="error",
payload={"detail": str(exc)},
timestamp=datetime.now(timezone.utc),
).model_dump(mode="json")
)
except WebSocketDisconnect:
await realtime_gateway.unregister(user.id, connection_id, user_chat_ids)
async def _dispatch_event(db, user_id: int, event: IncomingRealtimeEvent) -> None:
if event.event == "send_message":
payload = SendMessagePayload.model_validate(event.payload)
await realtime_gateway.handle_send_message(db, user_id, payload)
return
if event.event in {"typing_start", "typing_stop"}:
payload = ChatEventPayload.model_validate(event.payload)
await realtime_gateway.handle_typing_event(db, user_id, payload, event.event)
return
if event.event in {"message_read", "message_delivered"}:
payload = MessageStatusPayload.model_validate(event.payload)
await realtime_gateway.handle_message_status(db, user_id, payload, event.event)
return

48
app/realtime/schemas.py Normal file
View File

@@ -0,0 +1,48 @@
from datetime import datetime
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
from app.messages.models import MessageType
RealtimeEventName = Literal[
"connect",
"disconnect",
"send_message",
"receive_message",
"typing_start",
"typing_stop",
"message_read",
"message_delivered",
"error",
]
class SendMessagePayload(BaseModel):
chat_id: int
type: MessageType = MessageType.TEXT
text: str | None = Field(default=None, max_length=4096)
temp_id: str | None = None
class ChatEventPayload(BaseModel):
chat_id: int
class MessageStatusPayload(BaseModel):
chat_id: int
message_id: int
class IncomingRealtimeEvent(BaseModel):
event: Literal["send_message", "typing_start", "typing_stop", "message_read", "message_delivered"]
payload: dict[str, Any]
class OutgoingRealtimeEvent(BaseModel):
model_config = ConfigDict(from_attributes=True)
event: RealtimeEventName
payload: dict[str, Any]
timestamp: datetime

178
app/realtime/service.py Normal file
View File

@@ -0,0 +1,178 @@
import asyncio
import contextlib
from collections import defaultdict
from datetime import datetime, timezone
from uuid import uuid4
from fastapi import WebSocket
from sqlalchemy.ext.asyncio import AsyncSession
from app.chats.repository import list_user_chat_ids
from app.chats.service import ensure_chat_membership
from app.messages.schemas import MessageCreateRequest, MessageRead
from app.messages.service import create_chat_message
from app.realtime.models import ConnectionContext
from app.realtime.repository import RedisRealtimeRepository
from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload
class RealtimeGateway:
def __init__(self) -> None:
self._repo = RedisRealtimeRepository()
self._consume_task: asyncio.Task | None = None
self._distributed_enabled = False
self._connections: dict[int, dict[str, ConnectionContext]] = defaultdict(dict)
self._chat_subscribers: dict[int, set[int]] = defaultdict(set)
async def start(self) -> None:
try:
await self._repo.connect()
if not self._consume_task:
self._consume_task = asyncio.create_task(self._repo.consume(self._handle_redis_event))
self._distributed_enabled = True
except Exception:
self._distributed_enabled = False
async def stop(self) -> None:
if self._consume_task:
self._consume_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._consume_task
self._consume_task = None
await self._repo.close()
self._distributed_enabled = False
async def register(self, user_id: int, websocket: WebSocket, user_chat_ids: list[int]) -> str:
connection_id = str(uuid4())
self._connections[user_id][connection_id] = ConnectionContext(
user_id=user_id,
connection_id=connection_id,
websocket=websocket,
)
for chat_id in user_chat_ids:
self._chat_subscribers[chat_id].add(user_id)
await self._send_user_event(
user_id,
OutgoingRealtimeEvent(
event="connect",
payload={"connection_id": connection_id},
timestamp=datetime.now(timezone.utc),
),
)
return connection_id
async def unregister(self, user_id: int, connection_id: str, user_chat_ids: list[int]) -> None:
user_connections = self._connections.get(user_id, {})
user_connections.pop(connection_id, None)
if not user_connections:
self._connections.pop(user_id, None)
for chat_id in user_chat_ids:
subscribers = self._chat_subscribers.get(chat_id)
if not subscribers:
continue
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None:
message = await create_chat_message(
db,
sender_id=user_id,
payload=MessageCreateRequest(chat_id=payload.chat_id, type=payload.type, text=payload.text),
)
message_data = MessageRead.model_validate(message).model_dump(mode="json")
await self._publish_chat_event(
payload.chat_id,
event="receive_message",
payload={
"chat_id": payload.chat_id,
"message": message_data,
"temp_id": payload.temp_id,
"sender_id": user_id,
},
)
async def handle_typing_event(self, db: AsyncSession, user_id: int, payload: ChatEventPayload, event: str) -> None:
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
await self._publish_chat_event(
payload.chat_id,
event=event,
payload={"chat_id": payload.chat_id, "user_id": user_id},
)
async def handle_message_status(
self,
db: AsyncSession,
user_id: int,
payload: MessageStatusPayload,
event: str,
) -> None:
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
await self._publish_chat_event(
payload.chat_id,
event=event,
payload={
"chat_id": payload.chat_id,
"message_id": payload.message_id,
"user_id": user_id,
},
)
async def load_user_chat_ids(self, db: AsyncSession, user_id: int) -> list[int]:
return await list_user_chat_ids(db, user_id=user_id)
async def _handle_redis_event(self, channel: str, payload: dict) -> None:
chat_id = self._extract_chat_id(channel)
if chat_id is None:
return
subscribers = self._chat_subscribers.get(chat_id, set())
if not subscribers:
return
event = OutgoingRealtimeEvent(
event=payload.get("event", "error"),
payload=payload.get("payload", {}),
timestamp=datetime.now(timezone.utc),
)
await asyncio.gather(*(self._send_user_event(user_id, event) for user_id in subscribers), return_exceptions=True)
async def _publish_chat_event(self, chat_id: int, *, event: str, payload: dict) -> None:
event_payload = {
"event": event,
"payload": payload,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
if self._distributed_enabled:
await self._repo.publish_event(f"chat:{chat_id}", event_payload)
return
await self._handle_redis_event(f"chat:{chat_id}", event_payload)
async def _send_user_event(self, user_id: int, event: OutgoingRealtimeEvent) -> None:
user_connections = self._connections.get(user_id, {})
if not user_connections:
return
disconnected: list[str] = []
for connection_id, context in user_connections.items():
try:
await context.websocket.send_json(event.model_dump(mode="json"))
except Exception:
disconnected.append(connection_id)
for connection_id in disconnected:
user_connections.pop(connection_id, None)
if not user_connections:
self._connections.pop(user_id, None)
for chat_id, subscribers in list(self._chat_subscribers.items()):
subscribers.discard(user_id)
if not subscribers:
self._chat_subscribers.pop(chat_id, None)
@staticmethod
def _extract_chat_id(channel: str) -> int | None:
if not channel.startswith("chat:"):
return None
chat_id = channel.split(":", maxsplit=1)[1]
if not chat_id.isdigit():
return None
return int(chat_id)
realtime_gateway = RealtimeGateway()

0
app/users/__init__.py Normal file
View File

41
app/users/models.py Normal file
View File

@@ -0,0 +1,41 @@
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import Boolean, DateTime, String, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database.base import Base
if TYPE_CHECKING:
from app.auth.models import EmailVerificationToken, PasswordResetToken
from app.chats.models import ChatMember
from app.messages.models import Message
class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True, index=True)
username: Mapped[str] = mapped_column(String(50), unique=True, index=True)
email: Mapped[str] = mapped_column(String(255), unique=True, index=True)
password_hash: Mapped[str] = mapped_column(String(255))
avatar_url: Mapped[str | None] = mapped_column(String(512), nullable=True)
email_verified: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
memberships: Mapped[list["ChatMember"]] = relationship(back_populates="user", cascade="all, delete-orphan")
sent_messages: Mapped[list["Message"]] = relationship(back_populates="sender")
email_verification_tokens: Mapped[list["EmailVerificationToken"]] = relationship(
back_populates="user",
cascade="all, delete-orphan",
)
password_reset_tokens: Mapped[list["PasswordResetToken"]] = relationship(
back_populates="user",
cascade="all, delete-orphan",
)

26
app/users/repository.py Normal file
View File

@@ -0,0 +1,26 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.users.models import User
async def create_user(db: AsyncSession, *, email: str, username: str, password_hash: str) -> User:
user = User(email=email, username=username, password_hash=password_hash, email_verified=False)
db.add(user)
await db.flush()
return user
async def get_user_by_id(db: AsyncSession, user_id: int) -> User | None:
result = await db.execute(select(User).where(User.id == user_id))
return result.scalar_one_or_none()
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none()
async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
result = await db.execute(select(User).where(User.username == username))
return result.scalar_one_or_none()

43
app/users/router.py Normal file
View File

@@ -0,0 +1,43 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.service import get_current_user
from app.database.session import get_db
from app.users.models import User
from app.users.schemas import UserProfileUpdate, UserRead
from app.users.service import get_user_by_id, get_user_by_username, update_user_profile
router = APIRouter(prefix="/users", tags=["users"])
@router.get("/me", response_model=UserRead)
async def read_me(current_user: User = Depends(get_current_user)) -> UserRead:
return current_user
@router.get("/{user_id}", response_model=UserRead)
async def read_user(user_id: int, db: AsyncSession = Depends(get_db), _current_user: User = Depends(get_current_user)) -> UserRead:
user = await get_user_by_id(db, user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
return user
@router.put("/profile", response_model=UserRead)
async def update_profile(
payload: UserProfileUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> UserRead:
if payload.username and payload.username != current_user.username:
username_owner = await get_user_by_username(db, payload.username)
if username_owner:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Username already taken")
updated = await update_user_profile(
db,
current_user,
username=payload.username,
avatar_url=payload.avatar_url,
)
return updated

27
app/users/schemas.py Normal file
View File

@@ -0,0 +1,27 @@
from datetime import datetime
from pydantic import BaseModel, ConfigDict, EmailStr, Field
class UserBase(BaseModel):
username: str = Field(min_length=3, max_length=50)
email: EmailStr
class UserCreate(UserBase):
password: str = Field(min_length=8, max_length=128)
class UserRead(UserBase):
model_config = ConfigDict(from_attributes=True)
id: int
avatar_url: str | None = None
email_verified: bool
created_at: datetime
updated_at: datetime
class UserProfileUpdate(BaseModel):
username: str | None = Field(default=None, min_length=3, max_length=50)
avatar_url: str | None = Field(default=None, max_length=512)

32
app/users/service.py Normal file
View File

@@ -0,0 +1,32 @@
from sqlalchemy.ext.asyncio import AsyncSession
from app.users import repository
from app.users.models import User
async def get_user_by_id(db: AsyncSession, user_id: int) -> User | None:
return await repository.get_user_by_id(db, user_id)
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
return await repository.get_user_by_email(db, email)
async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
return await repository.get_user_by_username(db, username)
async def update_user_profile(
db: AsyncSession,
user: User,
*,
username: str | None = None,
avatar_url: str | None = None,
) -> User:
if username is not None:
user.username = username
if avatar_url is not None:
user.avatar_url = avatar_url
await db.commit()
await db.refresh(user)
return user

0
app/utils/__init__.py Normal file
View File

51
app/utils/security.py Normal file
View File

@@ -0,0 +1,51 @@
from datetime import datetime, timedelta, timezone
from secrets import token_urlsafe
from jose import JWTError, jwt
from passlib.context import CryptContext
from app.config.settings import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def hash_password(password: str) -> str:
return pwd_context.hash(password)
def verify_password(password: str, hashed_password: str) -> bool:
return pwd_context.verify(password, hashed_password)
def _create_token(subject: str, token_type: str, expires_delta: timedelta) -> str:
expire = datetime.now(timezone.utc) + expires_delta
payload = {"sub": subject, "type": token_type, "exp": expire}
return jwt.encode(payload, settings.secret_key, algorithm=settings.jwt_algorithm)
def create_access_token(subject: str) -> str:
return _create_token(
subject=subject,
token_type="access",
expires_delta=timedelta(minutes=settings.access_token_expire_minutes),
)
def create_refresh_token(subject: str) -> str:
return _create_token(
subject=subject,
token_type="refresh",
expires_delta=timedelta(days=settings.refresh_token_expire_days),
)
def decode_token(token: str) -> dict:
try:
return jwt.decode(token, settings.secret_key, algorithms=[settings.jwt_algorithm])
except JWTError as exc:
raise ValueError("Invalid token") from exc
def generate_random_token() -> str:
return token_urlsafe(48)