first commit
This commit is contained in:
32
.env.example
Normal file
32
.env.example
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
APP_NAME=BenyaMessenger
|
||||||
|
ENVIRONMENT=development
|
||||||
|
DEBUG=true
|
||||||
|
API_V1_PREFIX=/api/v1
|
||||||
|
AUTO_CREATE_TABLES=true
|
||||||
|
|
||||||
|
SECRET_KEY=change-me-please-with-a-long-random-secret
|
||||||
|
JWT_ALGORITHM=HS256
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||||
|
EMAIL_VERIFICATION_TOKEN_EXPIRE_HOURS=24
|
||||||
|
PASSWORD_RESET_TOKEN_EXPIRE_HOURS=1
|
||||||
|
|
||||||
|
POSTGRES_DSN=postgresql+asyncpg://postgres:postgres@localhost:5432/messenger
|
||||||
|
REDIS_URL=redis://localhost:6379/0
|
||||||
|
|
||||||
|
S3_ENDPOINT_URL=http://localhost:9000
|
||||||
|
S3_ACCESS_KEY=minioadmin
|
||||||
|
S3_SECRET_KEY=minioadmin
|
||||||
|
S3_REGION=us-east-1
|
||||||
|
S3_BUCKET_NAME=messenger-media
|
||||||
|
S3_PRESIGN_EXPIRE_SECONDS=900
|
||||||
|
MAX_UPLOAD_SIZE_BYTES=104857600
|
||||||
|
|
||||||
|
FRONTEND_BASE_URL=http://localhost:5173
|
||||||
|
|
||||||
|
SMTP_HOST=localhost
|
||||||
|
SMTP_PORT=1025
|
||||||
|
SMTP_USERNAME=
|
||||||
|
SMTP_PASSWORD=
|
||||||
|
SMTP_USE_TLS=false
|
||||||
|
SMTP_FROM_EMAIL=no-reply@benyamessenger.local
|
||||||
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
.venv/
|
||||||
|
.venv312/
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
.idea/
|
||||||
|
.env
|
||||||
132
ARCHITECTURE.md
Normal file
132
ARCHITECTURE.md
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
# Messaging Platform Architecture
|
||||||
|
|
||||||
|
## 1) Backend Architecture (FastAPI)
|
||||||
|
|
||||||
|
### High-Level Components
|
||||||
|
- API Gateway (`FastAPI` REST + WebSocket entrypoint)
|
||||||
|
- Auth Service (JWT, refresh token, email verification, password reset)
|
||||||
|
- User Service (profile, avatar metadata, user lookup)
|
||||||
|
- Chat Service (private/group/channel lifecycle and membership)
|
||||||
|
- Message Service (create/list/edit/delete messages with pagination)
|
||||||
|
- Realtime Gateway (WebSocket session management, typing/read events)
|
||||||
|
- Media Service (presigned upload URL + attachment metadata)
|
||||||
|
- Notification Service (mention and offline notifications)
|
||||||
|
- Email Service (verification and reset email delivery abstraction)
|
||||||
|
- Worker Layer (`Celery`) for async tasks (email, push, cleanup)
|
||||||
|
|
||||||
|
### Runtime Topology
|
||||||
|
- Stateless API instances behind a load balancer
|
||||||
|
- PostgreSQL for source-of-truth data
|
||||||
|
- Redis for cache, rate-limit counters, and pub/sub fan-out
|
||||||
|
- MinIO (S3-compatible) for media blobs
|
||||||
|
- Celery workers + Redis broker for background jobs
|
||||||
|
|
||||||
|
### Realtime Scaling
|
||||||
|
- Each WebSocket node subscribes to Redis channels by chat/user scope.
|
||||||
|
- Incoming `send_message` persists to PostgreSQL first.
|
||||||
|
- Message event is published to Redis; all nodes broadcast to online members.
|
||||||
|
- Offline recipients are queued for notifications.
|
||||||
|
|
||||||
|
### Security Layers
|
||||||
|
- JWT access + refresh tokens (short-lived access, longer refresh)
|
||||||
|
- `bcrypt` password hashing
|
||||||
|
- Email verification required before login
|
||||||
|
- Tokenized password reset flow
|
||||||
|
- DB indexes on auth and chat/message lookup paths
|
||||||
|
- Ready extension points for rate limiting and anti-spam middleware
|
||||||
|
|
||||||
|
### Suggested Backend Tree
|
||||||
|
```text
|
||||||
|
backend/
|
||||||
|
app/
|
||||||
|
main.py
|
||||||
|
config/
|
||||||
|
settings.py
|
||||||
|
database/
|
||||||
|
base.py
|
||||||
|
session.py
|
||||||
|
models.py
|
||||||
|
auth/
|
||||||
|
models.py
|
||||||
|
schemas.py
|
||||||
|
repository.py
|
||||||
|
service.py
|
||||||
|
router.py
|
||||||
|
users/
|
||||||
|
models.py
|
||||||
|
schemas.py
|
||||||
|
repository.py
|
||||||
|
service.py
|
||||||
|
router.py
|
||||||
|
chats/
|
||||||
|
models.py
|
||||||
|
schemas.py
|
||||||
|
repository.py
|
||||||
|
service.py
|
||||||
|
router.py
|
||||||
|
messages/
|
||||||
|
models.py
|
||||||
|
schemas.py
|
||||||
|
repository.py
|
||||||
|
service.py
|
||||||
|
router.py
|
||||||
|
media/
|
||||||
|
models.py
|
||||||
|
schemas.py
|
||||||
|
repository.py
|
||||||
|
service.py
|
||||||
|
router.py
|
||||||
|
realtime/
|
||||||
|
models.py
|
||||||
|
schemas.py
|
||||||
|
repository.py
|
||||||
|
service.py
|
||||||
|
router.py
|
||||||
|
notifications/
|
||||||
|
models.py
|
||||||
|
schemas.py
|
||||||
|
repository.py
|
||||||
|
service.py
|
||||||
|
router.py
|
||||||
|
email/
|
||||||
|
models.py
|
||||||
|
schemas.py
|
||||||
|
repository.py
|
||||||
|
service.py
|
||||||
|
router.py
|
||||||
|
utils/
|
||||||
|
security.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 2) Web Client Architecture (React + TypeScript)
|
||||||
|
|
||||||
|
### Layers
|
||||||
|
- `src/api`: typed Axios clients (auth, users, chats, messages, media)
|
||||||
|
- `src/store`: Zustand stores (session, chat list, active chat, realtime state)
|
||||||
|
- `src/chat`: domain logic (message normalization, optimistic updates)
|
||||||
|
- `src/hooks`: composable hooks (`useAuth`, `useChatMessages`, `useWebSocket`)
|
||||||
|
- `src/components`: reusable UI units (message bubble, composer, media picker)
|
||||||
|
- `src/pages`: route-level pages (`Login`, `Register`, `Chats`, `ChatDetail`)
|
||||||
|
|
||||||
|
### Data Flow
|
||||||
|
- HTTP for CRUD and pagination
|
||||||
|
- WebSocket for realtime events (`receive_message`, typing, read receipts)
|
||||||
|
- Optimistic UI on send, rollback on failure
|
||||||
|
- Local cache keyed by chat id + pagination cursor
|
||||||
|
|
||||||
|
## 3) Android Architecture (Kotlin + Compose)
|
||||||
|
|
||||||
|
### Layers (MVVM)
|
||||||
|
- `network`: Retrofit API + WebSocket client
|
||||||
|
- `data`: DTOs + Room entities + DAOs
|
||||||
|
- `repository`: sync strategy between remote and local cache
|
||||||
|
- `viewmodel`: state, intents, side-effects
|
||||||
|
- `ui/screens`: Compose screens (`Auth`, `ChatList`, `ChatRoom`, `Profile`)
|
||||||
|
- `ui/components`: reusable composables
|
||||||
|
|
||||||
|
### Realtime Strategy
|
||||||
|
- WebSocket manager as singleton service (Hilt)
|
||||||
|
- ChatViewModel subscribes to events by selected chat
|
||||||
|
- Persist inbound messages to Room first, then render state from DB flows
|
||||||
|
- Push notifications bridged to deep links into chat screens
|
||||||
|
|
||||||
12
README.md
Normal file
12
README.md
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# Benya Messenger
|
||||||
|
|
||||||
|
Backend foundation for a Telegram-like real-time messaging platform.
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
1. Create and activate Python 3.12 virtualenv.
|
||||||
|
2. Install dependencies:
|
||||||
|
pip install -r requirements.txt
|
||||||
|
3. Configure environment from `.env.example`.
|
||||||
|
4. Start API:
|
||||||
|
uvicorn app.main:app --reload --port 8000
|
||||||
38
alembic.ini
Normal file
38
alembic.ini
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
[alembic]
|
||||||
|
script_location = alembic
|
||||||
|
prepend_sys_path = .
|
||||||
|
timezone = UTC
|
||||||
|
sqlalchemy.url = postgresql+asyncpg://postgres:postgres@localhost:5432/messenger
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
60
alembic/env.py
Normal file
60
alembic/env.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from sqlalchemy import pool
|
||||||
|
from sqlalchemy.engine import Connection
|
||||||
|
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.database.base import Base
|
||||||
|
from app.database import models # noqa: F401
|
||||||
|
|
||||||
|
config = context.config
|
||||||
|
config.set_main_option("sqlalchemy.url", settings.postgres_dsn)
|
||||||
|
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def do_run_migrations(connection: Connection) -> None:
|
||||||
|
context.configure(connection=connection, target_metadata=target_metadata, compare_type=True)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_migrations_online() -> None:
|
||||||
|
connectable = async_engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with connectable.connect() as connection:
|
||||||
|
await connection.run_sync(do_run_migrations)
|
||||||
|
|
||||||
|
await connectable.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(run_migrations_online())
|
||||||
27
alembic/script.py.mako
Normal file
27
alembic/script.py.mako
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
217
alembic/versions/0001_initial_schema.py
Normal file
217
alembic/versions/0001_initial_schema.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
"""initial schema
|
||||||
|
|
||||||
|
Revision ID: 0001_initial_schema
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-03-07 22:40:00.000000
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = "0001_initial_schema"
|
||||||
|
down_revision: Union[str, Sequence[str], None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
chat_type_enum = sa.Enum("PRIVATE", "GROUP", "CHANNEL", name="chattype")
|
||||||
|
chat_role_enum = sa.Enum("OWNER", "ADMIN", "MEMBER", name="chatmemberrole")
|
||||||
|
message_type_enum = sa.Enum(
|
||||||
|
"TEXT",
|
||||||
|
"IMAGE",
|
||||||
|
"VIDEO",
|
||||||
|
"AUDIO",
|
||||||
|
"VOICE",
|
||||||
|
"FILE",
|
||||||
|
"CIRCLE_VIDEO",
|
||||||
|
name="messagetype",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"users",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("username", sa.String(length=50), nullable=False),
|
||||||
|
sa.Column("email", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("password_hash", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("avatar_url", sa.String(length=512), nullable=True),
|
||||||
|
sa.Column("email_verified", sa.Boolean(), nullable=False, server_default=sa.false()),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_users")),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_users_id"), "users", ["id"], unique=False)
|
||||||
|
op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True)
|
||||||
|
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
|
||||||
|
op.create_index(op.f("ix_users_email_verified"), "users", ["email_verified"], unique=False)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"chats",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("type", chat_type_enum, nullable=False),
|
||||||
|
sa.Column("title", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_chats")),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_chats_id"), "chats", ["id"], unique=False)
|
||||||
|
op.create_index(op.f("ix_chats_type"), "chats", ["type"], unique=False)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"chat_members",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("chat_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("role", chat_role_enum, nullable=False),
|
||||||
|
sa.Column("joined_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["chat_id"], ["chats.id"], name=op.f("fk_chat_members_chat_id_chats"), ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], name=op.f("fk_chat_members_user_id_users"), ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_chat_members")),
|
||||||
|
sa.UniqueConstraint("chat_id", "user_id", name=op.f("uq_chat_members_chat_id_user_id")),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_chat_members_id"), "chat_members", ["id"], unique=False)
|
||||||
|
op.create_index(op.f("ix_chat_members_chat_id"), "chat_members", ["chat_id"], unique=False)
|
||||||
|
op.create_index(op.f("ix_chat_members_user_id"), "chat_members", ["user_id"], unique=False)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"messages",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("chat_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("sender_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("type", message_type_enum, nullable=False),
|
||||||
|
sa.Column("text", sa.Text(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["chat_id"], ["chats.id"], name=op.f("fk_messages_chat_id_chats"), ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["sender_id"], ["users.id"], name=op.f("fk_messages_sender_id_users"), ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_messages")),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_messages_id"), "messages", ["id"], unique=False)
|
||||||
|
op.create_index(op.f("ix_messages_chat_id"), "messages", ["chat_id"], unique=False)
|
||||||
|
op.create_index(op.f("ix_messages_sender_id"), "messages", ["sender_id"], unique=False)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"attachments",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("message_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("file_url", sa.String(length=1024), nullable=False),
|
||||||
|
sa.Column("file_type", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("file_size", sa.Integer(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["message_id"], ["messages.id"], name=op.f("fk_attachments_message_id_messages"), ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_attachments")),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_attachments_id"), "attachments", ["id"], unique=False)
|
||||||
|
op.create_index(op.f("ix_attachments_message_id"), "attachments", ["message_id"], unique=False)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"email_verification_tokens",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("token", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"], ["users.id"], name=op.f("fk_email_verification_tokens_user_id_users"), ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_email_verification_tokens")),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_email_verification_tokens_id"), "email_verification_tokens", ["id"], unique=False)
|
||||||
|
op.create_index(op.f("ix_email_verification_tokens_user_id"), "email_verification_tokens", ["user_id"], unique=False)
|
||||||
|
op.create_index(op.f("ix_email_verification_tokens_token"), "email_verification_tokens", ["token"], unique=True)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_email_verification_tokens_expires_at"), "email_verification_tokens", ["expires_at"], unique=False
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"password_reset_tokens",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("token", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], name=op.f("fk_password_reset_tokens_user_id_users"), ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_password_reset_tokens")),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_password_reset_tokens_id"), "password_reset_tokens", ["id"], unique=False)
|
||||||
|
op.create_index(op.f("ix_password_reset_tokens_user_id"), "password_reset_tokens", ["user_id"], unique=False)
|
||||||
|
op.create_index(op.f("ix_password_reset_tokens_token"), "password_reset_tokens", ["token"], unique=True)
|
||||||
|
op.create_index(op.f("ix_password_reset_tokens_expires_at"), "password_reset_tokens", ["expires_at"], unique=False)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"email_logs",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("recipient", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("subject", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("body", sa.Text(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_email_logs")),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_email_logs_id"), "email_logs", ["id"], unique=False)
|
||||||
|
op.create_index(op.f("ix_email_logs_recipient"), "email_logs", ["recipient"], unique=False)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"notification_logs",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("event_type", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("payload", sa.String(length=1024), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_notification_logs")),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_notification_logs_id"), "notification_logs", ["id"], unique=False)
|
||||||
|
op.create_index(op.f("ix_notification_logs_user_id"), "notification_logs", ["user_id"], unique=False)
|
||||||
|
op.create_index(op.f("ix_notification_logs_event_type"), "notification_logs", ["event_type"], unique=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index(op.f("ix_notification_logs_event_type"), table_name="notification_logs")
|
||||||
|
op.drop_index(op.f("ix_notification_logs_user_id"), table_name="notification_logs")
|
||||||
|
op.drop_index(op.f("ix_notification_logs_id"), table_name="notification_logs")
|
||||||
|
op.drop_table("notification_logs")
|
||||||
|
|
||||||
|
op.drop_index(op.f("ix_email_logs_recipient"), table_name="email_logs")
|
||||||
|
op.drop_index(op.f("ix_email_logs_id"), table_name="email_logs")
|
||||||
|
op.drop_table("email_logs")
|
||||||
|
|
||||||
|
op.drop_index(op.f("ix_password_reset_tokens_expires_at"), table_name="password_reset_tokens")
|
||||||
|
op.drop_index(op.f("ix_password_reset_tokens_token"), table_name="password_reset_tokens")
|
||||||
|
op.drop_index(op.f("ix_password_reset_tokens_user_id"), table_name="password_reset_tokens")
|
||||||
|
op.drop_index(op.f("ix_password_reset_tokens_id"), table_name="password_reset_tokens")
|
||||||
|
op.drop_table("password_reset_tokens")
|
||||||
|
|
||||||
|
op.drop_index(op.f("ix_email_verification_tokens_expires_at"), table_name="email_verification_tokens")
|
||||||
|
op.drop_index(op.f("ix_email_verification_tokens_token"), table_name="email_verification_tokens")
|
||||||
|
op.drop_index(op.f("ix_email_verification_tokens_user_id"), table_name="email_verification_tokens")
|
||||||
|
op.drop_index(op.f("ix_email_verification_tokens_id"), table_name="email_verification_tokens")
|
||||||
|
op.drop_table("email_verification_tokens")
|
||||||
|
|
||||||
|
op.drop_index(op.f("ix_attachments_message_id"), table_name="attachments")
|
||||||
|
op.drop_index(op.f("ix_attachments_id"), table_name="attachments")
|
||||||
|
op.drop_table("attachments")
|
||||||
|
|
||||||
|
op.drop_index(op.f("ix_messages_sender_id"), table_name="messages")
|
||||||
|
op.drop_index(op.f("ix_messages_chat_id"), table_name="messages")
|
||||||
|
op.drop_index(op.f("ix_messages_id"), table_name="messages")
|
||||||
|
op.drop_table("messages")
|
||||||
|
|
||||||
|
op.drop_index(op.f("ix_chat_members_user_id"), table_name="chat_members")
|
||||||
|
op.drop_index(op.f("ix_chat_members_chat_id"), table_name="chat_members")
|
||||||
|
op.drop_index(op.f("ix_chat_members_id"), table_name="chat_members")
|
||||||
|
op.drop_table("chat_members")
|
||||||
|
|
||||||
|
op.drop_index(op.f("ix_chats_type"), table_name="chats")
|
||||||
|
op.drop_index(op.f("ix_chats_id"), table_name="chats")
|
||||||
|
op.drop_table("chats")
|
||||||
|
|
||||||
|
op.drop_index(op.f("ix_users_email_verified"), table_name="users")
|
||||||
|
op.drop_index(op.f("ix_users_email"), table_name="users")
|
||||||
|
op.drop_index(op.f("ix_users_username"), table_name="users")
|
||||||
|
op.drop_index(op.f("ix_users_id"), table_name="users")
|
||||||
|
op.drop_table("users")
|
||||||
|
|
||||||
|
message_type_enum.drop(op.get_bind(), checkfirst=False)
|
||||||
|
chat_role_enum.drop(op.get_bind(), checkfirst=False)
|
||||||
|
chat_type_enum.drop(op.get_bind(), checkfirst=False)
|
||||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
0
app/auth/__init__.py
Normal file
0
app/auth/__init__.py
Normal file
34
app/auth/models.py
Normal file
34
app/auth/models.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, ForeignKey, String
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.database.base import Base
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.users.models import User
|
||||||
|
|
||||||
|
|
||||||
|
class EmailVerificationToken(Base):
|
||||||
|
__tablename__ = "email_verification_tokens"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
token: Mapped[str] = mapped_column(String(255), unique=True, index=True, nullable=False)
|
||||||
|
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||||
|
|
||||||
|
user: Mapped["User"] = relationship(back_populates="email_verification_tokens")
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordResetToken(Base):
|
||||||
|
__tablename__ = "password_reset_tokens"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
token: Mapped[str] = mapped_column(String(255), unique=True, index=True, nullable=False)
|
||||||
|
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||||
|
|
||||||
|
user: Mapped["User"] = relationship(back_populates="password_reset_tokens")
|
||||||
46
app/auth/repository.py
Normal file
46
app/auth/repository.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.auth.models import EmailVerificationToken, PasswordResetToken
|
||||||
|
|
||||||
|
|
||||||
|
async def create_email_verification_token(db: AsyncSession, user_id: int, token: str, expires_at: datetime) -> None:
|
||||||
|
db.add(
|
||||||
|
EmailVerificationToken(
|
||||||
|
user_id=user_id,
|
||||||
|
token=token,
|
||||||
|
expires_at=expires_at,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_email_verification_token(db: AsyncSession, token: str) -> EmailVerificationToken | None:
|
||||||
|
result = await db.execute(select(EmailVerificationToken).where(EmailVerificationToken.token == token))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_email_verification_tokens_for_user(db: AsyncSession, user_id: int) -> None:
|
||||||
|
await db.execute(delete(EmailVerificationToken).where(EmailVerificationToken.user_id == user_id))
|
||||||
|
|
||||||
|
|
||||||
|
async def create_password_reset_token(db: AsyncSession, user_id: int, token: str, expires_at: datetime) -> None:
|
||||||
|
db.add(
|
||||||
|
PasswordResetToken(
|
||||||
|
user_id=user_id,
|
||||||
|
token=token,
|
||||||
|
expires_at=expires_at,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_password_reset_token(db: AsyncSession, token: str) -> PasswordResetToken | None:
|
||||||
|
result = await db.execute(select(PasswordResetToken).where(PasswordResetToken.token == token))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_password_reset_tokens_for_user(db: AsyncSession, user_id: int) -> None:
|
||||||
|
await db.execute(delete(PasswordResetToken).where(PasswordResetToken.user_id == user_id))
|
||||||
81
app/auth/router.py
Normal file
81
app/auth/router.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
from fastapi import APIRouter, Depends, status
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.auth.schemas import (
|
||||||
|
AuthUserResponse,
|
||||||
|
LoginRequest,
|
||||||
|
MessageResponse,
|
||||||
|
RegisterRequest,
|
||||||
|
RequestPasswordResetRequest,
|
||||||
|
ResendVerificationRequest,
|
||||||
|
ResetPasswordRequest,
|
||||||
|
TokenResponse,
|
||||||
|
VerifyEmailRequest,
|
||||||
|
)
|
||||||
|
from app.auth.service import (
|
||||||
|
get_current_user,
|
||||||
|
get_email_sender,
|
||||||
|
login_user,
|
||||||
|
register_user,
|
||||||
|
request_password_reset,
|
||||||
|
resend_verification_email,
|
||||||
|
reset_password,
|
||||||
|
verify_email,
|
||||||
|
)
|
||||||
|
from app.database.session import get_db
|
||||||
|
from app.email.service import EmailService
|
||||||
|
from app.users.models import User
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=MessageResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def register(
|
||||||
|
payload: RegisterRequest,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
email_service: EmailService = Depends(get_email_sender),
|
||||||
|
) -> MessageResponse:
|
||||||
|
await register_user(db, payload, email_service)
|
||||||
|
return MessageResponse(message="Registration successful. Verification email sent.")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=TokenResponse)
|
||||||
|
async def login(payload: LoginRequest, db: AsyncSession = Depends(get_db)) -> TokenResponse:
|
||||||
|
return await login_user(db, payload)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/verify-email", response_model=MessageResponse)
|
||||||
|
async def verify_email_endpoint(payload: VerifyEmailRequest, db: AsyncSession = Depends(get_db)) -> MessageResponse:
|
||||||
|
await verify_email(db, payload)
|
||||||
|
return MessageResponse(message="Email verified successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/resend-verification", response_model=MessageResponse)
|
||||||
|
async def resend_verification(
|
||||||
|
payload: ResendVerificationRequest,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
email_service: EmailService = Depends(get_email_sender),
|
||||||
|
) -> MessageResponse:
|
||||||
|
await resend_verification_email(db, payload, email_service)
|
||||||
|
return MessageResponse(message="If the account exists, a verification email was sent.")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/request-password-reset", response_model=MessageResponse)
|
||||||
|
async def request_password_reset_endpoint(
|
||||||
|
payload: RequestPasswordResetRequest,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
email_service: EmailService = Depends(get_email_sender),
|
||||||
|
) -> MessageResponse:
|
||||||
|
await request_password_reset(db, payload, email_service)
|
||||||
|
return MessageResponse(message="If the account exists, a reset email was sent.")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/reset-password", response_model=MessageResponse)
|
||||||
|
async def reset_password_endpoint(payload: ResetPasswordRequest, db: AsyncSession = Depends(get_db)) -> MessageResponse:
|
||||||
|
await reset_password(db, payload)
|
||||||
|
return MessageResponse(message="Password reset successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=AuthUserResponse)
|
||||||
|
async def me(current_user: User = Depends(get_current_user)) -> AuthUserResponse:
|
||||||
|
return current_user
|
||||||
53
app/auth/schemas.py
Normal file
53
app/auth/schemas.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, EmailStr, Field
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterRequest(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
username: str = Field(min_length=3, max_length=50)
|
||||||
|
password: str = Field(min_length=8, max_length=128)
|
||||||
|
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
password: str = Field(min_length=8, max_length=128)
|
||||||
|
|
||||||
|
|
||||||
|
class VerifyEmailRequest(BaseModel):
|
||||||
|
token: str = Field(min_length=16, max_length=512)
|
||||||
|
|
||||||
|
|
||||||
|
class ResendVerificationRequest(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
|
||||||
|
|
||||||
|
class RequestPasswordResetRequest(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
|
||||||
|
|
||||||
|
class ResetPasswordRequest(BaseModel):
|
||||||
|
token: str = Field(min_length=16, max_length=512)
|
||||||
|
new_password: str = Field(min_length=8, max_length=128)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageResponse(BaseModel):
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
class AuthUserResponse(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: int
|
||||||
|
email: EmailStr
|
||||||
|
username: str
|
||||||
|
avatar_url: str | None = None
|
||||||
|
email_verified: bool
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
195
app/auth/service.py
Normal file
195
app/auth/service.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.auth import repository as auth_repository
|
||||||
|
from app.auth.schemas import (
|
||||||
|
LoginRequest,
|
||||||
|
RegisterRequest,
|
||||||
|
RequestPasswordResetRequest,
|
||||||
|
ResendVerificationRequest,
|
||||||
|
ResetPasswordRequest,
|
||||||
|
TokenResponse,
|
||||||
|
VerifyEmailRequest,
|
||||||
|
)
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.database.session import get_db
|
||||||
|
from app.email.service import EmailService, get_email_service
|
||||||
|
from app.users.models import User
|
||||||
|
from app.users.repository import create_user, get_user_by_email, get_user_by_id, get_user_by_username
|
||||||
|
from app.utils.security import (
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
decode_token,
|
||||||
|
generate_random_token,
|
||||||
|
hash_password,
|
||||||
|
verify_password,
|
||||||
|
)
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/login")
|
||||||
|
|
||||||
|
|
||||||
|
async def register_user(
|
||||||
|
db: AsyncSession,
|
||||||
|
payload: RegisterRequest,
|
||||||
|
email_service: EmailService,
|
||||||
|
) -> None:
|
||||||
|
existing_email = await get_user_by_email(db, payload.email)
|
||||||
|
if existing_email:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email is already registered")
|
||||||
|
|
||||||
|
existing_username = await get_user_by_username(db, payload.username)
|
||||||
|
if existing_username:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Username is already taken")
|
||||||
|
|
||||||
|
user = await create_user(
|
||||||
|
db,
|
||||||
|
email=payload.email,
|
||||||
|
username=payload.username,
|
||||||
|
password_hash=hash_password(payload.password),
|
||||||
|
)
|
||||||
|
|
||||||
|
verification_token = generate_random_token()
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours)
|
||||||
|
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
|
||||||
|
await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await email_service.send_verification_email(payload.email, verification_token)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_email(db: AsyncSession, payload: VerifyEmailRequest) -> None:
|
||||||
|
record = await auth_repository.get_email_verification_token(db, payload.token)
|
||||||
|
if not record:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification token")
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc)
|
||||||
|
if expires_at < now:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Verification token expired")
|
||||||
|
|
||||||
|
user = await get_user_by_id(db, record.user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||||
|
|
||||||
|
user.email_verified = True
|
||||||
|
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def resend_verification_email(
|
||||||
|
db: AsyncSession,
|
||||||
|
payload: ResendVerificationRequest,
|
||||||
|
email_service: EmailService,
|
||||||
|
) -> None:
|
||||||
|
user = await get_user_by_email(db, payload.email)
|
||||||
|
if not user or user.email_verified:
|
||||||
|
return
|
||||||
|
|
||||||
|
verification_token = generate_random_token()
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.email_verification_token_expire_hours)
|
||||||
|
await auth_repository.delete_email_verification_tokens_for_user(db, user.id)
|
||||||
|
await auth_repository.create_email_verification_token(db, user.id, verification_token, expires_at)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await email_service.send_verification_email(user.email, verification_token)
|
||||||
|
|
||||||
|
|
||||||
|
async def login_user(db: AsyncSession, payload: LoginRequest) -> TokenResponse:
|
||||||
|
user = await get_user_by_email(db, payload.email)
|
||||||
|
if not user or not verify_password(payload.password, user.password_hash):
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
||||||
|
|
||||||
|
if not user.email_verified:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified")
|
||||||
|
|
||||||
|
return TokenResponse(
|
||||||
|
access_token=create_access_token(str(user.id)),
|
||||||
|
refresh_token=create_refresh_token(str(user.id)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def request_password_reset(
|
||||||
|
db: AsyncSession,
|
||||||
|
payload: RequestPasswordResetRequest,
|
||||||
|
email_service: EmailService,
|
||||||
|
) -> None:
|
||||||
|
user = await get_user_by_email(db, payload.email)
|
||||||
|
if not user:
|
||||||
|
return
|
||||||
|
|
||||||
|
reset_token = generate_random_token()
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.password_reset_token_expire_hours)
|
||||||
|
await auth_repository.delete_password_reset_tokens_for_user(db, user.id)
|
||||||
|
await auth_repository.create_password_reset_token(db, user.id, reset_token, expires_at)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await email_service.send_password_reset_email(user.email, reset_token)
|
||||||
|
|
||||||
|
|
||||||
|
async def reset_password(db: AsyncSession, payload: ResetPasswordRequest) -> None:
|
||||||
|
record = await auth_repository.get_password_reset_token(db, payload.token)
|
||||||
|
if not record:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid reset token")
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
expires_at = record.expires_at if record.expires_at.tzinfo else record.expires_at.replace(tzinfo=timezone.utc)
|
||||||
|
if expires_at < now:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Reset token expired")
|
||||||
|
|
||||||
|
user = await get_user_by_id(db, record.user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||||
|
|
||||||
|
user.password_hash = hash_password(payload.new_password)
|
||||||
|
await auth_repository.delete_password_reset_tokens_for_user(db, user.id)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db)) -> User:
|
||||||
|
credentials_error = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload = decode_token(token)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise credentials_error from exc
|
||||||
|
|
||||||
|
if payload.get("type") != "access":
|
||||||
|
raise credentials_error
|
||||||
|
|
||||||
|
user_id = payload.get("sub")
|
||||||
|
if not user_id or not str(user_id).isdigit():
|
||||||
|
raise credentials_error
|
||||||
|
|
||||||
|
user = await get_user_by_id(db, int(user_id))
|
||||||
|
if not user:
|
||||||
|
raise credentials_error
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user_for_ws(token: str, db: AsyncSession) -> User:
|
||||||
|
try:
|
||||||
|
payload = decode_token(token)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") from exc
|
||||||
|
|
||||||
|
if payload.get("type") != "access":
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type")
|
||||||
|
|
||||||
|
user_id = payload.get("sub")
|
||||||
|
if not user_id or not str(user_id).isdigit():
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload")
|
||||||
|
|
||||||
|
user = await get_user_by_id(db, int(user_id))
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def get_email_sender() -> EmailService:
|
||||||
|
return get_email_service()
|
||||||
0
app/chats/__init__.py
Normal file
0
app/chats/__init__.py
Normal file
50
app/chats/models.py
Normal file
50
app/chats/models.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, Enum as SAEnum, ForeignKey, String, UniqueConstraint, func
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.database.base import Base
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.messages.models import Message
|
||||||
|
from app.users.models import User
|
||||||
|
|
||||||
|
|
||||||
|
class ChatType(str, Enum):
|
||||||
|
PRIVATE = "private"
|
||||||
|
GROUP = "group"
|
||||||
|
CHANNEL = "channel"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMemberRole(str, Enum):
|
||||||
|
OWNER = "owner"
|
||||||
|
ADMIN = "admin"
|
||||||
|
MEMBER = "member"
|
||||||
|
|
||||||
|
|
||||||
|
class Chat(Base):
|
||||||
|
__tablename__ = "chats"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||||
|
type: Mapped[ChatType] = mapped_column(SAEnum(ChatType), nullable=False, index=True)
|
||||||
|
title: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
|
||||||
|
members: Mapped[list["ChatMember"]] = relationship(back_populates="chat", cascade="all, delete-orphan")
|
||||||
|
messages: Mapped[list["Message"]] = relationship(back_populates="chat", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMember(Base):
|
||||||
|
__tablename__ = "chat_members"
|
||||||
|
__table_args__ = (UniqueConstraint("chat_id", "user_id", name="uq_chat_members_chat_id_user_id"),)
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||||
|
chat_id: Mapped[int] = mapped_column(ForeignKey("chats.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
role: Mapped[ChatMemberRole] = mapped_column(SAEnum(ChatMemberRole), nullable=False, default=ChatMemberRole.MEMBER)
|
||||||
|
joined_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
|
||||||
|
chat: Mapped["Chat"] = relationship(back_populates="members")
|
||||||
|
user: Mapped["User"] = relationship(back_populates="memberships")
|
||||||
62
app/chats/repository.py
Normal file
62
app/chats/repository.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from sqlalchemy import Select, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.chats.models import Chat, ChatMember, ChatMemberRole, ChatType
|
||||||
|
|
||||||
|
|
||||||
|
async def create_chat(db: AsyncSession, *, chat_type: ChatType, title: str | None) -> Chat:
|
||||||
|
chat = Chat(type=chat_type, title=title)
|
||||||
|
db.add(chat)
|
||||||
|
await db.flush()
|
||||||
|
return chat
|
||||||
|
|
||||||
|
|
||||||
|
async def add_chat_member(db: AsyncSession, *, chat_id: int, user_id: int, role: ChatMemberRole) -> ChatMember:
|
||||||
|
member = ChatMember(chat_id=chat_id, user_id=user_id, role=role)
|
||||||
|
db.add(member)
|
||||||
|
await db.flush()
|
||||||
|
return member
|
||||||
|
|
||||||
|
|
||||||
|
def _user_chats_query(user_id: int) -> Select[tuple[Chat]]:
|
||||||
|
return (
|
||||||
|
select(Chat)
|
||||||
|
.join(ChatMember, ChatMember.chat_id == Chat.id)
|
||||||
|
.where(ChatMember.user_id == user_id)
|
||||||
|
.order_by(Chat.id.desc())
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_user_chats(db: AsyncSession, *, user_id: int, limit: int = 50, before_id: int | None = None) -> list[Chat]:
|
||||||
|
query = _user_chats_query(user_id).limit(limit)
|
||||||
|
if before_id is not None:
|
||||||
|
query = query.where(Chat.id < before_id)
|
||||||
|
result = await db.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
|
||||||
|
async def get_chat_by_id(db: AsyncSession, chat_id: int) -> Chat | None:
|
||||||
|
result = await db.execute(select(Chat).where(Chat.id == chat_id))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_chat_member(db: AsyncSession, *, chat_id: int, user_id: int) -> ChatMember | None:
|
||||||
|
result = await db.execute(
|
||||||
|
select(ChatMember).where(
|
||||||
|
ChatMember.chat_id == chat_id,
|
||||||
|
ChatMember.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def list_chat_members(db: AsyncSession, *, chat_id: int) -> list[ChatMember]:
|
||||||
|
result = await db.execute(select(ChatMember).where(ChatMember.chat_id == chat_id).order_by(ChatMember.id.asc()))
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
|
||||||
|
async def list_user_chat_ids(db: AsyncSession, *, user_id: int) -> list[int]:
|
||||||
|
result = await db.execute(
|
||||||
|
select(ChatMember.chat_id).where(ChatMember.user_id == user_id).order_by(ChatMember.chat_id.asc())
|
||||||
|
)
|
||||||
|
return list(result.scalars().all())
|
||||||
45
app/chats/router.py
Normal file
45
app/chats/router.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.auth.service import get_current_user
|
||||||
|
from app.chats.schemas import ChatCreateRequest, ChatDetailRead, ChatRead
|
||||||
|
from app.chats.service import create_chat_for_user, get_chat_for_user, get_chats_for_user
|
||||||
|
from app.database.session import get_db
|
||||||
|
from app.users.models import User
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/chats", tags=["chats"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=list[ChatRead])
|
||||||
|
async def list_chats(
|
||||||
|
limit: int = 50,
|
||||||
|
before_id: int | None = None,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> list[ChatRead]:
|
||||||
|
return await get_chats_for_user(db, user_id=current_user.id, limit=limit, before_id=before_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=ChatRead)
|
||||||
|
async def create_chat(
|
||||||
|
payload: ChatCreateRequest,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> ChatRead:
|
||||||
|
return await create_chat_for_user(db, creator_id=current_user.id, payload=payload)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{chat_id}", response_model=ChatDetailRead)
|
||||||
|
async def get_chat(
|
||||||
|
chat_id: int,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> ChatDetailRead:
|
||||||
|
chat, members = await get_chat_for_user(db, chat_id=chat_id, user_id=current_user.id)
|
||||||
|
return ChatDetailRead(
|
||||||
|
id=chat.id,
|
||||||
|
type=chat.type,
|
||||||
|
title=chat.title,
|
||||||
|
created_at=chat.created_at,
|
||||||
|
members=members,
|
||||||
|
)
|
||||||
33
app/chats/schemas.py
Normal file
33
app/chats/schemas.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
from app.chats.models import ChatMemberRole, ChatType
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRead(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: int
|
||||||
|
type: ChatType
|
||||||
|
title: str | None = None
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMemberRead(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: int
|
||||||
|
user_id: int
|
||||||
|
role: ChatMemberRole
|
||||||
|
joined_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ChatDetailRead(ChatRead):
|
||||||
|
members: list[ChatMemberRead]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCreateRequest(BaseModel):
|
||||||
|
type: ChatType
|
||||||
|
title: str | None = Field(default=None, max_length=255)
|
||||||
|
member_ids: list[int] = Field(default_factory=list)
|
||||||
62
app/chats/service.py
Normal file
62
app/chats/service.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from fastapi import HTTPException, status
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.chats import repository
|
||||||
|
from app.chats.models import Chat, ChatMemberRole, ChatType
|
||||||
|
from app.chats.schemas import ChatCreateRequest
|
||||||
|
from app.users.repository import get_user_by_id
|
||||||
|
|
||||||
|
|
||||||
|
async def create_chat_for_user(db: AsyncSession, *, creator_id: int, payload: ChatCreateRequest) -> Chat:
|
||||||
|
member_ids = list(dict.fromkeys(payload.member_ids))
|
||||||
|
member_ids = [member_id for member_id in member_ids if member_id != creator_id]
|
||||||
|
|
||||||
|
if payload.type == ChatType.PRIVATE and len(member_ids) != 1:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail="Private chat requires exactly one target user.",
|
||||||
|
)
|
||||||
|
if payload.type in {ChatType.GROUP, ChatType.CHANNEL} and not payload.title:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail="Group and channel chats require title.",
|
||||||
|
)
|
||||||
|
|
||||||
|
for member_id in member_ids:
|
||||||
|
user = await get_user_by_id(db, member_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User {member_id} not found")
|
||||||
|
|
||||||
|
chat = await repository.create_chat(db, chat_type=payload.type, title=payload.title)
|
||||||
|
await repository.add_chat_member(db, chat_id=chat.id, user_id=creator_id, role=ChatMemberRole.OWNER)
|
||||||
|
|
||||||
|
default_role = ChatMemberRole.MEMBER
|
||||||
|
for member_id in member_ids:
|
||||||
|
await repository.add_chat_member(db, chat_id=chat.id, user_id=member_id, role=default_role)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
return chat
|
||||||
|
|
||||||
|
|
||||||
|
async def get_chats_for_user(db: AsyncSession, *, user_id: int, limit: int = 50, before_id: int | None = None) -> list[Chat]:
|
||||||
|
safe_limit = max(1, min(limit, 100))
|
||||||
|
return await repository.list_user_chats(db, user_id=user_id, limit=safe_limit, before_id=before_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_chat_for_user(db: AsyncSession, *, chat_id: int, user_id: int) -> tuple[Chat, list]:
|
||||||
|
chat = await repository.get_chat_by_id(db, chat_id)
|
||||||
|
if not chat:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
||||||
|
|
||||||
|
membership = await repository.get_chat_member(db, chat_id=chat_id, user_id=user_id)
|
||||||
|
if not membership:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You are not a member of this chat")
|
||||||
|
|
||||||
|
members = await repository.list_chat_members(db, chat_id=chat_id)
|
||||||
|
return chat, members
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_chat_membership(db: AsyncSession, *, chat_id: int, user_id: int) -> None:
|
||||||
|
membership = await repository.get_chat_member(db, chat_id=chat_id, user_id=user_id)
|
||||||
|
if not membership:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You are not a member of this chat")
|
||||||
0
app/config/__init__.py
Normal file
0
app/config/__init__.py
Normal file
41
app/config/settings.py
Normal file
41
app/config/settings.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
app_name: str = "BenyaMessenger"
|
||||||
|
environment: str = "development"
|
||||||
|
debug: bool = True
|
||||||
|
api_v1_prefix: str = "/api/v1"
|
||||||
|
auto_create_tables: bool = True
|
||||||
|
|
||||||
|
secret_key: str = Field(default="change-me-please-12345", min_length=16)
|
||||||
|
access_token_expire_minutes: int = 30
|
||||||
|
refresh_token_expire_days: int = 30
|
||||||
|
jwt_algorithm: str = "HS256"
|
||||||
|
email_verification_token_expire_hours: int = 24
|
||||||
|
password_reset_token_expire_hours: int = 1
|
||||||
|
|
||||||
|
postgres_dsn: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/messenger"
|
||||||
|
redis_url: str = "redis://localhost:6379/0"
|
||||||
|
|
||||||
|
s3_endpoint_url: str = "http://localhost:9000"
|
||||||
|
s3_access_key: str = "minioadmin"
|
||||||
|
s3_secret_key: str = "minioadmin"
|
||||||
|
s3_region: str = "us-east-1"
|
||||||
|
s3_bucket_name: str = "messenger-media"
|
||||||
|
s3_presign_expire_seconds: int = 900
|
||||||
|
max_upload_size_bytes: int = 104857600
|
||||||
|
frontend_base_url: str = "http://localhost:5173"
|
||||||
|
|
||||||
|
smtp_host: str = "localhost"
|
||||||
|
smtp_port: int = 1025
|
||||||
|
smtp_username: str = ""
|
||||||
|
smtp_password: str = ""
|
||||||
|
smtp_use_tls: bool = False
|
||||||
|
smtp_from_email: str = "no-reply@benyamessenger.local"
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
0
app/database/__init__.py
Normal file
0
app/database/__init__.py
Normal file
15
app/database/base.py
Normal file
15
app/database/base.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from sqlalchemy import MetaData
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
|
||||||
|
NAMING_CONVENTION = {
|
||||||
|
"ix": "ix_%(column_0_label)s",
|
||||||
|
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||||
|
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||||
|
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||||
|
"pk": "pk_%(table_name)s",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
metadata = MetaData(naming_convention=NAMING_CONVENTION)
|
||||||
19
app/database/models.py
Normal file
19
app/database/models.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from app.auth.models import EmailVerificationToken, PasswordResetToken
|
||||||
|
from app.chats.models import Chat, ChatMember
|
||||||
|
from app.email.models import EmailLog
|
||||||
|
from app.media.models import Attachment
|
||||||
|
from app.messages.models import Message
|
||||||
|
from app.notifications.models import NotificationLog
|
||||||
|
from app.users.models import User
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Attachment",
|
||||||
|
"Chat",
|
||||||
|
"ChatMember",
|
||||||
|
"EmailLog",
|
||||||
|
"EmailVerificationToken",
|
||||||
|
"Message",
|
||||||
|
"NotificationLog",
|
||||||
|
"PasswordResetToken",
|
||||||
|
"User",
|
||||||
|
]
|
||||||
25
app/database/session.py
Normal file
25
app/database/session.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
settings.postgres_dsn,
|
||||||
|
echo=settings.debug,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
AsyncSessionLocal = async_sessionmaker(
|
||||||
|
bind=engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
autoflush=False,
|
||||||
|
autocommit=False,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db() -> AsyncIterator[AsyncSession]:
|
||||||
|
async with AsyncSessionLocal() as session:
|
||||||
|
yield session
|
||||||
1
app/email/__init__.py
Normal file
1
app/email/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
16
app/email/models.py
Normal file
16
app/email/models.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, String, Text, func
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from app.database.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class EmailLog(Base):
|
||||||
|
__tablename__ = "email_logs"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||||
|
recipient: Mapped[str] = mapped_column(String(255), index=True, nullable=False)
|
||||||
|
subject: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
body: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
7
app/email/repository.py
Normal file
7
app/email/repository.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.email.models import EmailLog
|
||||||
|
|
||||||
|
|
||||||
|
async def create_email_log(db: AsyncSession, *, recipient: str, subject: str, body: str) -> None:
|
||||||
|
db.add(EmailLog(recipient=recipient, subject=subject, body=body))
|
||||||
3
app/email/router.py
Normal file
3
app/email/router.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/email", tags=["email"])
|
||||||
7
app/email/schemas.py
Normal file
7
app/email/schemas.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from pydantic import BaseModel, EmailStr
|
||||||
|
|
||||||
|
|
||||||
|
class EmailPayload(BaseModel):
|
||||||
|
recipient: EmailStr
|
||||||
|
subject: str
|
||||||
|
body: str
|
||||||
23
app/email/service.py
Normal file
23
app/email/service.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmailService:
|
||||||
|
async def send_verification_email(self, email: str, token: str) -> None:
|
||||||
|
verify_link = f"{settings.frontend_base_url}/verify-email?token={token}"
|
||||||
|
subject = "Verify your BenyaMessenger account"
|
||||||
|
body = f"Open this link to verify your account: {verify_link}"
|
||||||
|
logger.info("EMAIL to=%s subject=%s body=%s", email, subject, body)
|
||||||
|
|
||||||
|
async def send_password_reset_email(self, email: str, token: str) -> None:
|
||||||
|
reset_link = f"{settings.frontend_base_url}/reset-password?token={token}"
|
||||||
|
subject = "Reset your BenyaMessenger password"
|
||||||
|
body = f"Open this link to reset your password: {reset_link}"
|
||||||
|
logger.info("EMAIL to=%s subject=%s body=%s", email, subject, body)
|
||||||
|
|
||||||
|
|
||||||
|
def get_email_service() -> EmailService:
|
||||||
|
return EmailService()
|
||||||
43
app/main.py
Normal file
43
app/main.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from app.auth.router import router as auth_router
|
||||||
|
from app.chats.router import router as chats_router
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.database import models # noqa: F401
|
||||||
|
from app.database.base import Base
|
||||||
|
from app.database.session import engine
|
||||||
|
from app.media.router import router as media_router
|
||||||
|
from app.messages.router import router as messages_router
|
||||||
|
from app.notifications.router import router as notifications_router
|
||||||
|
from app.realtime.router import router as realtime_router
|
||||||
|
from app.realtime.service import realtime_gateway
|
||||||
|
from app.users.router import router as users_router
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(_app: FastAPI):
|
||||||
|
await realtime_gateway.start()
|
||||||
|
if settings.auto_create_tables:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield
|
||||||
|
await realtime_gateway.stop()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title=settings.app_name, debug=settings.debug, lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health", tags=["health"])
|
||||||
|
async def health() -> dict[str, str]:
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
app.include_router(auth_router, prefix=settings.api_v1_prefix)
|
||||||
|
app.include_router(users_router, prefix=settings.api_v1_prefix)
|
||||||
|
app.include_router(chats_router, prefix=settings.api_v1_prefix)
|
||||||
|
app.include_router(messages_router, prefix=settings.api_v1_prefix)
|
||||||
|
app.include_router(media_router, prefix=settings.api_v1_prefix)
|
||||||
|
app.include_router(notifications_router, prefix=settings.api_v1_prefix)
|
||||||
|
app.include_router(realtime_router, prefix=settings.api_v1_prefix)
|
||||||
0
app/media/__init__.py
Normal file
0
app/media/__init__.py
Normal file
16
app/media/models.py
Normal file
16
app/media/models.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from sqlalchemy import ForeignKey, Integer, String
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.database.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Attachment(Base):
|
||||||
|
__tablename__ = "attachments"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||||
|
message_id: Mapped[int] = mapped_column(ForeignKey("messages.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
file_url: Mapped[str] = mapped_column(String(1024), nullable=False)
|
||||||
|
file_type: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
file_size: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
|
||||||
|
message = relationship("Message", back_populates="attachments")
|
||||||
26
app/media/repository.py
Normal file
26
app/media/repository.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.media.models import Attachment
|
||||||
|
|
||||||
|
|
||||||
|
async def create_attachment(
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
message_id: int,
|
||||||
|
file_url: str,
|
||||||
|
file_type: str,
|
||||||
|
file_size: int,
|
||||||
|
) -> Attachment:
|
||||||
|
attachment = Attachment(
|
||||||
|
message_id=message_id,
|
||||||
|
file_url=file_url,
|
||||||
|
file_type=file_type,
|
||||||
|
file_size=file_size,
|
||||||
|
)
|
||||||
|
db.add(attachment)
|
||||||
|
await db.flush()
|
||||||
|
return attachment
|
||||||
|
|
||||||
|
|
||||||
|
async def get_attachment_by_id(db: AsyncSession, attachment_id: int) -> Attachment | None:
|
||||||
|
return await db.get(Attachment, attachment_id)
|
||||||
27
app/media/router.py
Normal file
27
app/media/router.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.auth.service import get_current_user
|
||||||
|
from app.database.session import get_db
|
||||||
|
from app.media.schemas import AttachmentCreateRequest, AttachmentRead, UploadUrlRequest, UploadUrlResponse
|
||||||
|
from app.media.service import generate_upload_url, store_attachment_metadata
|
||||||
|
from app.users.models import User
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/media", tags=["media"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/upload-url", response_model=UploadUrlResponse)
|
||||||
|
async def create_upload_url(
|
||||||
|
payload: UploadUrlRequest,
|
||||||
|
_current_user: User = Depends(get_current_user),
|
||||||
|
) -> UploadUrlResponse:
|
||||||
|
return await generate_upload_url(payload)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/attachments", response_model=AttachmentRead)
|
||||||
|
async def create_attachment_metadata(
|
||||||
|
payload: AttachmentCreateRequest,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> AttachmentRead:
|
||||||
|
return await store_attachment_metadata(db, user_id=current_user.id, payload=payload)
|
||||||
32
app/media/schemas.py
Normal file
32
app/media/schemas.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class UploadUrlRequest(BaseModel):
|
||||||
|
file_name: str = Field(min_length=1, max_length=255)
|
||||||
|
file_type: str = Field(min_length=1, max_length=64)
|
||||||
|
file_size: int = Field(gt=0)
|
||||||
|
|
||||||
|
|
||||||
|
class UploadUrlResponse(BaseModel):
|
||||||
|
upload_url: str
|
||||||
|
file_url: str
|
||||||
|
object_key: str
|
||||||
|
expires_in: int
|
||||||
|
required_headers: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class AttachmentCreateRequest(BaseModel):
|
||||||
|
message_id: int
|
||||||
|
file_url: str = Field(min_length=1, max_length=1024)
|
||||||
|
file_type: str = Field(min_length=1, max_length=64)
|
||||||
|
file_size: int = Field(gt=0)
|
||||||
|
|
||||||
|
|
||||||
|
class AttachmentRead(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: int
|
||||||
|
message_id: int
|
||||||
|
file_url: str
|
||||||
|
file_type: str
|
||||||
|
file_size: int
|
||||||
127
app/media/service.py
Normal file
127
app/media/service.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
import re
|
||||||
|
from urllib.parse import quote
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from botocore.client import Config
|
||||||
|
from botocore.exceptions import BotoCoreError, ClientError
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.media import repository
|
||||||
|
from app.media.schemas import AttachmentCreateRequest, AttachmentRead, UploadUrlRequest, UploadUrlResponse
|
||||||
|
from app.messages.repository import get_message_by_id
|
||||||
|
|
||||||
|
ALLOWED_MIME_TYPES = {
|
||||||
|
"image/jpeg",
|
||||||
|
"image/png",
|
||||||
|
"image/webp",
|
||||||
|
"video/mp4",
|
||||||
|
"video/webm",
|
||||||
|
"audio/mpeg",
|
||||||
|
"audio/ogg",
|
||||||
|
"audio/wav",
|
||||||
|
"application/pdf",
|
||||||
|
"application/zip",
|
||||||
|
"text/plain",
|
||||||
|
}
|
||||||
|
|
||||||
|
_SAFE_NAME_RE = re.compile(r"[^a-zA-Z0-9._-]+")
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_filename(file_name: str) -> str:
|
||||||
|
sanitized = _SAFE_NAME_RE.sub("_", file_name).strip("._")
|
||||||
|
if not sanitized:
|
||||||
|
sanitized = "file.bin"
|
||||||
|
return sanitized[:120]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_file_url(bucket: str, object_key: str) -> str:
|
||||||
|
base = settings.s3_endpoint_url.rstrip("/")
|
||||||
|
encoded_key = quote(object_key)
|
||||||
|
return f"{base}/{bucket}/{encoded_key}"
|
||||||
|
|
||||||
|
|
||||||
|
def _allowed_file_url_prefix() -> str:
|
||||||
|
return f"{settings.s3_endpoint_url.rstrip('/')}/{settings.s3_bucket_name}/"
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_media(file_type: str, file_size: int) -> None:
|
||||||
|
if file_type not in ALLOWED_MIME_TYPES:
|
||||||
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Unsupported file type")
|
||||||
|
if file_size > settings.max_upload_size_bytes:
|
||||||
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="File size exceeds limit")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_s3_client():
|
||||||
|
return boto3.client(
|
||||||
|
"s3",
|
||||||
|
endpoint_url=settings.s3_endpoint_url,
|
||||||
|
aws_access_key_id=settings.s3_access_key,
|
||||||
|
aws_secret_access_key=settings.s3_secret_key,
|
||||||
|
region_name=settings.s3_region,
|
||||||
|
config=Config(signature_version="s3v4", s3={"addressing_style": "path"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_upload_url(payload: UploadUrlRequest) -> UploadUrlResponse:
|
||||||
|
_validate_media(payload.file_type, payload.file_size)
|
||||||
|
|
||||||
|
file_name = _sanitize_filename(payload.file_name)
|
||||||
|
object_key = f"uploads/{uuid4()}-{file_name}"
|
||||||
|
bucket = settings.s3_bucket_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
s3_client = _get_s3_client()
|
||||||
|
upload_url = s3_client.generate_presigned_url(
|
||||||
|
"put_object",
|
||||||
|
Params={
|
||||||
|
"Bucket": bucket,
|
||||||
|
"Key": object_key,
|
||||||
|
"ContentType": payload.file_type,
|
||||||
|
},
|
||||||
|
ExpiresIn=settings.s3_presign_expire_seconds,
|
||||||
|
HttpMethod="PUT",
|
||||||
|
)
|
||||||
|
except (BotoCoreError, ClientError) as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Storage service unavailable") from exc
|
||||||
|
|
||||||
|
return UploadUrlResponse(
|
||||||
|
upload_url=upload_url,
|
||||||
|
file_url=_build_file_url(bucket, object_key),
|
||||||
|
object_key=object_key,
|
||||||
|
expires_in=settings.s3_presign_expire_seconds,
|
||||||
|
required_headers={"Content-Type": payload.file_type},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def store_attachment_metadata(
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_id: int,
|
||||||
|
payload: AttachmentCreateRequest,
|
||||||
|
) -> AttachmentRead:
|
||||||
|
_validate_media(payload.file_type, payload.file_size)
|
||||||
|
if not payload.file_url.startswith(_allowed_file_url_prefix()):
|
||||||
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Invalid file URL")
|
||||||
|
|
||||||
|
message = await get_message_by_id(db, payload.message_id)
|
||||||
|
if not message:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
||||||
|
if message.sender_id != user_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Only the message sender can attach files",
|
||||||
|
)
|
||||||
|
|
||||||
|
attachment = await repository.create_attachment(
|
||||||
|
db,
|
||||||
|
message_id=payload.message_id,
|
||||||
|
file_url=payload.file_url,
|
||||||
|
file_type=payload.file_type,
|
||||||
|
file_size=payload.file_size,
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(attachment)
|
||||||
|
return AttachmentRead.model_validate(attachment)
|
||||||
0
app/messages/__init__.py
Normal file
0
app/messages/__init__.py
Normal file
44
app/messages/models.py
Normal file
44
app/messages/models.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, Enum as SAEnum, ForeignKey, Text, func
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.database.base import Base
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.chats.models import Chat
|
||||||
|
from app.media.models import Attachment
|
||||||
|
from app.users.models import User
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(str, Enum):
|
||||||
|
TEXT = "text"
|
||||||
|
IMAGE = "image"
|
||||||
|
VIDEO = "video"
|
||||||
|
AUDIO = "audio"
|
||||||
|
VOICE = "voice"
|
||||||
|
FILE = "file"
|
||||||
|
CIRCLE_VIDEO = "circle_video"
|
||||||
|
|
||||||
|
|
||||||
|
class Message(Base):
|
||||||
|
__tablename__ = "messages"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||||
|
chat_id: Mapped[int] = mapped_column(ForeignKey("chats.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
sender_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
type: Mapped[MessageType] = mapped_column(SAEnum(MessageType), nullable=False, default=MessageType.TEXT)
|
||||||
|
text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
chat: Mapped["Chat"] = relationship(back_populates="messages")
|
||||||
|
sender: Mapped["User"] = relationship(back_populates="sent_messages")
|
||||||
|
attachments: Mapped[list["Attachment"]] = relationship(back_populates="message", cascade="all, delete-orphan")
|
||||||
41
app/messages/repository.py
Normal file
41
app/messages/repository.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.messages.models import Message, MessageType
|
||||||
|
|
||||||
|
|
||||||
|
async def create_message(
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
chat_id: int,
|
||||||
|
sender_id: int,
|
||||||
|
message_type: MessageType,
|
||||||
|
text: str | None,
|
||||||
|
) -> Message:
|
||||||
|
message = Message(chat_id=chat_id, sender_id=sender_id, type=message_type, text=text)
|
||||||
|
db.add(message)
|
||||||
|
await db.flush()
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
async def get_message_by_id(db: AsyncSession, message_id: int) -> Message | None:
|
||||||
|
result = await db.execute(select(Message).where(Message.id == message_id))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def list_chat_messages(
|
||||||
|
db: AsyncSession,
|
||||||
|
chat_id: int,
|
||||||
|
*,
|
||||||
|
limit: int = 50,
|
||||||
|
before_id: int | None = None,
|
||||||
|
) -> list[Message]:
|
||||||
|
query = select(Message).where(Message.chat_id == chat_id)
|
||||||
|
if before_id is not None:
|
||||||
|
query = query.where(Message.id < before_id)
|
||||||
|
result = await db.execute(query.order_by(Message.id.desc()).limit(limit))
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_message(db: AsyncSession, message: Message) -> None:
|
||||||
|
await db.delete(message)
|
||||||
49
app/messages/router.py
Normal file
49
app/messages/router.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
from fastapi import APIRouter, Depends, status
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.auth.service import get_current_user
|
||||||
|
from app.database.session import get_db
|
||||||
|
from app.messages.schemas import MessageCreateRequest, MessageRead, MessageUpdateRequest
|
||||||
|
from app.messages.service import create_chat_message, delete_message, get_messages, update_message
|
||||||
|
from app.users.models import User
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/messages", tags=["messages"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=MessageRead, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_message(
|
||||||
|
payload: MessageCreateRequest,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> MessageRead:
|
||||||
|
return await create_chat_message(db, sender_id=current_user.id, payload=payload)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{chat_id}", response_model=list[MessageRead])
|
||||||
|
async def list_messages(
|
||||||
|
chat_id: int,
|
||||||
|
limit: int = 50,
|
||||||
|
before_id: int | None = None,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> list[MessageRead]:
|
||||||
|
return await get_messages(db, chat_id=chat_id, user_id=current_user.id, limit=limit, before_id=before_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{message_id}", response_model=MessageRead)
|
||||||
|
async def edit_message(
|
||||||
|
message_id: int,
|
||||||
|
payload: MessageUpdateRequest,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> MessageRead:
|
||||||
|
return await update_message(db, message_id=message_id, user_id=current_user.id, payload=payload)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{message_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def remove_message(
|
||||||
|
message_id: int,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> None:
|
||||||
|
await delete_message(db, message_id=message_id, user_id=current_user.id)
|
||||||
27
app/messages/schemas.py
Normal file
27
app/messages/schemas.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
from app.messages.models import MessageType
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRead(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: int
|
||||||
|
chat_id: int
|
||||||
|
sender_id: int
|
||||||
|
type: MessageType
|
||||||
|
text: str | None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class MessageCreateRequest(BaseModel):
|
||||||
|
chat_id: int
|
||||||
|
type: MessageType = MessageType.TEXT
|
||||||
|
text: str | None = Field(default=None, max_length=4096)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageUpdateRequest(BaseModel):
|
||||||
|
text: str = Field(min_length=1, max_length=4096)
|
||||||
67
app/messages/service.py
Normal file
67
app/messages/service.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
from fastapi import HTTPException, status
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.chats.service import ensure_chat_membership
|
||||||
|
from app.messages import repository
|
||||||
|
from app.messages.models import Message
|
||||||
|
from app.messages.schemas import MessageCreateRequest, MessageUpdateRequest
|
||||||
|
|
||||||
|
|
||||||
|
async def create_chat_message(db: AsyncSession, *, sender_id: int, payload: MessageCreateRequest) -> Message:
|
||||||
|
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=sender_id)
|
||||||
|
if payload.type.value == "text" and not (payload.text and payload.text.strip()):
|
||||||
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Text message cannot be empty")
|
||||||
|
|
||||||
|
message = await repository.create_message(
|
||||||
|
db,
|
||||||
|
chat_id=payload.chat_id,
|
||||||
|
sender_id=sender_id,
|
||||||
|
message_type=payload.type,
|
||||||
|
text=payload.text,
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(message)
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
async def get_messages(
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
chat_id: int,
|
||||||
|
user_id: int,
|
||||||
|
limit: int = 50,
|
||||||
|
before_id: int | None = None,
|
||||||
|
) -> list[Message]:
|
||||||
|
await ensure_chat_membership(db, chat_id=chat_id, user_id=user_id)
|
||||||
|
safe_limit = max(1, min(limit, 100))
|
||||||
|
return await repository.list_chat_messages(db, chat_id, limit=safe_limit, before_id=before_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def update_message(
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
message_id: int,
|
||||||
|
user_id: int,
|
||||||
|
payload: MessageUpdateRequest,
|
||||||
|
) -> Message:
|
||||||
|
message = await repository.get_message_by_id(db, message_id)
|
||||||
|
if not message:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
||||||
|
await ensure_chat_membership(db, chat_id=message.chat_id, user_id=user_id)
|
||||||
|
if message.sender_id != user_id:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You can edit only your own messages")
|
||||||
|
message.text = payload.text
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(message)
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_message(db: AsyncSession, *, message_id: int, user_id: int) -> None:
|
||||||
|
message = await repository.get_message_by_id(db, message_id)
|
||||||
|
if not message:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
||||||
|
await ensure_chat_membership(db, chat_id=message.chat_id, user_id=user_id)
|
||||||
|
if message.sender_id != user_id:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You can delete only your own messages")
|
||||||
|
await repository.delete_message(db, message)
|
||||||
|
await db.commit()
|
||||||
0
app/notifications/__init__.py
Normal file
0
app/notifications/__init__.py
Normal file
16
app/notifications/models.py
Normal file
16
app/notifications/models.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, String, func
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from app.database.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationLog(Base):
|
||||||
|
__tablename__ = "notification_logs"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(index=True)
|
||||||
|
event_type: Mapped[str] = mapped_column(String(64), index=True)
|
||||||
|
payload: Mapped[str] = mapped_column(String(1024))
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
7
app/notifications/repository.py
Normal file
7
app/notifications/repository.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.notifications.models import NotificationLog
|
||||||
|
|
||||||
|
|
||||||
|
async def create_notification_log(db: AsyncSession, *, user_id: int, event_type: str, payload: str) -> None:
|
||||||
|
db.add(NotificationLog(user_id=user_id, event_type=event_type, payload=payload))
|
||||||
3
app/notifications/router.py
Normal file
3
app/notifications/router.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||||
7
app/notifications/schemas.py
Normal file
7
app/notifications/schemas.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationRequest(BaseModel):
|
||||||
|
user_id: int
|
||||||
|
event_type: str
|
||||||
|
payload: dict
|
||||||
13
app/notifications/service.py
Normal file
13
app/notifications/service.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.notifications.repository import create_notification_log
|
||||||
|
from app.notifications.schemas import NotificationRequest
|
||||||
|
|
||||||
|
|
||||||
|
async def enqueue_notification(db: AsyncSession, payload: NotificationRequest) -> None:
|
||||||
|
await create_notification_log(
|
||||||
|
db,
|
||||||
|
user_id=payload.user_id,
|
||||||
|
event_type=payload.event_type,
|
||||||
|
payload=payload.payload.__repr__(),
|
||||||
|
)
|
||||||
0
app/realtime/__init__.py
Normal file
0
app/realtime/__init__.py
Normal file
10
app/realtime/models.py
Normal file
10
app/realtime/models.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ConnectionContext:
|
||||||
|
user_id: int
|
||||||
|
connection_id: str
|
||||||
|
websocket: WebSocket
|
||||||
48
app/realtime/repository.py
Normal file
48
app/realtime/repository.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import json
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class RedisRealtimeRepository:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._redis: Redis | None = None
|
||||||
|
self._pubsub = None
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
if self._redis:
|
||||||
|
return
|
||||||
|
self._redis = Redis.from_url(settings.redis_url, decode_responses=True)
|
||||||
|
self._pubsub = self._redis.pubsub()
|
||||||
|
await self._pubsub.psubscribe("chat:*")
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._pubsub:
|
||||||
|
await self._pubsub.close()
|
||||||
|
self._pubsub = None
|
||||||
|
if self._redis:
|
||||||
|
await self._redis.aclose()
|
||||||
|
self._redis = None
|
||||||
|
|
||||||
|
async def publish_event(self, channel: str, payload: dict) -> None:
|
||||||
|
if not self._redis:
|
||||||
|
await self.connect()
|
||||||
|
assert self._redis is not None
|
||||||
|
await self._redis.publish(channel, json.dumps(payload))
|
||||||
|
|
||||||
|
async def consume(self, handler: Callable[[str, dict], Awaitable[None]]) -> None:
|
||||||
|
if not self._pubsub:
|
||||||
|
await self.connect()
|
||||||
|
assert self._pubsub is not None
|
||||||
|
while True:
|
||||||
|
message = await self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
||||||
|
if not message:
|
||||||
|
continue
|
||||||
|
channel = message.get("channel")
|
||||||
|
data = message.get("data")
|
||||||
|
if not channel or not isinstance(data, str):
|
||||||
|
continue
|
||||||
|
payload = json.loads(data)
|
||||||
|
await handler(channel, payload)
|
||||||
76
app/realtime/router.py
Normal file
76
app/realtime/router.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, status
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from app.auth.service import get_current_user_for_ws
|
||||||
|
from app.database.session import AsyncSessionLocal
|
||||||
|
from app.realtime.schemas import (
|
||||||
|
ChatEventPayload,
|
||||||
|
IncomingRealtimeEvent,
|
||||||
|
MessageStatusPayload,
|
||||||
|
OutgoingRealtimeEvent,
|
||||||
|
SendMessagePayload,
|
||||||
|
)
|
||||||
|
from app.realtime.service import realtime_gateway
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/realtime", tags=["realtime"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/ws")
|
||||||
|
async def websocket_gateway(websocket: WebSocket) -> None:
|
||||||
|
token = websocket.query_params.get("token")
|
||||||
|
if not token:
|
||||||
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||||
|
return
|
||||||
|
|
||||||
|
async with AsyncSessionLocal() as db:
|
||||||
|
try:
|
||||||
|
user = await get_current_user_for_ws(token, db)
|
||||||
|
except Exception:
|
||||||
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||||
|
return
|
||||||
|
|
||||||
|
user_chat_ids = await realtime_gateway.load_user_chat_ids(db, user.id)
|
||||||
|
await websocket.accept()
|
||||||
|
connection_id = await realtime_gateway.register(user.id, websocket, user_chat_ids)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
raw_data = await websocket.receive_json()
|
||||||
|
try:
|
||||||
|
event = IncomingRealtimeEvent.model_validate(raw_data)
|
||||||
|
await _dispatch_event(db, user.id, event)
|
||||||
|
except ValidationError:
|
||||||
|
await websocket.send_json(
|
||||||
|
OutgoingRealtimeEvent(
|
||||||
|
event="error",
|
||||||
|
payload={"detail": "Invalid event payload"},
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
).model_dump(mode="json")
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
await websocket.send_json(
|
||||||
|
OutgoingRealtimeEvent(
|
||||||
|
event="error",
|
||||||
|
payload={"detail": str(exc)},
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
).model_dump(mode="json")
|
||||||
|
)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
await realtime_gateway.unregister(user.id, connection_id, user_chat_ids)
|
||||||
|
|
||||||
|
|
||||||
|
async def _dispatch_event(db, user_id: int, event: IncomingRealtimeEvent) -> None:
|
||||||
|
if event.event == "send_message":
|
||||||
|
payload = SendMessagePayload.model_validate(event.payload)
|
||||||
|
await realtime_gateway.handle_send_message(db, user_id, payload)
|
||||||
|
return
|
||||||
|
if event.event in {"typing_start", "typing_stop"}:
|
||||||
|
payload = ChatEventPayload.model_validate(event.payload)
|
||||||
|
await realtime_gateway.handle_typing_event(db, user_id, payload, event.event)
|
||||||
|
return
|
||||||
|
if event.event in {"message_read", "message_delivered"}:
|
||||||
|
payload = MessageStatusPayload.model_validate(event.payload)
|
||||||
|
await realtime_gateway.handle_message_status(db, user_id, payload, event.event)
|
||||||
|
return
|
||||||
48
app/realtime/schemas.py
Normal file
48
app/realtime/schemas.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
from app.messages.models import MessageType
|
||||||
|
|
||||||
|
|
||||||
|
RealtimeEventName = Literal[
|
||||||
|
"connect",
|
||||||
|
"disconnect",
|
||||||
|
"send_message",
|
||||||
|
"receive_message",
|
||||||
|
"typing_start",
|
||||||
|
"typing_stop",
|
||||||
|
"message_read",
|
||||||
|
"message_delivered",
|
||||||
|
"error",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SendMessagePayload(BaseModel):
|
||||||
|
chat_id: int
|
||||||
|
type: MessageType = MessageType.TEXT
|
||||||
|
text: str | None = Field(default=None, max_length=4096)
|
||||||
|
temp_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatEventPayload(BaseModel):
|
||||||
|
chat_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class MessageStatusPayload(BaseModel):
|
||||||
|
chat_id: int
|
||||||
|
message_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class IncomingRealtimeEvent(BaseModel):
|
||||||
|
event: Literal["send_message", "typing_start", "typing_stop", "message_read", "message_delivered"]
|
||||||
|
payload: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class OutgoingRealtimeEvent(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
event: RealtimeEventName
|
||||||
|
payload: dict[str, Any]
|
||||||
|
timestamp: datetime
|
||||||
178
app/realtime/service.py
Normal file
178
app/realtime/service.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from fastapi import WebSocket
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.chats.repository import list_user_chat_ids
|
||||||
|
from app.chats.service import ensure_chat_membership
|
||||||
|
from app.messages.schemas import MessageCreateRequest, MessageRead
|
||||||
|
from app.messages.service import create_chat_message
|
||||||
|
from app.realtime.models import ConnectionContext
|
||||||
|
from app.realtime.repository import RedisRealtimeRepository
|
||||||
|
from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload
|
||||||
|
|
||||||
|
|
||||||
|
class RealtimeGateway:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._repo = RedisRealtimeRepository()
|
||||||
|
self._consume_task: asyncio.Task | None = None
|
||||||
|
self._distributed_enabled = False
|
||||||
|
self._connections: dict[int, dict[str, ConnectionContext]] = defaultdict(dict)
|
||||||
|
self._chat_subscribers: dict[int, set[int]] = defaultdict(set)
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
try:
|
||||||
|
await self._repo.connect()
|
||||||
|
if not self._consume_task:
|
||||||
|
self._consume_task = asyncio.create_task(self._repo.consume(self._handle_redis_event))
|
||||||
|
self._distributed_enabled = True
|
||||||
|
except Exception:
|
||||||
|
self._distributed_enabled = False
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
if self._consume_task:
|
||||||
|
self._consume_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await self._consume_task
|
||||||
|
self._consume_task = None
|
||||||
|
await self._repo.close()
|
||||||
|
self._distributed_enabled = False
|
||||||
|
|
||||||
|
async def register(self, user_id: int, websocket: WebSocket, user_chat_ids: list[int]) -> str:
|
||||||
|
connection_id = str(uuid4())
|
||||||
|
self._connections[user_id][connection_id] = ConnectionContext(
|
||||||
|
user_id=user_id,
|
||||||
|
connection_id=connection_id,
|
||||||
|
websocket=websocket,
|
||||||
|
)
|
||||||
|
for chat_id in user_chat_ids:
|
||||||
|
self._chat_subscribers[chat_id].add(user_id)
|
||||||
|
await self._send_user_event(
|
||||||
|
user_id,
|
||||||
|
OutgoingRealtimeEvent(
|
||||||
|
event="connect",
|
||||||
|
payload={"connection_id": connection_id},
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return connection_id
|
||||||
|
|
||||||
|
async def unregister(self, user_id: int, connection_id: str, user_chat_ids: list[int]) -> None:
|
||||||
|
user_connections = self._connections.get(user_id, {})
|
||||||
|
user_connections.pop(connection_id, None)
|
||||||
|
if not user_connections:
|
||||||
|
self._connections.pop(user_id, None)
|
||||||
|
for chat_id in user_chat_ids:
|
||||||
|
subscribers = self._chat_subscribers.get(chat_id)
|
||||||
|
if not subscribers:
|
||||||
|
continue
|
||||||
|
subscribers.discard(user_id)
|
||||||
|
if not subscribers:
|
||||||
|
self._chat_subscribers.pop(chat_id, None)
|
||||||
|
|
||||||
|
async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None:
|
||||||
|
message = await create_chat_message(
|
||||||
|
db,
|
||||||
|
sender_id=user_id,
|
||||||
|
payload=MessageCreateRequest(chat_id=payload.chat_id, type=payload.type, text=payload.text),
|
||||||
|
)
|
||||||
|
message_data = MessageRead.model_validate(message).model_dump(mode="json")
|
||||||
|
await self._publish_chat_event(
|
||||||
|
payload.chat_id,
|
||||||
|
event="receive_message",
|
||||||
|
payload={
|
||||||
|
"chat_id": payload.chat_id,
|
||||||
|
"message": message_data,
|
||||||
|
"temp_id": payload.temp_id,
|
||||||
|
"sender_id": user_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handle_typing_event(self, db: AsyncSession, user_id: int, payload: ChatEventPayload, event: str) -> None:
|
||||||
|
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
|
||||||
|
await self._publish_chat_event(
|
||||||
|
payload.chat_id,
|
||||||
|
event=event,
|
||||||
|
payload={"chat_id": payload.chat_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handle_message_status(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: int,
|
||||||
|
payload: MessageStatusPayload,
|
||||||
|
event: str,
|
||||||
|
) -> None:
|
||||||
|
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=user_id)
|
||||||
|
await self._publish_chat_event(
|
||||||
|
payload.chat_id,
|
||||||
|
event=event,
|
||||||
|
payload={
|
||||||
|
"chat_id": payload.chat_id,
|
||||||
|
"message_id": payload.message_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_user_chat_ids(self, db: AsyncSession, user_id: int) -> list[int]:
|
||||||
|
return await list_user_chat_ids(db, user_id=user_id)
|
||||||
|
|
||||||
|
async def _handle_redis_event(self, channel: str, payload: dict) -> None:
|
||||||
|
chat_id = self._extract_chat_id(channel)
|
||||||
|
if chat_id is None:
|
||||||
|
return
|
||||||
|
subscribers = self._chat_subscribers.get(chat_id, set())
|
||||||
|
if not subscribers:
|
||||||
|
return
|
||||||
|
event = OutgoingRealtimeEvent(
|
||||||
|
event=payload.get("event", "error"),
|
||||||
|
payload=payload.get("payload", {}),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
await asyncio.gather(*(self._send_user_event(user_id, event) for user_id in subscribers), return_exceptions=True)
|
||||||
|
|
||||||
|
async def _publish_chat_event(self, chat_id: int, *, event: str, payload: dict) -> None:
|
||||||
|
event_payload = {
|
||||||
|
"event": event,
|
||||||
|
"payload": payload,
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
if self._distributed_enabled:
|
||||||
|
await self._repo.publish_event(f"chat:{chat_id}", event_payload)
|
||||||
|
return
|
||||||
|
await self._handle_redis_event(f"chat:{chat_id}", event_payload)
|
||||||
|
|
||||||
|
async def _send_user_event(self, user_id: int, event: OutgoingRealtimeEvent) -> None:
|
||||||
|
user_connections = self._connections.get(user_id, {})
|
||||||
|
if not user_connections:
|
||||||
|
return
|
||||||
|
disconnected: list[str] = []
|
||||||
|
for connection_id, context in user_connections.items():
|
||||||
|
try:
|
||||||
|
await context.websocket.send_json(event.model_dump(mode="json"))
|
||||||
|
except Exception:
|
||||||
|
disconnected.append(connection_id)
|
||||||
|
for connection_id in disconnected:
|
||||||
|
user_connections.pop(connection_id, None)
|
||||||
|
if not user_connections:
|
||||||
|
self._connections.pop(user_id, None)
|
||||||
|
for chat_id, subscribers in list(self._chat_subscribers.items()):
|
||||||
|
subscribers.discard(user_id)
|
||||||
|
if not subscribers:
|
||||||
|
self._chat_subscribers.pop(chat_id, None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_chat_id(channel: str) -> int | None:
|
||||||
|
if not channel.startswith("chat:"):
|
||||||
|
return None
|
||||||
|
chat_id = channel.split(":", maxsplit=1)[1]
|
||||||
|
if not chat_id.isdigit():
|
||||||
|
return None
|
||||||
|
return int(chat_id)
|
||||||
|
|
||||||
|
|
||||||
|
realtime_gateway = RealtimeGateway()
|
||||||
0
app/users/__init__.py
Normal file
0
app/users/__init__.py
Normal file
41
app/users/models.py
Normal file
41
app/users/models.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, DateTime, String, func
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.database.base import Base
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.auth.models import EmailVerificationToken, PasswordResetToken
|
||||||
|
from app.chats.models import ChatMember
|
||||||
|
from app.messages.models import Message
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||||
|
username: Mapped[str] = mapped_column(String(50), unique=True, index=True)
|
||||||
|
email: Mapped[str] = mapped_column(String(255), unique=True, index=True)
|
||||||
|
password_hash: Mapped[str] = mapped_column(String(255))
|
||||||
|
avatar_url: Mapped[str | None] = mapped_column(String(512), nullable=True)
|
||||||
|
email_verified: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False, index=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
memberships: Mapped[list["ChatMember"]] = relationship(back_populates="user", cascade="all, delete-orphan")
|
||||||
|
sent_messages: Mapped[list["Message"]] = relationship(back_populates="sender")
|
||||||
|
email_verification_tokens: Mapped[list["EmailVerificationToken"]] = relationship(
|
||||||
|
back_populates="user",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
password_reset_tokens: Mapped[list["PasswordResetToken"]] = relationship(
|
||||||
|
back_populates="user",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
26
app/users/repository.py
Normal file
26
app/users/repository.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.users.models import User
|
||||||
|
|
||||||
|
|
||||||
|
async def create_user(db: AsyncSession, *, email: str, username: str, password_hash: str) -> User:
|
||||||
|
user = User(email=email, username=username, password_hash=password_hash, email_verified=False)
|
||||||
|
db.add(user)
|
||||||
|
await db.flush()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_by_id(db: AsyncSession, user_id: int) -> User | None:
|
||||||
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
|
||||||
|
result = await db.execute(select(User).where(User.email == email))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
|
||||||
|
result = await db.execute(select(User).where(User.username == username))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
43
app/users/router.py
Normal file
43
app/users/router.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.auth.service import get_current_user
|
||||||
|
from app.database.session import get_db
|
||||||
|
from app.users.models import User
|
||||||
|
from app.users.schemas import UserProfileUpdate, UserRead
|
||||||
|
from app.users.service import get_user_by_id, get_user_by_username, update_user_profile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/users", tags=["users"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserRead)
|
||||||
|
async def read_me(current_user: User = Depends(get_current_user)) -> UserRead:
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{user_id}", response_model=UserRead)
|
||||||
|
async def read_user(user_id: int, db: AsyncSession = Depends(get_db), _current_user: User = Depends(get_current_user)) -> UserRead:
|
||||||
|
user = await get_user_by_id(db, user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/profile", response_model=UserRead)
|
||||||
|
async def update_profile(
|
||||||
|
payload: UserProfileUpdate,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> UserRead:
|
||||||
|
if payload.username and payload.username != current_user.username:
|
||||||
|
username_owner = await get_user_by_username(db, payload.username)
|
||||||
|
if username_owner:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Username already taken")
|
||||||
|
|
||||||
|
updated = await update_user_profile(
|
||||||
|
db,
|
||||||
|
current_user,
|
||||||
|
username=payload.username,
|
||||||
|
avatar_url=payload.avatar_url,
|
||||||
|
)
|
||||||
|
return updated
|
||||||
27
app/users/schemas.py
Normal file
27
app/users/schemas.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, EmailStr, Field
|
||||||
|
|
||||||
|
|
||||||
|
class UserBase(BaseModel):
|
||||||
|
username: str = Field(min_length=3, max_length=50)
|
||||||
|
email: EmailStr
|
||||||
|
|
||||||
|
|
||||||
|
class UserCreate(UserBase):
|
||||||
|
password: str = Field(min_length=8, max_length=128)
|
||||||
|
|
||||||
|
|
||||||
|
class UserRead(UserBase):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: int
|
||||||
|
avatar_url: str | None = None
|
||||||
|
email_verified: bool
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class UserProfileUpdate(BaseModel):
|
||||||
|
username: str | None = Field(default=None, min_length=3, max_length=50)
|
||||||
|
avatar_url: str | None = Field(default=None, max_length=512)
|
||||||
32
app/users/service.py
Normal file
32
app/users/service.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.users import repository
|
||||||
|
from app.users.models import User
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_by_id(db: AsyncSession, user_id: int) -> User | None:
|
||||||
|
return await repository.get_user_by_id(db, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
|
||||||
|
return await repository.get_user_by_email(db, email)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
|
||||||
|
return await repository.get_user_by_username(db, username)
|
||||||
|
|
||||||
|
|
||||||
|
async def update_user_profile(
|
||||||
|
db: AsyncSession,
|
||||||
|
user: User,
|
||||||
|
*,
|
||||||
|
username: str | None = None,
|
||||||
|
avatar_url: str | None = None,
|
||||||
|
) -> User:
|
||||||
|
if username is not None:
|
||||||
|
user.username = username
|
||||||
|
if avatar_url is not None:
|
||||||
|
user.avatar_url = avatar_url
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
return user
|
||||||
0
app/utils/__init__.py
Normal file
0
app/utils/__init__.py
Normal file
51
app/utils/security.py
Normal file
51
app/utils/security.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from secrets import token_urlsafe
|
||||||
|
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(password: str, hashed_password: str) -> bool:
|
||||||
|
return pwd_context.verify(password, hashed_password)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_token(subject: str, token_type: str, expires_delta: timedelta) -> str:
|
||||||
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
|
payload = {"sub": subject, "type": token_type, "exp": expire}
|
||||||
|
return jwt.encode(payload, settings.secret_key, algorithm=settings.jwt_algorithm)
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(subject: str) -> str:
|
||||||
|
return _create_token(
|
||||||
|
subject=subject,
|
||||||
|
token_type="access",
|
||||||
|
expires_delta=timedelta(minutes=settings.access_token_expire_minutes),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_token(subject: str) -> str:
|
||||||
|
return _create_token(
|
||||||
|
subject=subject,
|
||||||
|
token_type="refresh",
|
||||||
|
expires_delta=timedelta(days=settings.refresh_token_expire_days),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_token(token: str) -> dict:
|
||||||
|
try:
|
||||||
|
return jwt.decode(token, settings.secret_key, algorithms=[settings.jwt_algorithm])
|
||||||
|
except JWTError as exc:
|
||||||
|
raise ValueError("Invalid token") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def generate_random_token() -> str:
|
||||||
|
return token_urlsafe(48)
|
||||||
5
main.py
Normal file
5
main.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
import uvicorn
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
|
||||||
15
requirements.txt
Normal file
15
requirements.txt
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
fastapi==0.116.1
|
||||||
|
uvicorn[standard]==0.35.0
|
||||||
|
sqlalchemy==2.0.43
|
||||||
|
asyncpg==0.30.0
|
||||||
|
pydantic==2.11.7
|
||||||
|
pydantic-settings==2.10.1
|
||||||
|
python-jose[cryptography]==3.5.0
|
||||||
|
passlib[bcrypt]==1.7.4
|
||||||
|
email-validator==2.2.0
|
||||||
|
python-multipart==0.0.20
|
||||||
|
redis==6.4.0
|
||||||
|
celery==5.5.3
|
||||||
|
boto3==1.40.31
|
||||||
|
aiosmtplib==4.0.2
|
||||||
|
alembic==1.16.5
|
||||||
Reference in New Issue
Block a user