Implement security hardening, notification pipeline, and CI test suite
All checks were successful
CI / test (push) Successful in 9m2s

Security hardening:

- Added IP/user rate limiting with Redis-backed counters and fail-open behavior.

- Added message anti-spam controls (per-chat rate + duplicate cooldown).

- Implemented refresh token rotation with JTI tracking and revoke support.

Notification pipeline:

- Added Celery app and async notification tasks for mention/offline delivery.

- Added Redis-based presence tracking and integrated it into realtime connect/disconnect.

- Added notification dispatch from message flow and notifications listing endpoint.

Quality gates and CI:

- Added pytest async integration tests for auth and chat/message lifecycle.

- Added pytest config, test fixtures, and GitHub Actions CI workflow.

- Fixed bcrypt/passlib compatibility by pinning bcrypt version.

- Documented worker and quality-gate commands in README.
This commit is contained in:
2026-03-07 21:46:30 +03:00
parent a879ba7b50
commit 85631b566a
29 changed files with 723 additions and 11 deletions

View File

@@ -30,3 +30,11 @@ SMTP_USERNAME=
SMTP_PASSWORD= SMTP_PASSWORD=
SMTP_USE_TLS=false SMTP_USE_TLS=false
SMTP_FROM_EMAIL=no-reply@benyamessenger.local SMTP_FROM_EMAIL=no-reply@benyamessenger.local
LOGIN_RATE_LIMIT_PER_MINUTE=10
REGISTER_RATE_LIMIT_PER_MINUTE=5
RESET_RATE_LIMIT_PER_MINUTE=5
REFRESH_RATE_LIMIT_PER_MINUTE=30
MESSAGE_RATE_LIMIT_PER_MINUTE=30
DUPLICATE_MESSAGE_COOLDOWN_SECONDS=10
CELERY_TASK_ALWAYS_EAGER=false

32
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,32 @@
name: CI
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Compile check
run: |
python -m compileall app main.py
- name: Run tests
run: |
pytest -q

1
.gitignore vendored
View File

@@ -4,3 +4,4 @@ __pycache__/
*.pyc *.pyc
.idea/ .idea/
.env .env
test.db

View File

@@ -10,3 +10,16 @@ Backend foundation for a Telegram-like real-time messaging platform.
3. Configure environment from `.env.example`. 3. Configure environment from `.env.example`.
4. Start API: 4. Start API:
uvicorn app.main:app --reload --port 8000 uvicorn app.main:app --reload --port 8000
## Celery Worker
Run worker for async notification jobs:
celery -A app.celery_app:celery_app worker --loglevel=info
## Quality Gates
- Compile check:
python -m compileall app main.py
- Tests:
pytest -q

View File

@@ -1,10 +1,11 @@
from fastapi import APIRouter, Depends, status from fastapi import APIRouter, Depends, Request, status
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.schemas import ( from app.auth.schemas import (
AuthUserResponse, AuthUserResponse,
LoginRequest, LoginRequest,
MessageResponse, MessageResponse,
RefreshTokenRequest,
RegisterRequest, RegisterRequest,
RequestPasswordResetRequest, RequestPasswordResetRequest,
ResendVerificationRequest, ResendVerificationRequest,
@@ -16,6 +17,7 @@ from app.auth.service import (
get_current_user, get_current_user,
get_email_sender, get_email_sender,
login_user, login_user,
refresh_tokens,
register_user, register_user,
request_password_reset, request_password_reset,
resend_verification_email, resend_verification_email,
@@ -24,6 +26,8 @@ from app.auth.service import (
) )
from app.database.session import get_db from app.database.session import get_db
from app.email.service import EmailService from app.email.service import EmailService
from app.config.settings import settings
from app.utils.rate_limit import enforce_ip_rate_limit
from app.users.models import User from app.users.models import User
router = APIRouter(prefix="/auth", tags=["auth"]) router = APIRouter(prefix="/auth", tags=["auth"])
@@ -32,18 +36,43 @@ router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/register", response_model=MessageResponse, status_code=status.HTTP_201_CREATED) @router.post("/register", response_model=MessageResponse, status_code=status.HTTP_201_CREATED)
async def register( async def register(
payload: RegisterRequest, payload: RegisterRequest,
request: Request,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
email_service: EmailService = Depends(get_email_sender), email_service: EmailService = Depends(get_email_sender),
) -> MessageResponse: ) -> MessageResponse:
await enforce_ip_rate_limit(
request,
scope="auth_register",
limit=settings.register_rate_limit_per_minute,
)
await register_user(db, payload, email_service) await register_user(db, payload, email_service)
return MessageResponse(message="Registration successful. Verification email sent.") return MessageResponse(message="Registration successful. Verification email sent.")
@router.post("/login", response_model=TokenResponse) @router.post("/login", response_model=TokenResponse)
async def login(payload: LoginRequest, db: AsyncSession = Depends(get_db)) -> TokenResponse: async def login(payload: LoginRequest, request: Request, db: AsyncSession = Depends(get_db)) -> TokenResponse:
await enforce_ip_rate_limit(
request,
scope="auth_login",
limit=settings.login_rate_limit_per_minute,
)
return await login_user(db, payload) return await login_user(db, payload)
@router.post("/refresh", response_model=TokenResponse)
async def refresh(
payload: RefreshTokenRequest,
request: Request,
db: AsyncSession = Depends(get_db),
) -> TokenResponse:
await enforce_ip_rate_limit(
request,
scope="auth_refresh",
limit=settings.refresh_rate_limit_per_minute,
)
return await refresh_tokens(db, payload)
@router.post("/verify-email", response_model=MessageResponse) @router.post("/verify-email", response_model=MessageResponse)
async def verify_email_endpoint(payload: VerifyEmailRequest, db: AsyncSession = Depends(get_db)) -> MessageResponse: async def verify_email_endpoint(payload: VerifyEmailRequest, db: AsyncSession = Depends(get_db)) -> MessageResponse:
await verify_email(db, payload) await verify_email(db, payload)
@@ -53,9 +82,15 @@ async def verify_email_endpoint(payload: VerifyEmailRequest, db: AsyncSession =
@router.post("/resend-verification", response_model=MessageResponse) @router.post("/resend-verification", response_model=MessageResponse)
async def resend_verification( async def resend_verification(
payload: ResendVerificationRequest, payload: ResendVerificationRequest,
request: Request,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
email_service: EmailService = Depends(get_email_sender), email_service: EmailService = Depends(get_email_sender),
) -> MessageResponse: ) -> MessageResponse:
await enforce_ip_rate_limit(
request,
scope="auth_resend_verification",
limit=settings.reset_rate_limit_per_minute,
)
await resend_verification_email(db, payload, email_service) await resend_verification_email(db, payload, email_service)
return MessageResponse(message="If the account exists, a verification email was sent.") return MessageResponse(message="If the account exists, a verification email was sent.")
@@ -63,9 +98,15 @@ async def resend_verification(
@router.post("/request-password-reset", response_model=MessageResponse) @router.post("/request-password-reset", response_model=MessageResponse)
async def request_password_reset_endpoint( async def request_password_reset_endpoint(
payload: RequestPasswordResetRequest, payload: RequestPasswordResetRequest,
request: Request,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
email_service: EmailService = Depends(get_email_sender), email_service: EmailService = Depends(get_email_sender),
) -> MessageResponse: ) -> MessageResponse:
await enforce_ip_rate_limit(
request,
scope="auth_request_reset",
limit=settings.reset_rate_limit_per_minute,
)
await request_password_reset(db, payload, email_service) await request_password_reset(db, payload, email_service)
return MessageResponse(message="If the account exists, a reset email was sent.") return MessageResponse(message="If the account exists, a reset email was sent.")

View File

@@ -14,6 +14,10 @@ class LoginRequest(BaseModel):
password: str = Field(min_length=8, max_length=128) password: str = Field(min_length=8, max_length=128)
class RefreshTokenRequest(BaseModel):
refresh_token: str = Field(min_length=16)
class VerifyEmailRequest(BaseModel): class VerifyEmailRequest(BaseModel):
token: str = Field(min_length=16, max_length=512) token: str = Field(min_length=16, max_length=512)

View File

@@ -1,12 +1,15 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from uuid import uuid4
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth import repository as auth_repository from app.auth import repository as auth_repository
from app.auth.token_store import get_refresh_token_user_id, revoke_refresh_token_jti, store_refresh_token_jti
from app.auth.schemas import ( from app.auth.schemas import (
LoginRequest, LoginRequest,
RefreshTokenRequest,
RegisterRequest, RegisterRequest,
RequestPasswordResetRequest, RequestPasswordResetRequest,
ResendVerificationRequest, ResendVerificationRequest,
@@ -31,6 +34,10 @@ from app.utils.security import (
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/login")
def _refresh_ttl_seconds() -> int:
return settings.refresh_token_expire_days * 24 * 60 * 60
async def register_user( async def register_user(
db: AsyncSession, db: AsyncSession,
payload: RegisterRequest, payload: RegisterRequest,
@@ -105,9 +112,46 @@ async def login_user(db: AsyncSession, payload: LoginRequest) -> TokenResponse:
if not user.email_verified: if not user.email_verified:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified") raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified")
refresh_jti = str(uuid4())
refresh_token = create_refresh_token(str(user.id), jti=refresh_jti)
await store_refresh_token_jti(user_id=user.id, jti=refresh_jti, ttl_seconds=_refresh_ttl_seconds())
return TokenResponse( return TokenResponse(
access_token=create_access_token(str(user.id)), access_token=create_access_token(str(user.id)),
refresh_token=create_refresh_token(str(user.id)), refresh_token=refresh_token,
)
async def refresh_tokens(db: AsyncSession, payload: RefreshTokenRequest) -> TokenResponse:
credentials_error = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
try:
token_payload = decode_token(payload.refresh_token)
except ValueError as exc:
raise credentials_error from exc
if token_payload.get("type") != "refresh":
raise credentials_error
user_id = token_payload.get("sub")
refresh_jti = token_payload.get("jti")
if not user_id or not str(user_id).isdigit() or not refresh_jti:
raise credentials_error
active_user_id = await get_refresh_token_user_id(jti=refresh_jti)
if active_user_id is None or active_user_id != int(user_id):
raise credentials_error
user = await get_user_by_id(db, int(user_id))
if not user:
raise credentials_error
await revoke_refresh_token_jti(jti=refresh_jti)
new_jti = str(uuid4())
await store_refresh_token_jti(user_id=int(user_id), jti=new_jti, ttl_seconds=_refresh_ttl_seconds())
return TokenResponse(
access_token=create_access_token(str(user_id)),
refresh_token=create_refresh_token(str(user_id), jti=new_jti),
) )

46
app/auth/token_store.py Normal file
View File

@@ -0,0 +1,46 @@
import time
from redis.exceptions import RedisError
from app.utils.redis_client import get_redis_client
_fallback_tokens: dict[str, tuple[int, float]] = {}
def _cleanup_fallback() -> None:
now = time.time()
expired = [jti for jti, (_, exp_at) in _fallback_tokens.items() if exp_at <= now]
for jti in expired:
_fallback_tokens.pop(jti, None)
async def store_refresh_token_jti(*, user_id: int, jti: str, ttl_seconds: int) -> None:
try:
redis = get_redis_client()
await redis.set(f"auth:refresh:{jti}", str(user_id), ex=ttl_seconds)
except RedisError:
_cleanup_fallback()
_fallback_tokens[jti] = (user_id, time.time() + ttl_seconds)
async def get_refresh_token_user_id(*, jti: str) -> int | None:
try:
redis = get_redis_client()
value = await redis.get(f"auth:refresh:{jti}")
if not value or not str(value).isdigit():
return None
return int(value)
except RedisError:
_cleanup_fallback()
data = _fallback_tokens.get(jti)
if not data:
return None
return data[0]
async def revoke_refresh_token_jti(*, jti: str) -> None:
try:
redis = get_redis_client()
await redis.delete(f"auth:refresh:{jti}")
except RedisError:
_fallback_tokens.pop(jti, None)

21
app/celery_app.py Normal file
View File

@@ -0,0 +1,21 @@
from celery import Celery
from app.config.settings import settings
celery_app = Celery(
"benya_messenger",
broker=settings.redis_url,
backend=settings.redis_url,
)
celery_app.conf.update(
task_serializer="json",
accept_content=["json"],
result_serializer="json",
timezone="UTC",
enable_utc=True,
task_always_eager=settings.celery_task_always_eager,
task_eager_propagates=True,
)
celery_app.autodiscover_tasks(["app.notifications"])

View File

@@ -35,6 +35,14 @@ class Settings(BaseSettings):
smtp_use_tls: bool = False smtp_use_tls: bool = False
smtp_from_email: str = "no-reply@benyamessenger.local" smtp_from_email: str = "no-reply@benyamessenger.local"
login_rate_limit_per_minute: int = 10
register_rate_limit_per_minute: int = 5
reset_rate_limit_per_minute: int = 5
refresh_rate_limit_per_minute: int = 30
message_rate_limit_per_minute: int = 30
duplicate_message_cooldown_seconds: int = 10
celery_task_always_eager: bool = False
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore") model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")

View File

@@ -14,6 +14,7 @@ from app.notifications.router import router as notifications_router
from app.realtime.router import router as realtime_router from app.realtime.router import router as realtime_router
from app.realtime.service import realtime_gateway from app.realtime.service import realtime_gateway
from app.users.router import router as users_router from app.users.router import router as users_router
from app.utils.redis_client import close_redis_client
@asynccontextmanager @asynccontextmanager
@@ -24,6 +25,7 @@ async def lifespan(_app: FastAPI):
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
yield yield
await realtime_gateway.stop() await realtime_gateway.stop()
await close_redis_client()
app = FastAPI(title=settings.app_name, debug=settings.debug, lifespan=lifespan) app = FastAPI(title=settings.app_name, debug=settings.debug, lifespan=lifespan)

View File

@@ -4,13 +4,16 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.chats.service import ensure_chat_membership from app.chats.service import ensure_chat_membership
from app.messages import repository from app.messages import repository
from app.messages.models import Message from app.messages.models import Message
from app.messages.spam_guard import enforce_message_spam_policy
from app.messages.schemas import MessageCreateRequest, MessageUpdateRequest from app.messages.schemas import MessageCreateRequest, MessageUpdateRequest
from app.notifications.service import dispatch_message_notifications
async def create_chat_message(db: AsyncSession, *, sender_id: int, payload: MessageCreateRequest) -> Message: async def create_chat_message(db: AsyncSession, *, sender_id: int, payload: MessageCreateRequest) -> Message:
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=sender_id) await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=sender_id)
if payload.type.value == "text" and not (payload.text and payload.text.strip()): if payload.type.value == "text" and not (payload.text and payload.text.strip()):
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Text message cannot be empty") raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Text message cannot be empty")
await enforce_message_spam_policy(user_id=sender_id, chat_id=payload.chat_id, text=payload.text)
message = await repository.create_message( message = await repository.create_message(
db, db,
@@ -21,6 +24,11 @@ async def create_chat_message(db: AsyncSession, *, sender_id: int, payload: Mess
) )
await db.commit() await db.commit()
await db.refresh(message) await db.refresh(message)
try:
await dispatch_message_notifications(db, message)
except Exception:
# Notifications should not block message delivery.
pass
return message return message

View File

@@ -0,0 +1,37 @@
import hashlib
from fastapi import HTTPException, status
from redis.exceptions import RedisError
from app.config.settings import settings
from app.utils.redis_client import get_redis_client
def _hash_text(text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()
async def enforce_message_spam_policy(*, user_id: int, chat_id: int, text: str | None) -> None:
redis = get_redis_client()
rate_key = f"spam:msg_rate:{user_id}:{chat_id}"
try:
count = await redis.incr(rate_key)
if count == 1:
await redis.expire(rate_key, 60)
if count > settings.message_rate_limit_per_minute:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Message rate limit exceeded for this chat.",
)
normalized = (text or "").strip()
if normalized:
dup_key = f"spam:dup:{user_id}:{chat_id}:{_hash_text(normalized)}"
if await redis.exists(dup_key):
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Duplicate message cooldown is active.",
)
await redis.set(dup_key, "1", ex=settings.duplicate_message_cooldown_seconds)
except RedisError:
return

View File

@@ -5,3 +5,15 @@ from app.notifications.models import NotificationLog
async def create_notification_log(db: AsyncSession, *, user_id: int, event_type: str, payload: str) -> None: async def create_notification_log(db: AsyncSession, *, user_id: int, event_type: str, payload: str) -> None:
db.add(NotificationLog(user_id=user_id, event_type=event_type, payload=payload)) db.add(NotificationLog(user_id=user_id, event_type=event_type, payload=payload))
async def list_user_notifications(db: AsyncSession, *, user_id: int, limit: int = 50) -> list[NotificationLog]:
from sqlalchemy import select
result = await db.execute(
select(NotificationLog)
.where(NotificationLog.user_id == user_id)
.order_by(NotificationLog.id.desc())
.limit(limit)
)
return list(result.scalars().all())

View File

@@ -1,3 +1,19 @@
from fastapi import APIRouter from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.service import get_current_user
from app.database.session import get_db
from app.notifications.schemas import NotificationRead
from app.notifications.service import get_notifications_for_user
from app.users.models import User
router = APIRouter(prefix="/notifications", tags=["notifications"]) router = APIRouter(prefix="/notifications", tags=["notifications"])
@router.get("", response_model=list[NotificationRead])
async def list_my_notifications(
limit: int = 50,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> list[NotificationRead]:
return await get_notifications_for_user(db, user_id=current_user.id, limit=limit)

View File

@@ -1,7 +1,27 @@
from pydantic import BaseModel from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict
class NotificationRequest(BaseModel): class NotificationRequest(BaseModel):
user_id: int user_id: int
event_type: str event_type: str
payload: dict payload: dict
class NotificationRead(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
user_id: int
event_type: str
payload: str
created_at: datetime
class PushTaskPayload(BaseModel):
user_id: int
title: str
body: str
data: dict[str, Any]

View File

@@ -1,7 +1,23 @@
import json
import re
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.notifications.repository import create_notification_log from app.chats.repository import list_chat_members
from app.notifications.schemas import NotificationRequest from app.messages.models import Message
from app.notifications.repository import create_notification_log, list_user_notifications
from app.notifications.schemas import NotificationRead, NotificationRequest
from app.notifications.tasks import send_mention_notification_task, send_push_notification_task
from app.realtime.presence import is_user_online
from app.users.repository import list_users_by_ids
_MENTION_RE = re.compile(r"@([A-Za-z0-9_]{3,50})")
def _extract_mentions(text: str | None) -> set[str]:
if not text:
return set()
return {match.group(1).lower() for match in _MENTION_RE.finditer(text)}
async def enqueue_notification(db: AsyncSession, payload: NotificationRequest) -> None: async def enqueue_notification(db: AsyncSession, payload: NotificationRequest) -> None:
@@ -9,5 +25,73 @@ async def enqueue_notification(db: AsyncSession, payload: NotificationRequest) -
db, db,
user_id=payload.user_id, user_id=payload.user_id,
event_type=payload.event_type, event_type=payload.event_type,
payload=payload.payload.__repr__(), payload=json.dumps(payload.payload, ensure_ascii=True),
) )
async def dispatch_message_notifications(db: AsyncSession, message: Message) -> None:
members = await list_chat_members(db, chat_id=message.chat_id)
recipient_ids = [m.user_id for m in members if m.user_id != message.sender_id]
if not recipient_ids:
return
users = await list_users_by_ids(db, recipient_ids)
user_by_username = {user.username.lower(): user for user in users}
mentioned_usernames = _extract_mentions(message.text)
mentioned_user_ids = {user_by_username[name].id for name in mentioned_usernames if name in user_by_username}
sender_users = await list_users_by_ids(db, [message.sender_id])
sender_name = sender_users[0].username if sender_users else "Someone"
for recipient in users:
base_payload = {
"chat_id": message.chat_id,
"message_id": message.id,
"sender_id": message.sender_id,
}
if recipient.id in mentioned_user_ids:
payload = {
**base_payload,
"type": "mention",
"text_preview": (message.text or "")[:120],
}
await create_notification_log(
db,
user_id=recipient.id,
event_type="mention",
payload=json.dumps(payload, ensure_ascii=True),
)
send_mention_notification_task.delay(
recipient.id,
f"{sender_name} mentioned you",
(message.text or "")[:120],
payload,
)
continue
if not await is_user_online(recipient.id):
payload = {
**base_payload,
"type": "offline_message",
"text_preview": (message.text or "")[:120],
}
await create_notification_log(
db,
user_id=recipient.id,
event_type="offline_message",
payload=json.dumps(payload, ensure_ascii=True),
)
send_push_notification_task.delay(
recipient.id,
f"New message from {sender_name}",
(message.text or "")[:120],
payload,
)
await db.commit()
async def get_notifications_for_user(db: AsyncSession, *, user_id: int, limit: int = 50) -> list[NotificationRead]:
safe_limit = max(1, min(limit, 100))
rows = await list_user_notifications(db, user_id=user_id, limit=safe_limit)
return [NotificationRead.model_validate(item) for item in rows]

View File

@@ -0,0 +1,15 @@
import logging
from app.celery_app import celery_app
logger = logging.getLogger(__name__)
@celery_app.task(name="notifications.send_push")
def send_push_notification_task(user_id: int, title: str, body: str, data: dict) -> None:
logger.info("PUSH user=%s title=%s body=%s data=%s", user_id, title, body, data)
@celery_app.task(name="notifications.send_mention")
def send_mention_notification_task(user_id: int, title: str, body: str, data: dict) -> None:
logger.info("MENTION user=%s title=%s body=%s data=%s", user_id, title, body, data)

34
app/realtime/presence.py Normal file
View File

@@ -0,0 +1,34 @@
from redis.exceptions import RedisError
from app.utils.redis_client import get_redis_client
async def mark_user_online(user_id: int) -> None:
try:
redis = get_redis_client()
key = f"presence:user:{user_id}"
count = await redis.incr(key)
if count == 1:
await redis.expire(key, 3600)
except RedisError:
return
async def mark_user_offline(user_id: int) -> None:
try:
redis = get_redis_client()
key = f"presence:user:{user_id}"
value = await redis.decr(key)
if value <= 0:
await redis.delete(key)
except RedisError:
return
async def is_user_online(user_id: int) -> bool:
try:
redis = get_redis_client()
value = await redis.get(f"presence:user:{user_id}")
return bool(value and str(value).isdigit() and int(value) > 0)
except RedisError:
return False

View File

@@ -12,6 +12,7 @@ from app.chats.service import ensure_chat_membership
from app.messages.schemas import MessageCreateRequest, MessageRead from app.messages.schemas import MessageCreateRequest, MessageRead
from app.messages.service import create_chat_message from app.messages.service import create_chat_message
from app.realtime.models import ConnectionContext from app.realtime.models import ConnectionContext
from app.realtime.presence import mark_user_offline, mark_user_online
from app.realtime.repository import RedisRealtimeRepository from app.realtime.repository import RedisRealtimeRepository
from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload
@@ -51,6 +52,7 @@ class RealtimeGateway:
) )
for chat_id in user_chat_ids: for chat_id in user_chat_ids:
self._chat_subscribers[chat_id].add(user_id) self._chat_subscribers[chat_id].add(user_id)
await mark_user_online(user_id)
await self._send_user_event( await self._send_user_event(
user_id, user_id,
OutgoingRealtimeEvent( OutgoingRealtimeEvent(
@@ -73,6 +75,7 @@ class RealtimeGateway:
subscribers.discard(user_id) subscribers.discard(user_id)
if not subscribers: if not subscribers:
self._chat_subscribers.pop(chat_id, None) self._chat_subscribers.pop(chat_id, None)
await mark_user_offline(user_id)
async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None: async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None:
message = await create_chat_message( message = await create_chat_message(
@@ -164,6 +167,7 @@ class RealtimeGateway:
subscribers.discard(user_id) subscribers.discard(user_id)
if not subscribers: if not subscribers:
self._chat_subscribers.pop(chat_id, None) self._chat_subscribers.pop(chat_id, None)
await mark_user_offline(user_id)
@staticmethod @staticmethod
def _extract_chat_id(channel: str) -> int | None: def _extract_chat_id(channel: str) -> int | None:

View File

@@ -24,3 +24,10 @@ async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
async def get_user_by_username(db: AsyncSession, username: str) -> User | None: async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
result = await db.execute(select(User).where(User.username == username)) result = await db.execute(select(User).where(User.username == username))
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def list_users_by_ids(db: AsyncSession, user_ids: list[int]) -> list[User]:
if not user_ids:
return []
result = await db.execute(select(User).where(User.id.in_(user_ids)))
return list(result.scalars().all())

54
app/utils/rate_limit.py Normal file
View File

@@ -0,0 +1,54 @@
from fastapi import HTTPException, Request, status
from redis.exceptions import RedisError
from app.utils.redis_client import get_redis_client
def _safe_ip(request: Request) -> str:
if not request.client or not request.client.host:
return "unknown"
return request.client.host
async def enforce_ip_rate_limit(
request: Request,
*,
scope: str,
limit: int,
window_seconds: int = 60,
) -> None:
if limit <= 0:
return
key = f"rl:{scope}:ip:{_safe_ip(request)}"
await _enforce(key=key, limit=limit, window_seconds=window_seconds)
async def enforce_user_rate_limit(
user_id: int,
*,
scope: str,
limit: int,
window_seconds: int = 60,
) -> None:
if limit <= 0:
return
key = f"rl:{scope}:user:{user_id}"
await _enforce(key=key, limit=limit, window_seconds=window_seconds)
async def _enforce(*, key: str, limit: int, window_seconds: int) -> None:
try:
redis = get_redis_client()
current = await redis.incr(key)
if current == 1:
await redis.expire(key, window_seconds)
if current > limit:
ttl = await redis.ttl(key)
retry_after = max(1, ttl if ttl and ttl > 0 else window_seconds)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Rate limit exceeded. Retry in {retry_after} seconds.",
)
except RedisError:
# Fail-open in case of Redis outage to keep core API available.
return

19
app/utils/redis_client.py Normal file
View File

@@ -0,0 +1,19 @@
from redis.asyncio import Redis
from app.config.settings import settings
_redis_client: Redis | None = None
def get_redis_client() -> Redis:
global _redis_client
if _redis_client is None:
_redis_client = Redis.from_url(settings.redis_url, decode_responses=True)
return _redis_client
async def close_redis_client() -> None:
global _redis_client
if _redis_client is not None:
await _redis_client.aclose()
_redis_client = None

View File

@@ -1,5 +1,6 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from secrets import token_urlsafe from secrets import token_urlsafe
from uuid import uuid4
from jose import JWTError, jwt from jose import JWTError, jwt
from passlib.context import CryptContext from passlib.context import CryptContext
@@ -18,9 +19,11 @@ def verify_password(password: str, hashed_password: str) -> bool:
return pwd_context.verify(password, hashed_password) return pwd_context.verify(password, hashed_password)
def _create_token(subject: str, token_type: str, expires_delta: timedelta) -> str: def _create_token(subject: str, token_type: str, expires_delta: timedelta, jti: str | None = None) -> str:
now = datetime.now(timezone.utc)
expire = datetime.now(timezone.utc) + expires_delta expire = datetime.now(timezone.utc) + expires_delta
payload = {"sub": subject, "type": token_type, "exp": expire} payload = {"sub": subject, "type": token_type, "exp": expire, "iat": now}
payload["jti"] = jti or str(uuid4())
return jwt.encode(payload, settings.secret_key, algorithm=settings.jwt_algorithm) return jwt.encode(payload, settings.secret_key, algorithm=settings.jwt_algorithm)
@@ -32,11 +35,12 @@ def create_access_token(subject: str) -> str:
) )
def create_refresh_token(subject: str) -> str: def create_refresh_token(subject: str, *, jti: str | None = None) -> str:
return _create_token( return _create_token(
subject=subject, subject=subject,
token_type="refresh", token_type="refresh",
expires_delta=timedelta(days=settings.refresh_token_expire_days), expires_delta=timedelta(days=settings.refresh_token_expire_days),
jti=jti,
) )

3
pytest.ini Normal file
View File

@@ -0,0 +1,3 @@
[pytest]
asyncio_mode = auto
testpaths = tests

View File

@@ -6,6 +6,7 @@ pydantic==2.11.7
pydantic-settings==2.10.1 pydantic-settings==2.10.1
python-jose[cryptography]==3.5.0 python-jose[cryptography]==3.5.0
passlib[bcrypt]==1.7.4 passlib[bcrypt]==1.7.4
bcrypt==4.0.1
email-validator==2.2.0 email-validator==2.2.0
python-multipart==0.0.20 python-multipart==0.0.20
redis==6.4.0 redis==6.4.0
@@ -13,3 +14,7 @@ celery==5.5.3
boto3==1.40.31 boto3==1.40.31
aiosmtplib==4.0.2 aiosmtplib==4.0.2
alembic==1.16.5 alembic==1.16.5
pytest==8.4.2
pytest-asyncio==1.2.0
httpx==0.28.1
aiosqlite==0.21.0

40
tests/conftest.py Normal file
View File

@@ -0,0 +1,40 @@
import os
import sys
from pathlib import Path
# Set test env before importing app modules.
os.environ["POSTGRES_DSN"] = "sqlite+aiosqlite:///./test.db"
os.environ["AUTO_CREATE_TABLES"] = "false"
os.environ["REDIS_URL"] = "redis://localhost:6399/15"
os.environ["SECRET_KEY"] = "test-secret-key-1234567890"
os.environ["CELERY_TASK_ALWAYS_EAGER"] = "true"
PROJECT_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(PROJECT_ROOT))
import pytest
from httpx import ASGITransport, AsyncClient
from app.database.base import Base
from app.database.session import AsyncSessionLocal, engine
from app.main import app
@pytest.fixture(autouse=True)
async def reset_db() -> None:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
yield
@pytest.fixture
async def client() -> AsyncClient:
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
yield ac
@pytest.fixture
async def db_session():
async with AsyncSessionLocal() as session:
yield session

69
tests/test_auth_flow.py Normal file
View File

@@ -0,0 +1,69 @@
from sqlalchemy import select
from app.auth.models import EmailVerificationToken
async def test_register_verify_login_and_me(client, db_session):
register_payload = {
"email": "alice@example.com",
"username": "alice",
"password": "strongpass123",
}
register_response = await client.post("/api/v1/auth/register", json=register_payload)
assert register_response.status_code == 201
login_response_before_verify = await client.post(
"/api/v1/auth/login",
json={"email": register_payload["email"], "password": register_payload["password"]},
)
assert login_response_before_verify.status_code == 403
token_row = await db_session.execute(select(EmailVerificationToken).order_by(EmailVerificationToken.id.desc()))
verify_token = token_row.scalar_one().token
verify_response = await client.post("/api/v1/auth/verify-email", json={"token": verify_token})
assert verify_response.status_code == 200
login_response = await client.post(
"/api/v1/auth/login",
json={"email": register_payload["email"], "password": register_payload["password"]},
)
assert login_response.status_code == 200
token_data = login_response.json()
assert "access_token" in token_data
assert "refresh_token" in token_data
me_response = await client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {token_data['access_token']}"},
)
assert me_response.status_code == 200
me_data = me_response.json()
assert me_data["email"] == "alice@example.com"
assert me_data["email_verified"] is True
async def test_refresh_token_rotation(client, db_session):
payload = {
"email": "bob@example.com",
"username": "bob",
"password": "strongpass123",
}
await client.post("/api/v1/auth/register", json=payload)
token_row = await db_session.execute(select(EmailVerificationToken).order_by(EmailVerificationToken.id.desc()))
verify_token = token_row.scalar_one().token
await client.post("/api/v1/auth/verify-email", json={"token": verify_token})
login_response = await client.post(
"/api/v1/auth/login",
json={"email": payload["email"], "password": payload["password"]},
)
refresh_token = login_response.json()["refresh_token"]
refresh_response = await client.post("/api/v1/auth/refresh", json={"refresh_token": refresh_token})
assert refresh_response.status_code == 200
rotated_refresh_token = refresh_response.json()["refresh_token"]
assert rotated_refresh_token != refresh_token
old_refresh_reuse = await client.post("/api/v1/auth/refresh", json={"refresh_token": refresh_token})
assert old_refresh_reuse.status_code == 401

View File

@@ -0,0 +1,61 @@
from sqlalchemy import select
from app.auth.models import EmailVerificationToken
from app.chats.models import ChatType
async def _create_verified_user(client, db_session, email: str, username: str, password: str) -> dict:
await client.post(
"/api/v1/auth/register",
json={"email": email, "username": username, "password": password},
)
token_row = await db_session.execute(select(EmailVerificationToken).order_by(EmailVerificationToken.id.desc()))
verify_token = token_row.scalar_one().token
await client.post("/api/v1/auth/verify-email", json={"token": verify_token})
login_response = await client.post("/api/v1/auth/login", json={"email": email, "password": password})
return login_response.json()
async def test_private_chat_message_lifecycle(client, db_session):
u1 = await _create_verified_user(client, db_session, "u1@example.com", "user_one", "strongpass123")
u2 = await _create_verified_user(client, db_session, "u2@example.com", "user_two", "strongpass123")
me_u2 = await client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {u2['access_token']}"})
u2_id = me_u2.json()["id"]
create_chat_response = await client.post(
"/api/v1/chats",
headers={"Authorization": f"Bearer {u1['access_token']}"},
json={"type": ChatType.PRIVATE.value, "title": None, "member_ids": [u2_id]},
)
assert create_chat_response.status_code == 200
chat_id = create_chat_response.json()["id"]
send_message_response = await client.post(
"/api/v1/messages",
headers={"Authorization": f"Bearer {u1['access_token']}"},
json={"chat_id": chat_id, "type": "text", "text": "hello @user_two"},
)
assert send_message_response.status_code == 201
message_id = send_message_response.json()["id"]
list_messages_response = await client.get(
f"/api/v1/messages/{chat_id}",
headers={"Authorization": f"Bearer {u2['access_token']}"},
)
assert list_messages_response.status_code == 200
assert len(list_messages_response.json()) == 1
edit_message_response = await client.put(
f"/api/v1/messages/{message_id}",
headers={"Authorization": f"Bearer {u1['access_token']}"},
json={"text": "edited text"},
)
assert edit_message_response.status_code == 200
assert edit_message_response.json()["text"] == "edited text"
delete_message_response = await client.delete(
f"/api/v1/messages/{message_id}",
headers={"Authorization": f"Bearer {u1['access_token']}"},
)
assert delete_message_response.status_code == 204