From 85631b566afef0dc2e05b76c7b5ee01575c1cb52 Mon Sep 17 00:00:00 2001 From: benya Date: Sat, 7 Mar 2026 21:46:30 +0300 Subject: [PATCH] Implement security hardening, notification pipeline, and CI test suite Security hardening: - Added IP/user rate limiting with Redis-backed counters and fail-open behavior. - Added message anti-spam controls (per-chat rate + duplicate cooldown). - Implemented refresh token rotation with JTI tracking and revoke support. Notification pipeline: - Added Celery app and async notification tasks for mention/offline delivery. - Added Redis-based presence tracking and integrated it into realtime connect/disconnect. - Added notification dispatch from message flow and notifications listing endpoint. Quality gates and CI: - Added pytest async integration tests for auth and chat/message lifecycle. - Added pytest config, test fixtures, and GitHub Actions CI workflow. - Fixed bcrypt/passlib compatibility by pinning bcrypt version. - Documented worker and quality-gate commands in README. --- .env.example | 8 +++ .github/workflows/ci.yml | 32 ++++++++++++ .gitignore | 1 + README.md | 13 +++++ app/auth/router.py | 45 ++++++++++++++++- app/auth/schemas.py | 4 ++ app/auth/service.py | 46 ++++++++++++++++- app/auth/token_store.py | 46 +++++++++++++++++ app/celery_app.py | 21 ++++++++ app/config/settings.py | 8 +++ app/main.py | 2 + app/messages/service.py | 8 +++ app/messages/spam_guard.py | 37 ++++++++++++++ app/notifications/repository.py | 12 +++++ app/notifications/router.py | 18 ++++++- app/notifications/schemas.py | 22 +++++++- app/notifications/service.py | 90 +++++++++++++++++++++++++++++++-- app/notifications/tasks.py | 15 ++++++ app/realtime/presence.py | 34 +++++++++++++ app/realtime/service.py | 4 ++ app/users/repository.py | 7 +++ app/utils/rate_limit.py | 54 ++++++++++++++++++++ app/utils/redis_client.py | 19 +++++++ app/utils/security.py | 10 ++-- pytest.ini | 3 ++ requirements.txt | 5 ++ tests/conftest.py | 40 +++++++++++++++ tests/test_auth_flow.py | 69 +++++++++++++++++++++++++ tests/test_chat_message_flow.py | 61 ++++++++++++++++++++++ 29 files changed, 723 insertions(+), 11 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 app/auth/token_store.py create mode 100644 app/celery_app.py create mode 100644 app/messages/spam_guard.py create mode 100644 app/notifications/tasks.py create mode 100644 app/realtime/presence.py create mode 100644 app/utils/rate_limit.py create mode 100644 app/utils/redis_client.py create mode 100644 pytest.ini create mode 100644 tests/conftest.py create mode 100644 tests/test_auth_flow.py create mode 100644 tests/test_chat_message_flow.py diff --git a/.env.example b/.env.example index 10c505a..4c41b4d 100644 --- a/.env.example +++ b/.env.example @@ -30,3 +30,11 @@ SMTP_USERNAME= SMTP_PASSWORD= SMTP_USE_TLS=false SMTP_FROM_EMAIL=no-reply@benyamessenger.local + +LOGIN_RATE_LIMIT_PER_MINUTE=10 +REGISTER_RATE_LIMIT_PER_MINUTE=5 +RESET_RATE_LIMIT_PER_MINUTE=5 +REFRESH_RATE_LIMIT_PER_MINUTE=30 +MESSAGE_RATE_LIMIT_PER_MINUTE=30 +DUPLICATE_MESSAGE_COOLDOWN_SECONDS=10 +CELERY_TASK_ALWAYS_EAGER=false diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..a9ceb63 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,32 @@ +name: CI + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Compile check + run: | + python -m compileall app main.py + + - name: Run tests + run: | + pytest -q diff --git a/.gitignore b/.gitignore index e037df0..1961fc5 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ __pycache__/ *.pyc .idea/ .env +test.db diff --git a/README.md b/README.md index e138e11..3ddf574 100644 --- a/README.md +++ b/README.md @@ -10,3 +10,16 @@ Backend foundation for a Telegram-like real-time messaging platform. 3. Configure environment from `.env.example`. 4. Start API: uvicorn app.main:app --reload --port 8000 + +## Celery Worker + +Run worker for async notification jobs: + +celery -A app.celery_app:celery_app worker --loglevel=info + +## Quality Gates + +- Compile check: + python -m compileall app main.py +- Tests: + pytest -q diff --git a/app/auth/router.py b/app/auth/router.py index c87dae4..4056a96 100644 --- a/app/auth/router.py +++ b/app/auth/router.py @@ -1,10 +1,11 @@ -from fastapi import APIRouter, Depends, status +from fastapi import APIRouter, Depends, Request, status from sqlalchemy.ext.asyncio import AsyncSession from app.auth.schemas import ( AuthUserResponse, LoginRequest, MessageResponse, + RefreshTokenRequest, RegisterRequest, RequestPasswordResetRequest, ResendVerificationRequest, @@ -16,6 +17,7 @@ from app.auth.service import ( get_current_user, get_email_sender, login_user, + refresh_tokens, register_user, request_password_reset, resend_verification_email, @@ -24,6 +26,8 @@ from app.auth.service import ( ) from app.database.session import get_db from app.email.service import EmailService +from app.config.settings import settings +from app.utils.rate_limit import enforce_ip_rate_limit from app.users.models import User router = APIRouter(prefix="/auth", tags=["auth"]) @@ -32,18 +36,43 @@ router = APIRouter(prefix="/auth", tags=["auth"]) @router.post("/register", response_model=MessageResponse, status_code=status.HTTP_201_CREATED) async def register( payload: RegisterRequest, + request: Request, db: AsyncSession = Depends(get_db), email_service: EmailService = Depends(get_email_sender), ) -> MessageResponse: + await enforce_ip_rate_limit( + request, + scope="auth_register", + limit=settings.register_rate_limit_per_minute, + ) await register_user(db, payload, email_service) return MessageResponse(message="Registration successful. Verification email sent.") @router.post("/login", response_model=TokenResponse) -async def login(payload: LoginRequest, db: AsyncSession = Depends(get_db)) -> TokenResponse: +async def login(payload: LoginRequest, request: Request, db: AsyncSession = Depends(get_db)) -> TokenResponse: + await enforce_ip_rate_limit( + request, + scope="auth_login", + limit=settings.login_rate_limit_per_minute, + ) return await login_user(db, payload) +@router.post("/refresh", response_model=TokenResponse) +async def refresh( + payload: RefreshTokenRequest, + request: Request, + db: AsyncSession = Depends(get_db), +) -> TokenResponse: + await enforce_ip_rate_limit( + request, + scope="auth_refresh", + limit=settings.refresh_rate_limit_per_minute, + ) + return await refresh_tokens(db, payload) + + @router.post("/verify-email", response_model=MessageResponse) async def verify_email_endpoint(payload: VerifyEmailRequest, db: AsyncSession = Depends(get_db)) -> MessageResponse: await verify_email(db, payload) @@ -53,9 +82,15 @@ async def verify_email_endpoint(payload: VerifyEmailRequest, db: AsyncSession = @router.post("/resend-verification", response_model=MessageResponse) async def resend_verification( payload: ResendVerificationRequest, + request: Request, db: AsyncSession = Depends(get_db), email_service: EmailService = Depends(get_email_sender), ) -> MessageResponse: + await enforce_ip_rate_limit( + request, + scope="auth_resend_verification", + limit=settings.reset_rate_limit_per_minute, + ) await resend_verification_email(db, payload, email_service) return MessageResponse(message="If the account exists, a verification email was sent.") @@ -63,9 +98,15 @@ async def resend_verification( @router.post("/request-password-reset", response_model=MessageResponse) async def request_password_reset_endpoint( payload: RequestPasswordResetRequest, + request: Request, db: AsyncSession = Depends(get_db), email_service: EmailService = Depends(get_email_sender), ) -> MessageResponse: + await enforce_ip_rate_limit( + request, + scope="auth_request_reset", + limit=settings.reset_rate_limit_per_minute, + ) await request_password_reset(db, payload, email_service) return MessageResponse(message="If the account exists, a reset email was sent.") diff --git a/app/auth/schemas.py b/app/auth/schemas.py index 7ebe85b..ed1526d 100644 --- a/app/auth/schemas.py +++ b/app/auth/schemas.py @@ -14,6 +14,10 @@ class LoginRequest(BaseModel): password: str = Field(min_length=8, max_length=128) +class RefreshTokenRequest(BaseModel): + refresh_token: str = Field(min_length=16) + + class VerifyEmailRequest(BaseModel): token: str = Field(min_length=16, max_length=512) diff --git a/app/auth/service.py b/app/auth/service.py index a45a39c..3085183 100644 --- a/app/auth/service.py +++ b/app/auth/service.py @@ -1,12 +1,15 @@ from datetime import datetime, timedelta, timezone +from uuid import uuid4 from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from sqlalchemy.ext.asyncio import AsyncSession from app.auth import repository as auth_repository +from app.auth.token_store import get_refresh_token_user_id, revoke_refresh_token_jti, store_refresh_token_jti from app.auth.schemas import ( LoginRequest, + RefreshTokenRequest, RegisterRequest, RequestPasswordResetRequest, ResendVerificationRequest, @@ -31,6 +34,10 @@ from app.utils.security import ( oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/login") +def _refresh_ttl_seconds() -> int: + return settings.refresh_token_expire_days * 24 * 60 * 60 + + async def register_user( db: AsyncSession, payload: RegisterRequest, @@ -105,9 +112,46 @@ async def login_user(db: AsyncSession, payload: LoginRequest) -> TokenResponse: if not user.email_verified: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified") + refresh_jti = str(uuid4()) + refresh_token = create_refresh_token(str(user.id), jti=refresh_jti) + await store_refresh_token_jti(user_id=user.id, jti=refresh_jti, ttl_seconds=_refresh_ttl_seconds()) return TokenResponse( access_token=create_access_token(str(user.id)), - refresh_token=create_refresh_token(str(user.id)), + refresh_token=refresh_token, + ) + + +async def refresh_tokens(db: AsyncSession, payload: RefreshTokenRequest) -> TokenResponse: + credentials_error = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid refresh token", + ) + try: + token_payload = decode_token(payload.refresh_token) + except ValueError as exc: + raise credentials_error from exc + + if token_payload.get("type") != "refresh": + raise credentials_error + + user_id = token_payload.get("sub") + refresh_jti = token_payload.get("jti") + if not user_id or not str(user_id).isdigit() or not refresh_jti: + raise credentials_error + + active_user_id = await get_refresh_token_user_id(jti=refresh_jti) + if active_user_id is None or active_user_id != int(user_id): + raise credentials_error + user = await get_user_by_id(db, int(user_id)) + if not user: + raise credentials_error + + await revoke_refresh_token_jti(jti=refresh_jti) + new_jti = str(uuid4()) + await store_refresh_token_jti(user_id=int(user_id), jti=new_jti, ttl_seconds=_refresh_ttl_seconds()) + return TokenResponse( + access_token=create_access_token(str(user_id)), + refresh_token=create_refresh_token(str(user_id), jti=new_jti), ) diff --git a/app/auth/token_store.py b/app/auth/token_store.py new file mode 100644 index 0000000..5ee6ad2 --- /dev/null +++ b/app/auth/token_store.py @@ -0,0 +1,46 @@ +import time + +from redis.exceptions import RedisError + +from app.utils.redis_client import get_redis_client + +_fallback_tokens: dict[str, tuple[int, float]] = {} + + +def _cleanup_fallback() -> None: + now = time.time() + expired = [jti for jti, (_, exp_at) in _fallback_tokens.items() if exp_at <= now] + for jti in expired: + _fallback_tokens.pop(jti, None) + + +async def store_refresh_token_jti(*, user_id: int, jti: str, ttl_seconds: int) -> None: + try: + redis = get_redis_client() + await redis.set(f"auth:refresh:{jti}", str(user_id), ex=ttl_seconds) + except RedisError: + _cleanup_fallback() + _fallback_tokens[jti] = (user_id, time.time() + ttl_seconds) + + +async def get_refresh_token_user_id(*, jti: str) -> int | None: + try: + redis = get_redis_client() + value = await redis.get(f"auth:refresh:{jti}") + if not value or not str(value).isdigit(): + return None + return int(value) + except RedisError: + _cleanup_fallback() + data = _fallback_tokens.get(jti) + if not data: + return None + return data[0] + + +async def revoke_refresh_token_jti(*, jti: str) -> None: + try: + redis = get_redis_client() + await redis.delete(f"auth:refresh:{jti}") + except RedisError: + _fallback_tokens.pop(jti, None) diff --git a/app/celery_app.py b/app/celery_app.py new file mode 100644 index 0000000..9943571 --- /dev/null +++ b/app/celery_app.py @@ -0,0 +1,21 @@ +from celery import Celery + +from app.config.settings import settings + +celery_app = Celery( + "benya_messenger", + broker=settings.redis_url, + backend=settings.redis_url, +) + +celery_app.conf.update( + task_serializer="json", + accept_content=["json"], + result_serializer="json", + timezone="UTC", + enable_utc=True, + task_always_eager=settings.celery_task_always_eager, + task_eager_propagates=True, +) + +celery_app.autodiscover_tasks(["app.notifications"]) diff --git a/app/config/settings.py b/app/config/settings.py index 0a1f0a2..85855ef 100644 --- a/app/config/settings.py +++ b/app/config/settings.py @@ -35,6 +35,14 @@ class Settings(BaseSettings): smtp_use_tls: bool = False smtp_from_email: str = "no-reply@benyamessenger.local" + login_rate_limit_per_minute: int = 10 + register_rate_limit_per_minute: int = 5 + reset_rate_limit_per_minute: int = 5 + refresh_rate_limit_per_minute: int = 30 + message_rate_limit_per_minute: int = 30 + duplicate_message_cooldown_seconds: int = 10 + celery_task_always_eager: bool = False + model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore") diff --git a/app/main.py b/app/main.py index 7940418..19846ca 100644 --- a/app/main.py +++ b/app/main.py @@ -14,6 +14,7 @@ from app.notifications.router import router as notifications_router from app.realtime.router import router as realtime_router from app.realtime.service import realtime_gateway from app.users.router import router as users_router +from app.utils.redis_client import close_redis_client @asynccontextmanager @@ -24,6 +25,7 @@ async def lifespan(_app: FastAPI): await conn.run_sync(Base.metadata.create_all) yield await realtime_gateway.stop() + await close_redis_client() app = FastAPI(title=settings.app_name, debug=settings.debug, lifespan=lifespan) diff --git a/app/messages/service.py b/app/messages/service.py index 5b8161a..3485d25 100644 --- a/app/messages/service.py +++ b/app/messages/service.py @@ -4,13 +4,16 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.chats.service import ensure_chat_membership from app.messages import repository from app.messages.models import Message +from app.messages.spam_guard import enforce_message_spam_policy from app.messages.schemas import MessageCreateRequest, MessageUpdateRequest +from app.notifications.service import dispatch_message_notifications async def create_chat_message(db: AsyncSession, *, sender_id: int, payload: MessageCreateRequest) -> Message: await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=sender_id) if payload.type.value == "text" and not (payload.text and payload.text.strip()): raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Text message cannot be empty") + await enforce_message_spam_policy(user_id=sender_id, chat_id=payload.chat_id, text=payload.text) message = await repository.create_message( db, @@ -21,6 +24,11 @@ async def create_chat_message(db: AsyncSession, *, sender_id: int, payload: Mess ) await db.commit() await db.refresh(message) + try: + await dispatch_message_notifications(db, message) + except Exception: + # Notifications should not block message delivery. + pass return message diff --git a/app/messages/spam_guard.py b/app/messages/spam_guard.py new file mode 100644 index 0000000..ef91bb5 --- /dev/null +++ b/app/messages/spam_guard.py @@ -0,0 +1,37 @@ +import hashlib + +from fastapi import HTTPException, status +from redis.exceptions import RedisError + +from app.config.settings import settings +from app.utils.redis_client import get_redis_client + + +def _hash_text(text: str) -> str: + return hashlib.sha256(text.encode("utf-8")).hexdigest() + + +async def enforce_message_spam_policy(*, user_id: int, chat_id: int, text: str | None) -> None: + redis = get_redis_client() + rate_key = f"spam:msg_rate:{user_id}:{chat_id}" + try: + count = await redis.incr(rate_key) + if count == 1: + await redis.expire(rate_key, 60) + if count > settings.message_rate_limit_per_minute: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="Message rate limit exceeded for this chat.", + ) + + normalized = (text or "").strip() + if normalized: + dup_key = f"spam:dup:{user_id}:{chat_id}:{_hash_text(normalized)}" + if await redis.exists(dup_key): + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="Duplicate message cooldown is active.", + ) + await redis.set(dup_key, "1", ex=settings.duplicate_message_cooldown_seconds) + except RedisError: + return diff --git a/app/notifications/repository.py b/app/notifications/repository.py index ded25bd..5fbcad7 100644 --- a/app/notifications/repository.py +++ b/app/notifications/repository.py @@ -5,3 +5,15 @@ from app.notifications.models import NotificationLog async def create_notification_log(db: AsyncSession, *, user_id: int, event_type: str, payload: str) -> None: db.add(NotificationLog(user_id=user_id, event_type=event_type, payload=payload)) + + +async def list_user_notifications(db: AsyncSession, *, user_id: int, limit: int = 50) -> list[NotificationLog]: + from sqlalchemy import select + + result = await db.execute( + select(NotificationLog) + .where(NotificationLog.user_id == user_id) + .order_by(NotificationLog.id.desc()) + .limit(limit) + ) + return list(result.scalars().all()) diff --git a/app/notifications/router.py b/app/notifications/router.py index 384c289..03cbcc3 100644 --- a/app/notifications/router.py +++ b/app/notifications/router.py @@ -1,3 +1,19 @@ -from fastapi import APIRouter +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth.service import get_current_user +from app.database.session import get_db +from app.notifications.schemas import NotificationRead +from app.notifications.service import get_notifications_for_user +from app.users.models import User router = APIRouter(prefix="/notifications", tags=["notifications"]) + + +@router.get("", response_model=list[NotificationRead]) +async def list_my_notifications( + limit: int = 50, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +) -> list[NotificationRead]: + return await get_notifications_for_user(db, user_id=current_user.id, limit=limit) diff --git a/app/notifications/schemas.py b/app/notifications/schemas.py index 5969a0a..f5f0318 100644 --- a/app/notifications/schemas.py +++ b/app/notifications/schemas.py @@ -1,7 +1,27 @@ -from pydantic import BaseModel +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict class NotificationRequest(BaseModel): user_id: int event_type: str payload: dict + + +class NotificationRead(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: int + user_id: int + event_type: str + payload: str + created_at: datetime + + +class PushTaskPayload(BaseModel): + user_id: int + title: str + body: str + data: dict[str, Any] diff --git a/app/notifications/service.py b/app/notifications/service.py index 2054332..79a2604 100644 --- a/app/notifications/service.py +++ b/app/notifications/service.py @@ -1,7 +1,23 @@ +import json +import re + from sqlalchemy.ext.asyncio import AsyncSession -from app.notifications.repository import create_notification_log -from app.notifications.schemas import NotificationRequest +from app.chats.repository import list_chat_members +from app.messages.models import Message +from app.notifications.repository import create_notification_log, list_user_notifications +from app.notifications.schemas import NotificationRead, NotificationRequest +from app.notifications.tasks import send_mention_notification_task, send_push_notification_task +from app.realtime.presence import is_user_online +from app.users.repository import list_users_by_ids + +_MENTION_RE = re.compile(r"@([A-Za-z0-9_]{3,50})") + + +def _extract_mentions(text: str | None) -> set[str]: + if not text: + return set() + return {match.group(1).lower() for match in _MENTION_RE.finditer(text)} async def enqueue_notification(db: AsyncSession, payload: NotificationRequest) -> None: @@ -9,5 +25,73 @@ async def enqueue_notification(db: AsyncSession, payload: NotificationRequest) - db, user_id=payload.user_id, event_type=payload.event_type, - payload=payload.payload.__repr__(), + payload=json.dumps(payload.payload, ensure_ascii=True), ) + + +async def dispatch_message_notifications(db: AsyncSession, message: Message) -> None: + members = await list_chat_members(db, chat_id=message.chat_id) + recipient_ids = [m.user_id for m in members if m.user_id != message.sender_id] + if not recipient_ids: + return + + users = await list_users_by_ids(db, recipient_ids) + user_by_username = {user.username.lower(): user for user in users} + mentioned_usernames = _extract_mentions(message.text) + mentioned_user_ids = {user_by_username[name].id for name in mentioned_usernames if name in user_by_username} + + sender_users = await list_users_by_ids(db, [message.sender_id]) + sender_name = sender_users[0].username if sender_users else "Someone" + + for recipient in users: + base_payload = { + "chat_id": message.chat_id, + "message_id": message.id, + "sender_id": message.sender_id, + } + if recipient.id in mentioned_user_ids: + payload = { + **base_payload, + "type": "mention", + "text_preview": (message.text or "")[:120], + } + await create_notification_log( + db, + user_id=recipient.id, + event_type="mention", + payload=json.dumps(payload, ensure_ascii=True), + ) + send_mention_notification_task.delay( + recipient.id, + f"{sender_name} mentioned you", + (message.text or "")[:120], + payload, + ) + continue + + if not await is_user_online(recipient.id): + payload = { + **base_payload, + "type": "offline_message", + "text_preview": (message.text or "")[:120], + } + await create_notification_log( + db, + user_id=recipient.id, + event_type="offline_message", + payload=json.dumps(payload, ensure_ascii=True), + ) + send_push_notification_task.delay( + recipient.id, + f"New message from {sender_name}", + (message.text or "")[:120], + payload, + ) + + await db.commit() + + +async def get_notifications_for_user(db: AsyncSession, *, user_id: int, limit: int = 50) -> list[NotificationRead]: + safe_limit = max(1, min(limit, 100)) + rows = await list_user_notifications(db, user_id=user_id, limit=safe_limit) + return [NotificationRead.model_validate(item) for item in rows] diff --git a/app/notifications/tasks.py b/app/notifications/tasks.py new file mode 100644 index 0000000..81b6b8a --- /dev/null +++ b/app/notifications/tasks.py @@ -0,0 +1,15 @@ +import logging + +from app.celery_app import celery_app + +logger = logging.getLogger(__name__) + + +@celery_app.task(name="notifications.send_push") +def send_push_notification_task(user_id: int, title: str, body: str, data: dict) -> None: + logger.info("PUSH user=%s title=%s body=%s data=%s", user_id, title, body, data) + + +@celery_app.task(name="notifications.send_mention") +def send_mention_notification_task(user_id: int, title: str, body: str, data: dict) -> None: + logger.info("MENTION user=%s title=%s body=%s data=%s", user_id, title, body, data) diff --git a/app/realtime/presence.py b/app/realtime/presence.py new file mode 100644 index 0000000..b6a6bcc --- /dev/null +++ b/app/realtime/presence.py @@ -0,0 +1,34 @@ +from redis.exceptions import RedisError + +from app.utils.redis_client import get_redis_client + + +async def mark_user_online(user_id: int) -> None: + try: + redis = get_redis_client() + key = f"presence:user:{user_id}" + count = await redis.incr(key) + if count == 1: + await redis.expire(key, 3600) + except RedisError: + return + + +async def mark_user_offline(user_id: int) -> None: + try: + redis = get_redis_client() + key = f"presence:user:{user_id}" + value = await redis.decr(key) + if value <= 0: + await redis.delete(key) + except RedisError: + return + + +async def is_user_online(user_id: int) -> bool: + try: + redis = get_redis_client() + value = await redis.get(f"presence:user:{user_id}") + return bool(value and str(value).isdigit() and int(value) > 0) + except RedisError: + return False diff --git a/app/realtime/service.py b/app/realtime/service.py index 6813f75..15f7890 100644 --- a/app/realtime/service.py +++ b/app/realtime/service.py @@ -12,6 +12,7 @@ from app.chats.service import ensure_chat_membership from app.messages.schemas import MessageCreateRequest, MessageRead from app.messages.service import create_chat_message from app.realtime.models import ConnectionContext +from app.realtime.presence import mark_user_offline, mark_user_online from app.realtime.repository import RedisRealtimeRepository from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload @@ -51,6 +52,7 @@ class RealtimeGateway: ) for chat_id in user_chat_ids: self._chat_subscribers[chat_id].add(user_id) + await mark_user_online(user_id) await self._send_user_event( user_id, OutgoingRealtimeEvent( @@ -73,6 +75,7 @@ class RealtimeGateway: subscribers.discard(user_id) if not subscribers: self._chat_subscribers.pop(chat_id, None) + await mark_user_offline(user_id) async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None: message = await create_chat_message( @@ -164,6 +167,7 @@ class RealtimeGateway: subscribers.discard(user_id) if not subscribers: self._chat_subscribers.pop(chat_id, None) + await mark_user_offline(user_id) @staticmethod def _extract_chat_id(channel: str) -> int | None: diff --git a/app/users/repository.py b/app/users/repository.py index 9eb9282..080e5fd 100644 --- a/app/users/repository.py +++ b/app/users/repository.py @@ -24,3 +24,10 @@ async def get_user_by_email(db: AsyncSession, email: str) -> User | None: async def get_user_by_username(db: AsyncSession, username: str) -> User | None: result = await db.execute(select(User).where(User.username == username)) return result.scalar_one_or_none() + + +async def list_users_by_ids(db: AsyncSession, user_ids: list[int]) -> list[User]: + if not user_ids: + return [] + result = await db.execute(select(User).where(User.id.in_(user_ids))) + return list(result.scalars().all()) diff --git a/app/utils/rate_limit.py b/app/utils/rate_limit.py new file mode 100644 index 0000000..2cb79be --- /dev/null +++ b/app/utils/rate_limit.py @@ -0,0 +1,54 @@ +from fastapi import HTTPException, Request, status +from redis.exceptions import RedisError + +from app.utils.redis_client import get_redis_client + + +def _safe_ip(request: Request) -> str: + if not request.client or not request.client.host: + return "unknown" + return request.client.host + + +async def enforce_ip_rate_limit( + request: Request, + *, + scope: str, + limit: int, + window_seconds: int = 60, +) -> None: + if limit <= 0: + return + key = f"rl:{scope}:ip:{_safe_ip(request)}" + await _enforce(key=key, limit=limit, window_seconds=window_seconds) + + +async def enforce_user_rate_limit( + user_id: int, + *, + scope: str, + limit: int, + window_seconds: int = 60, +) -> None: + if limit <= 0: + return + key = f"rl:{scope}:user:{user_id}" + await _enforce(key=key, limit=limit, window_seconds=window_seconds) + + +async def _enforce(*, key: str, limit: int, window_seconds: int) -> None: + try: + redis = get_redis_client() + current = await redis.incr(key) + if current == 1: + await redis.expire(key, window_seconds) + if current > limit: + ttl = await redis.ttl(key) + retry_after = max(1, ttl if ttl and ttl > 0 else window_seconds) + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail=f"Rate limit exceeded. Retry in {retry_after} seconds.", + ) + except RedisError: + # Fail-open in case of Redis outage to keep core API available. + return diff --git a/app/utils/redis_client.py b/app/utils/redis_client.py new file mode 100644 index 0000000..c63ad7d --- /dev/null +++ b/app/utils/redis_client.py @@ -0,0 +1,19 @@ +from redis.asyncio import Redis + +from app.config.settings import settings + +_redis_client: Redis | None = None + + +def get_redis_client() -> Redis: + global _redis_client + if _redis_client is None: + _redis_client = Redis.from_url(settings.redis_url, decode_responses=True) + return _redis_client + + +async def close_redis_client() -> None: + global _redis_client + if _redis_client is not None: + await _redis_client.aclose() + _redis_client = None diff --git a/app/utils/security.py b/app/utils/security.py index 3bed0ac..47a14a4 100644 --- a/app/utils/security.py +++ b/app/utils/security.py @@ -1,5 +1,6 @@ from datetime import datetime, timedelta, timezone from secrets import token_urlsafe +from uuid import uuid4 from jose import JWTError, jwt from passlib.context import CryptContext @@ -18,9 +19,11 @@ def verify_password(password: str, hashed_password: str) -> bool: return pwd_context.verify(password, hashed_password) -def _create_token(subject: str, token_type: str, expires_delta: timedelta) -> str: +def _create_token(subject: str, token_type: str, expires_delta: timedelta, jti: str | None = None) -> str: + now = datetime.now(timezone.utc) expire = datetime.now(timezone.utc) + expires_delta - payload = {"sub": subject, "type": token_type, "exp": expire} + payload = {"sub": subject, "type": token_type, "exp": expire, "iat": now} + payload["jti"] = jti or str(uuid4()) return jwt.encode(payload, settings.secret_key, algorithm=settings.jwt_algorithm) @@ -32,11 +35,12 @@ def create_access_token(subject: str) -> str: ) -def create_refresh_token(subject: str) -> str: +def create_refresh_token(subject: str, *, jti: str | None = None) -> str: return _create_token( subject=subject, token_type="refresh", expires_delta=timedelta(days=settings.refresh_token_expire_days), + jti=jti, ) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..78c5011 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +asyncio_mode = auto +testpaths = tests diff --git a/requirements.txt b/requirements.txt index b4e1795..cbdfd53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ pydantic==2.11.7 pydantic-settings==2.10.1 python-jose[cryptography]==3.5.0 passlib[bcrypt]==1.7.4 +bcrypt==4.0.1 email-validator==2.2.0 python-multipart==0.0.20 redis==6.4.0 @@ -13,3 +14,7 @@ celery==5.5.3 boto3==1.40.31 aiosmtplib==4.0.2 alembic==1.16.5 +pytest==8.4.2 +pytest-asyncio==1.2.0 +httpx==0.28.1 +aiosqlite==0.21.0 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5a9eb61 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,40 @@ +import os +import sys +from pathlib import Path + +# Set test env before importing app modules. +os.environ["POSTGRES_DSN"] = "sqlite+aiosqlite:///./test.db" +os.environ["AUTO_CREATE_TABLES"] = "false" +os.environ["REDIS_URL"] = "redis://localhost:6399/15" +os.environ["SECRET_KEY"] = "test-secret-key-1234567890" +os.environ["CELERY_TASK_ALWAYS_EAGER"] = "true" +PROJECT_ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(PROJECT_ROOT)) + +import pytest +from httpx import ASGITransport, AsyncClient + +from app.database.base import Base +from app.database.session import AsyncSessionLocal, engine +from app.main import app + + +@pytest.fixture(autouse=True) +async def reset_db() -> None: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await conn.run_sync(Base.metadata.create_all) + yield + + +@pytest.fixture +async def client() -> AsyncClient: + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://testserver") as ac: + yield ac + + +@pytest.fixture +async def db_session(): + async with AsyncSessionLocal() as session: + yield session diff --git a/tests/test_auth_flow.py b/tests/test_auth_flow.py new file mode 100644 index 0000000..fb80026 --- /dev/null +++ b/tests/test_auth_flow.py @@ -0,0 +1,69 @@ +from sqlalchemy import select + +from app.auth.models import EmailVerificationToken + + +async def test_register_verify_login_and_me(client, db_session): + register_payload = { + "email": "alice@example.com", + "username": "alice", + "password": "strongpass123", + } + register_response = await client.post("/api/v1/auth/register", json=register_payload) + assert register_response.status_code == 201 + + login_response_before_verify = await client.post( + "/api/v1/auth/login", + json={"email": register_payload["email"], "password": register_payload["password"]}, + ) + assert login_response_before_verify.status_code == 403 + + token_row = await db_session.execute(select(EmailVerificationToken).order_by(EmailVerificationToken.id.desc())) + verify_token = token_row.scalar_one().token + + verify_response = await client.post("/api/v1/auth/verify-email", json={"token": verify_token}) + assert verify_response.status_code == 200 + + login_response = await client.post( + "/api/v1/auth/login", + json={"email": register_payload["email"], "password": register_payload["password"]}, + ) + assert login_response.status_code == 200 + token_data = login_response.json() + assert "access_token" in token_data + assert "refresh_token" in token_data + + me_response = await client.get( + "/api/v1/auth/me", + headers={"Authorization": f"Bearer {token_data['access_token']}"}, + ) + assert me_response.status_code == 200 + me_data = me_response.json() + assert me_data["email"] == "alice@example.com" + assert me_data["email_verified"] is True + + +async def test_refresh_token_rotation(client, db_session): + payload = { + "email": "bob@example.com", + "username": "bob", + "password": "strongpass123", + } + await client.post("/api/v1/auth/register", json=payload) + token_row = await db_session.execute(select(EmailVerificationToken).order_by(EmailVerificationToken.id.desc())) + verify_token = token_row.scalar_one().token + await client.post("/api/v1/auth/verify-email", json={"token": verify_token}) + + login_response = await client.post( + "/api/v1/auth/login", + json={"email": payload["email"], "password": payload["password"]}, + ) + refresh_token = login_response.json()["refresh_token"] + + refresh_response = await client.post("/api/v1/auth/refresh", json={"refresh_token": refresh_token}) + assert refresh_response.status_code == 200 + rotated_refresh_token = refresh_response.json()["refresh_token"] + assert rotated_refresh_token != refresh_token + + old_refresh_reuse = await client.post("/api/v1/auth/refresh", json={"refresh_token": refresh_token}) + assert old_refresh_reuse.status_code == 401 diff --git a/tests/test_chat_message_flow.py b/tests/test_chat_message_flow.py new file mode 100644 index 0000000..4079cfc --- /dev/null +++ b/tests/test_chat_message_flow.py @@ -0,0 +1,61 @@ +from sqlalchemy import select + +from app.auth.models import EmailVerificationToken +from app.chats.models import ChatType + + +async def _create_verified_user(client, db_session, email: str, username: str, password: str) -> dict: + await client.post( + "/api/v1/auth/register", + json={"email": email, "username": username, "password": password}, + ) + token_row = await db_session.execute(select(EmailVerificationToken).order_by(EmailVerificationToken.id.desc())) + verify_token = token_row.scalar_one().token + await client.post("/api/v1/auth/verify-email", json={"token": verify_token}) + login_response = await client.post("/api/v1/auth/login", json={"email": email, "password": password}) + return login_response.json() + + +async def test_private_chat_message_lifecycle(client, db_session): + u1 = await _create_verified_user(client, db_session, "u1@example.com", "user_one", "strongpass123") + u2 = await _create_verified_user(client, db_session, "u2@example.com", "user_two", "strongpass123") + + me_u2 = await client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {u2['access_token']}"}) + u2_id = me_u2.json()["id"] + + create_chat_response = await client.post( + "/api/v1/chats", + headers={"Authorization": f"Bearer {u1['access_token']}"}, + json={"type": ChatType.PRIVATE.value, "title": None, "member_ids": [u2_id]}, + ) + assert create_chat_response.status_code == 200 + chat_id = create_chat_response.json()["id"] + + send_message_response = await client.post( + "/api/v1/messages", + headers={"Authorization": f"Bearer {u1['access_token']}"}, + json={"chat_id": chat_id, "type": "text", "text": "hello @user_two"}, + ) + assert send_message_response.status_code == 201 + message_id = send_message_response.json()["id"] + + list_messages_response = await client.get( + f"/api/v1/messages/{chat_id}", + headers={"Authorization": f"Bearer {u2['access_token']}"}, + ) + assert list_messages_response.status_code == 200 + assert len(list_messages_response.json()) == 1 + + edit_message_response = await client.put( + f"/api/v1/messages/{message_id}", + headers={"Authorization": f"Bearer {u1['access_token']}"}, + json={"text": "edited text"}, + ) + assert edit_message_response.status_code == 200 + assert edit_message_response.json()["text"] == "edited text" + + delete_message_response = await client.delete( + f"/api/v1/messages/{message_id}", + headers={"Authorization": f"Bearer {u1['access_token']}"}, + ) + assert delete_message_response.status_code == 204