Implement security hardening, notification pipeline, and CI test suite
All checks were successful
CI / test (push) Successful in 9m2s
All checks were successful
CI / test (push) Successful in 9m2s
Security hardening: - Added IP/user rate limiting with Redis-backed counters and fail-open behavior. - Added message anti-spam controls (per-chat rate + duplicate cooldown). - Implemented refresh token rotation with JTI tracking and revoke support. Notification pipeline: - Added Celery app and async notification tasks for mention/offline delivery. - Added Redis-based presence tracking and integrated it into realtime connect/disconnect. - Added notification dispatch from message flow and notifications listing endpoint. Quality gates and CI: - Added pytest async integration tests for auth and chat/message lifecycle. - Added pytest config, test fixtures, and GitHub Actions CI workflow. - Fixed bcrypt/passlib compatibility by pinning bcrypt version. - Documented worker and quality-gate commands in README.
This commit is contained in:
@@ -30,3 +30,11 @@ SMTP_USERNAME=
|
|||||||
SMTP_PASSWORD=
|
SMTP_PASSWORD=
|
||||||
SMTP_USE_TLS=false
|
SMTP_USE_TLS=false
|
||||||
SMTP_FROM_EMAIL=no-reply@benyamessenger.local
|
SMTP_FROM_EMAIL=no-reply@benyamessenger.local
|
||||||
|
|
||||||
|
LOGIN_RATE_LIMIT_PER_MINUTE=10
|
||||||
|
REGISTER_RATE_LIMIT_PER_MINUTE=5
|
||||||
|
RESET_RATE_LIMIT_PER_MINUTE=5
|
||||||
|
REFRESH_RATE_LIMIT_PER_MINUTE=30
|
||||||
|
MESSAGE_RATE_LIMIT_PER_MINUTE=30
|
||||||
|
DUPLICATE_MESSAGE_COOLDOWN_SECONDS=10
|
||||||
|
CELERY_TASK_ALWAYS_EAGER=false
|
||||||
|
|||||||
32
.github/workflows/ci.yml
vendored
Normal file
32
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: ["main"]
|
||||||
|
pull_request:
|
||||||
|
branches: ["main"]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
- name: Compile check
|
||||||
|
run: |
|
||||||
|
python -m compileall app main.py
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
pytest -q
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -4,3 +4,4 @@ __pycache__/
|
|||||||
*.pyc
|
*.pyc
|
||||||
.idea/
|
.idea/
|
||||||
.env
|
.env
|
||||||
|
test.db
|
||||||
|
|||||||
13
README.md
13
README.md
@@ -10,3 +10,16 @@ Backend foundation for a Telegram-like real-time messaging platform.
|
|||||||
3. Configure environment from `.env.example`.
|
3. Configure environment from `.env.example`.
|
||||||
4. Start API:
|
4. Start API:
|
||||||
uvicorn app.main:app --reload --port 8000
|
uvicorn app.main:app --reload --port 8000
|
||||||
|
|
||||||
|
## Celery Worker
|
||||||
|
|
||||||
|
Run worker for async notification jobs:
|
||||||
|
|
||||||
|
celery -A app.celery_app:celery_app worker --loglevel=info
|
||||||
|
|
||||||
|
## Quality Gates
|
||||||
|
|
||||||
|
- Compile check:
|
||||||
|
python -m compileall app main.py
|
||||||
|
- Tests:
|
||||||
|
pytest -q
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from fastapi import APIRouter, Depends, status
|
from fastapi import APIRouter, Depends, Request, status
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.auth.schemas import (
|
from app.auth.schemas import (
|
||||||
AuthUserResponse,
|
AuthUserResponse,
|
||||||
LoginRequest,
|
LoginRequest,
|
||||||
MessageResponse,
|
MessageResponse,
|
||||||
|
RefreshTokenRequest,
|
||||||
RegisterRequest,
|
RegisterRequest,
|
||||||
RequestPasswordResetRequest,
|
RequestPasswordResetRequest,
|
||||||
ResendVerificationRequest,
|
ResendVerificationRequest,
|
||||||
@@ -16,6 +17,7 @@ from app.auth.service import (
|
|||||||
get_current_user,
|
get_current_user,
|
||||||
get_email_sender,
|
get_email_sender,
|
||||||
login_user,
|
login_user,
|
||||||
|
refresh_tokens,
|
||||||
register_user,
|
register_user,
|
||||||
request_password_reset,
|
request_password_reset,
|
||||||
resend_verification_email,
|
resend_verification_email,
|
||||||
@@ -24,6 +26,8 @@ from app.auth.service import (
|
|||||||
)
|
)
|
||||||
from app.database.session import get_db
|
from app.database.session import get_db
|
||||||
from app.email.service import EmailService
|
from app.email.service import EmailService
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.utils.rate_limit import enforce_ip_rate_limit
|
||||||
from app.users.models import User
|
from app.users.models import User
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
@@ -32,18 +36,43 @@ router = APIRouter(prefix="/auth", tags=["auth"])
|
|||||||
@router.post("/register", response_model=MessageResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("/register", response_model=MessageResponse, status_code=status.HTTP_201_CREATED)
|
||||||
async def register(
|
async def register(
|
||||||
payload: RegisterRequest,
|
payload: RegisterRequest,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
email_service: EmailService = Depends(get_email_sender),
|
email_service: EmailService = Depends(get_email_sender),
|
||||||
) -> MessageResponse:
|
) -> MessageResponse:
|
||||||
|
await enforce_ip_rate_limit(
|
||||||
|
request,
|
||||||
|
scope="auth_register",
|
||||||
|
limit=settings.register_rate_limit_per_minute,
|
||||||
|
)
|
||||||
await register_user(db, payload, email_service)
|
await register_user(db, payload, email_service)
|
||||||
return MessageResponse(message="Registration successful. Verification email sent.")
|
return MessageResponse(message="Registration successful. Verification email sent.")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=TokenResponse)
|
@router.post("/login", response_model=TokenResponse)
|
||||||
async def login(payload: LoginRequest, db: AsyncSession = Depends(get_db)) -> TokenResponse:
|
async def login(payload: LoginRequest, request: Request, db: AsyncSession = Depends(get_db)) -> TokenResponse:
|
||||||
|
await enforce_ip_rate_limit(
|
||||||
|
request,
|
||||||
|
scope="auth_login",
|
||||||
|
limit=settings.login_rate_limit_per_minute,
|
||||||
|
)
|
||||||
return await login_user(db, payload)
|
return await login_user(db, payload)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=TokenResponse)
|
||||||
|
async def refresh(
|
||||||
|
payload: RefreshTokenRequest,
|
||||||
|
request: Request,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> TokenResponse:
|
||||||
|
await enforce_ip_rate_limit(
|
||||||
|
request,
|
||||||
|
scope="auth_refresh",
|
||||||
|
limit=settings.refresh_rate_limit_per_minute,
|
||||||
|
)
|
||||||
|
return await refresh_tokens(db, payload)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/verify-email", response_model=MessageResponse)
|
@router.post("/verify-email", response_model=MessageResponse)
|
||||||
async def verify_email_endpoint(payload: VerifyEmailRequest, db: AsyncSession = Depends(get_db)) -> MessageResponse:
|
async def verify_email_endpoint(payload: VerifyEmailRequest, db: AsyncSession = Depends(get_db)) -> MessageResponse:
|
||||||
await verify_email(db, payload)
|
await verify_email(db, payload)
|
||||||
@@ -53,9 +82,15 @@ async def verify_email_endpoint(payload: VerifyEmailRequest, db: AsyncSession =
|
|||||||
@router.post("/resend-verification", response_model=MessageResponse)
|
@router.post("/resend-verification", response_model=MessageResponse)
|
||||||
async def resend_verification(
|
async def resend_verification(
|
||||||
payload: ResendVerificationRequest,
|
payload: ResendVerificationRequest,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
email_service: EmailService = Depends(get_email_sender),
|
email_service: EmailService = Depends(get_email_sender),
|
||||||
) -> MessageResponse:
|
) -> MessageResponse:
|
||||||
|
await enforce_ip_rate_limit(
|
||||||
|
request,
|
||||||
|
scope="auth_resend_verification",
|
||||||
|
limit=settings.reset_rate_limit_per_minute,
|
||||||
|
)
|
||||||
await resend_verification_email(db, payload, email_service)
|
await resend_verification_email(db, payload, email_service)
|
||||||
return MessageResponse(message="If the account exists, a verification email was sent.")
|
return MessageResponse(message="If the account exists, a verification email was sent.")
|
||||||
|
|
||||||
@@ -63,9 +98,15 @@ async def resend_verification(
|
|||||||
@router.post("/request-password-reset", response_model=MessageResponse)
|
@router.post("/request-password-reset", response_model=MessageResponse)
|
||||||
async def request_password_reset_endpoint(
|
async def request_password_reset_endpoint(
|
||||||
payload: RequestPasswordResetRequest,
|
payload: RequestPasswordResetRequest,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
email_service: EmailService = Depends(get_email_sender),
|
email_service: EmailService = Depends(get_email_sender),
|
||||||
) -> MessageResponse:
|
) -> MessageResponse:
|
||||||
|
await enforce_ip_rate_limit(
|
||||||
|
request,
|
||||||
|
scope="auth_request_reset",
|
||||||
|
limit=settings.reset_rate_limit_per_minute,
|
||||||
|
)
|
||||||
await request_password_reset(db, payload, email_service)
|
await request_password_reset(db, payload, email_service)
|
||||||
return MessageResponse(message="If the account exists, a reset email was sent.")
|
return MessageResponse(message="If the account exists, a reset email was sent.")
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ class LoginRequest(BaseModel):
|
|||||||
password: str = Field(min_length=8, max_length=128)
|
password: str = Field(min_length=8, max_length=128)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenRequest(BaseModel):
|
||||||
|
refresh_token: str = Field(min_length=16)
|
||||||
|
|
||||||
|
|
||||||
class VerifyEmailRequest(BaseModel):
|
class VerifyEmailRequest(BaseModel):
|
||||||
token: str = Field(min_length=16, max_length=512)
|
token: str = Field(min_length=16, max_length=512)
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.auth import repository as auth_repository
|
from app.auth import repository as auth_repository
|
||||||
|
from app.auth.token_store import get_refresh_token_user_id, revoke_refresh_token_jti, store_refresh_token_jti
|
||||||
from app.auth.schemas import (
|
from app.auth.schemas import (
|
||||||
LoginRequest,
|
LoginRequest,
|
||||||
|
RefreshTokenRequest,
|
||||||
RegisterRequest,
|
RegisterRequest,
|
||||||
RequestPasswordResetRequest,
|
RequestPasswordResetRequest,
|
||||||
ResendVerificationRequest,
|
ResendVerificationRequest,
|
||||||
@@ -31,6 +34,10 @@ from app.utils.security import (
|
|||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/login")
|
||||||
|
|
||||||
|
|
||||||
|
def _refresh_ttl_seconds() -> int:
|
||||||
|
return settings.refresh_token_expire_days * 24 * 60 * 60
|
||||||
|
|
||||||
|
|
||||||
async def register_user(
|
async def register_user(
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
payload: RegisterRequest,
|
payload: RegisterRequest,
|
||||||
@@ -105,9 +112,46 @@ async def login_user(db: AsyncSession, payload: LoginRequest) -> TokenResponse:
|
|||||||
if not user.email_verified:
|
if not user.email_verified:
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified")
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified")
|
||||||
|
|
||||||
|
refresh_jti = str(uuid4())
|
||||||
|
refresh_token = create_refresh_token(str(user.id), jti=refresh_jti)
|
||||||
|
await store_refresh_token_jti(user_id=user.id, jti=refresh_jti, ttl_seconds=_refresh_ttl_seconds())
|
||||||
return TokenResponse(
|
return TokenResponse(
|
||||||
access_token=create_access_token(str(user.id)),
|
access_token=create_access_token(str(user.id)),
|
||||||
refresh_token=create_refresh_token(str(user.id)),
|
refresh_token=refresh_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_tokens(db: AsyncSession, payload: RefreshTokenRequest) -> TokenResponse:
|
||||||
|
credentials_error = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid refresh token",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
token_payload = decode_token(payload.refresh_token)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise credentials_error from exc
|
||||||
|
|
||||||
|
if token_payload.get("type") != "refresh":
|
||||||
|
raise credentials_error
|
||||||
|
|
||||||
|
user_id = token_payload.get("sub")
|
||||||
|
refresh_jti = token_payload.get("jti")
|
||||||
|
if not user_id or not str(user_id).isdigit() or not refresh_jti:
|
||||||
|
raise credentials_error
|
||||||
|
|
||||||
|
active_user_id = await get_refresh_token_user_id(jti=refresh_jti)
|
||||||
|
if active_user_id is None or active_user_id != int(user_id):
|
||||||
|
raise credentials_error
|
||||||
|
user = await get_user_by_id(db, int(user_id))
|
||||||
|
if not user:
|
||||||
|
raise credentials_error
|
||||||
|
|
||||||
|
await revoke_refresh_token_jti(jti=refresh_jti)
|
||||||
|
new_jti = str(uuid4())
|
||||||
|
await store_refresh_token_jti(user_id=int(user_id), jti=new_jti, ttl_seconds=_refresh_ttl_seconds())
|
||||||
|
return TokenResponse(
|
||||||
|
access_token=create_access_token(str(user_id)),
|
||||||
|
refresh_token=create_refresh_token(str(user_id), jti=new_jti),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
46
app/auth/token_store.py
Normal file
46
app/auth/token_store.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
from redis.exceptions import RedisError
|
||||||
|
|
||||||
|
from app.utils.redis_client import get_redis_client
|
||||||
|
|
||||||
|
_fallback_tokens: dict[str, tuple[int, float]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_fallback() -> None:
|
||||||
|
now = time.time()
|
||||||
|
expired = [jti for jti, (_, exp_at) in _fallback_tokens.items() if exp_at <= now]
|
||||||
|
for jti in expired:
|
||||||
|
_fallback_tokens.pop(jti, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def store_refresh_token_jti(*, user_id: int, jti: str, ttl_seconds: int) -> None:
|
||||||
|
try:
|
||||||
|
redis = get_redis_client()
|
||||||
|
await redis.set(f"auth:refresh:{jti}", str(user_id), ex=ttl_seconds)
|
||||||
|
except RedisError:
|
||||||
|
_cleanup_fallback()
|
||||||
|
_fallback_tokens[jti] = (user_id, time.time() + ttl_seconds)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_refresh_token_user_id(*, jti: str) -> int | None:
|
||||||
|
try:
|
||||||
|
redis = get_redis_client()
|
||||||
|
value = await redis.get(f"auth:refresh:{jti}")
|
||||||
|
if not value or not str(value).isdigit():
|
||||||
|
return None
|
||||||
|
return int(value)
|
||||||
|
except RedisError:
|
||||||
|
_cleanup_fallback()
|
||||||
|
data = _fallback_tokens.get(jti)
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
return data[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def revoke_refresh_token_jti(*, jti: str) -> None:
|
||||||
|
try:
|
||||||
|
redis = get_redis_client()
|
||||||
|
await redis.delete(f"auth:refresh:{jti}")
|
||||||
|
except RedisError:
|
||||||
|
_fallback_tokens.pop(jti, None)
|
||||||
21
app/celery_app.py
Normal file
21
app/celery_app.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from celery import Celery
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
celery_app = Celery(
|
||||||
|
"benya_messenger",
|
||||||
|
broker=settings.redis_url,
|
||||||
|
backend=settings.redis_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
celery_app.conf.update(
|
||||||
|
task_serializer="json",
|
||||||
|
accept_content=["json"],
|
||||||
|
result_serializer="json",
|
||||||
|
timezone="UTC",
|
||||||
|
enable_utc=True,
|
||||||
|
task_always_eager=settings.celery_task_always_eager,
|
||||||
|
task_eager_propagates=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
celery_app.autodiscover_tasks(["app.notifications"])
|
||||||
@@ -35,6 +35,14 @@ class Settings(BaseSettings):
|
|||||||
smtp_use_tls: bool = False
|
smtp_use_tls: bool = False
|
||||||
smtp_from_email: str = "no-reply@benyamessenger.local"
|
smtp_from_email: str = "no-reply@benyamessenger.local"
|
||||||
|
|
||||||
|
login_rate_limit_per_minute: int = 10
|
||||||
|
register_rate_limit_per_minute: int = 5
|
||||||
|
reset_rate_limit_per_minute: int = 5
|
||||||
|
refresh_rate_limit_per_minute: int = 30
|
||||||
|
message_rate_limit_per_minute: int = 30
|
||||||
|
duplicate_message_cooldown_seconds: int = 10
|
||||||
|
celery_task_always_eager: bool = False
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from app.notifications.router import router as notifications_router
|
|||||||
from app.realtime.router import router as realtime_router
|
from app.realtime.router import router as realtime_router
|
||||||
from app.realtime.service import realtime_gateway
|
from app.realtime.service import realtime_gateway
|
||||||
from app.users.router import router as users_router
|
from app.users.router import router as users_router
|
||||||
|
from app.utils.redis_client import close_redis_client
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -24,6 +25,7 @@ async def lifespan(_app: FastAPI):
|
|||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
yield
|
yield
|
||||||
await realtime_gateway.stop()
|
await realtime_gateway.stop()
|
||||||
|
await close_redis_client()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title=settings.app_name, debug=settings.debug, lifespan=lifespan)
|
app = FastAPI(title=settings.app_name, debug=settings.debug, lifespan=lifespan)
|
||||||
|
|||||||
@@ -4,13 +4,16 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from app.chats.service import ensure_chat_membership
|
from app.chats.service import ensure_chat_membership
|
||||||
from app.messages import repository
|
from app.messages import repository
|
||||||
from app.messages.models import Message
|
from app.messages.models import Message
|
||||||
|
from app.messages.spam_guard import enforce_message_spam_policy
|
||||||
from app.messages.schemas import MessageCreateRequest, MessageUpdateRequest
|
from app.messages.schemas import MessageCreateRequest, MessageUpdateRequest
|
||||||
|
from app.notifications.service import dispatch_message_notifications
|
||||||
|
|
||||||
|
|
||||||
async def create_chat_message(db: AsyncSession, *, sender_id: int, payload: MessageCreateRequest) -> Message:
|
async def create_chat_message(db: AsyncSession, *, sender_id: int, payload: MessageCreateRequest) -> Message:
|
||||||
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=sender_id)
|
await ensure_chat_membership(db, chat_id=payload.chat_id, user_id=sender_id)
|
||||||
if payload.type.value == "text" and not (payload.text and payload.text.strip()):
|
if payload.type.value == "text" and not (payload.text and payload.text.strip()):
|
||||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Text message cannot be empty")
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Text message cannot be empty")
|
||||||
|
await enforce_message_spam_policy(user_id=sender_id, chat_id=payload.chat_id, text=payload.text)
|
||||||
|
|
||||||
message = await repository.create_message(
|
message = await repository.create_message(
|
||||||
db,
|
db,
|
||||||
@@ -21,6 +24,11 @@ async def create_chat_message(db: AsyncSession, *, sender_id: int, payload: Mess
|
|||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(message)
|
await db.refresh(message)
|
||||||
|
try:
|
||||||
|
await dispatch_message_notifications(db, message)
|
||||||
|
except Exception:
|
||||||
|
# Notifications should not block message delivery.
|
||||||
|
pass
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
37
app/messages/spam_guard.py
Normal file
37
app/messages/spam_guard.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import hashlib
|
||||||
|
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from redis.exceptions import RedisError
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.utils.redis_client import get_redis_client
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_text(text: str) -> str:
|
||||||
|
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
async def enforce_message_spam_policy(*, user_id: int, chat_id: int, text: str | None) -> None:
|
||||||
|
redis = get_redis_client()
|
||||||
|
rate_key = f"spam:msg_rate:{user_id}:{chat_id}"
|
||||||
|
try:
|
||||||
|
count = await redis.incr(rate_key)
|
||||||
|
if count == 1:
|
||||||
|
await redis.expire(rate_key, 60)
|
||||||
|
if count > settings.message_rate_limit_per_minute:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
detail="Message rate limit exceeded for this chat.",
|
||||||
|
)
|
||||||
|
|
||||||
|
normalized = (text or "").strip()
|
||||||
|
if normalized:
|
||||||
|
dup_key = f"spam:dup:{user_id}:{chat_id}:{_hash_text(normalized)}"
|
||||||
|
if await redis.exists(dup_key):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
detail="Duplicate message cooldown is active.",
|
||||||
|
)
|
||||||
|
await redis.set(dup_key, "1", ex=settings.duplicate_message_cooldown_seconds)
|
||||||
|
except RedisError:
|
||||||
|
return
|
||||||
@@ -5,3 +5,15 @@ from app.notifications.models import NotificationLog
|
|||||||
|
|
||||||
async def create_notification_log(db: AsyncSession, *, user_id: int, event_type: str, payload: str) -> None:
|
async def create_notification_log(db: AsyncSession, *, user_id: int, event_type: str, payload: str) -> None:
|
||||||
db.add(NotificationLog(user_id=user_id, event_type=event_type, payload=payload))
|
db.add(NotificationLog(user_id=user_id, event_type=event_type, payload=payload))
|
||||||
|
|
||||||
|
|
||||||
|
async def list_user_notifications(db: AsyncSession, *, user_id: int, limit: int = 50) -> list[NotificationLog]:
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(NotificationLog)
|
||||||
|
.where(NotificationLog.user_id == user_id)
|
||||||
|
.order_by(NotificationLog.id.desc())
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|||||||
@@ -1,3 +1,19 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.auth.service import get_current_user
|
||||||
|
from app.database.session import get_db
|
||||||
|
from app.notifications.schemas import NotificationRead
|
||||||
|
from app.notifications.service import get_notifications_for_user
|
||||||
|
from app.users.models import User
|
||||||
|
|
||||||
router = APIRouter(prefix="/notifications", tags=["notifications"])
|
router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=list[NotificationRead])
|
||||||
|
async def list_my_notifications(
|
||||||
|
limit: int = 50,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> list[NotificationRead]:
|
||||||
|
return await get_notifications_for_user(db, user_id=current_user.id, limit=limit)
|
||||||
|
|||||||
@@ -1,7 +1,27 @@
|
|||||||
from pydantic import BaseModel
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
class NotificationRequest(BaseModel):
|
class NotificationRequest(BaseModel):
|
||||||
user_id: int
|
user_id: int
|
||||||
event_type: str
|
event_type: str
|
||||||
payload: dict
|
payload: dict
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationRead(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: int
|
||||||
|
user_id: int
|
||||||
|
event_type: str
|
||||||
|
payload: str
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class PushTaskPayload(BaseModel):
|
||||||
|
user_id: int
|
||||||
|
title: str
|
||||||
|
body: str
|
||||||
|
data: dict[str, Any]
|
||||||
|
|||||||
@@ -1,7 +1,23 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.notifications.repository import create_notification_log
|
from app.chats.repository import list_chat_members
|
||||||
from app.notifications.schemas import NotificationRequest
|
from app.messages.models import Message
|
||||||
|
from app.notifications.repository import create_notification_log, list_user_notifications
|
||||||
|
from app.notifications.schemas import NotificationRead, NotificationRequest
|
||||||
|
from app.notifications.tasks import send_mention_notification_task, send_push_notification_task
|
||||||
|
from app.realtime.presence import is_user_online
|
||||||
|
from app.users.repository import list_users_by_ids
|
||||||
|
|
||||||
|
_MENTION_RE = re.compile(r"@([A-Za-z0-9_]{3,50})")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_mentions(text: str | None) -> set[str]:
|
||||||
|
if not text:
|
||||||
|
return set()
|
||||||
|
return {match.group(1).lower() for match in _MENTION_RE.finditer(text)}
|
||||||
|
|
||||||
|
|
||||||
async def enqueue_notification(db: AsyncSession, payload: NotificationRequest) -> None:
|
async def enqueue_notification(db: AsyncSession, payload: NotificationRequest) -> None:
|
||||||
@@ -9,5 +25,73 @@ async def enqueue_notification(db: AsyncSession, payload: NotificationRequest) -
|
|||||||
db,
|
db,
|
||||||
user_id=payload.user_id,
|
user_id=payload.user_id,
|
||||||
event_type=payload.event_type,
|
event_type=payload.event_type,
|
||||||
payload=payload.payload.__repr__(),
|
payload=json.dumps(payload.payload, ensure_ascii=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def dispatch_message_notifications(db: AsyncSession, message: Message) -> None:
|
||||||
|
members = await list_chat_members(db, chat_id=message.chat_id)
|
||||||
|
recipient_ids = [m.user_id for m in members if m.user_id != message.sender_id]
|
||||||
|
if not recipient_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
users = await list_users_by_ids(db, recipient_ids)
|
||||||
|
user_by_username = {user.username.lower(): user for user in users}
|
||||||
|
mentioned_usernames = _extract_mentions(message.text)
|
||||||
|
mentioned_user_ids = {user_by_username[name].id for name in mentioned_usernames if name in user_by_username}
|
||||||
|
|
||||||
|
sender_users = await list_users_by_ids(db, [message.sender_id])
|
||||||
|
sender_name = sender_users[0].username if sender_users else "Someone"
|
||||||
|
|
||||||
|
for recipient in users:
|
||||||
|
base_payload = {
|
||||||
|
"chat_id": message.chat_id,
|
||||||
|
"message_id": message.id,
|
||||||
|
"sender_id": message.sender_id,
|
||||||
|
}
|
||||||
|
if recipient.id in mentioned_user_ids:
|
||||||
|
payload = {
|
||||||
|
**base_payload,
|
||||||
|
"type": "mention",
|
||||||
|
"text_preview": (message.text or "")[:120],
|
||||||
|
}
|
||||||
|
await create_notification_log(
|
||||||
|
db,
|
||||||
|
user_id=recipient.id,
|
||||||
|
event_type="mention",
|
||||||
|
payload=json.dumps(payload, ensure_ascii=True),
|
||||||
|
)
|
||||||
|
send_mention_notification_task.delay(
|
||||||
|
recipient.id,
|
||||||
|
f"{sender_name} mentioned you",
|
||||||
|
(message.text or "")[:120],
|
||||||
|
payload,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not await is_user_online(recipient.id):
|
||||||
|
payload = {
|
||||||
|
**base_payload,
|
||||||
|
"type": "offline_message",
|
||||||
|
"text_preview": (message.text or "")[:120],
|
||||||
|
}
|
||||||
|
await create_notification_log(
|
||||||
|
db,
|
||||||
|
user_id=recipient.id,
|
||||||
|
event_type="offline_message",
|
||||||
|
payload=json.dumps(payload, ensure_ascii=True),
|
||||||
|
)
|
||||||
|
send_push_notification_task.delay(
|
||||||
|
recipient.id,
|
||||||
|
f"New message from {sender_name}",
|
||||||
|
(message.text or "")[:120],
|
||||||
|
payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_notifications_for_user(db: AsyncSession, *, user_id: int, limit: int = 50) -> list[NotificationRead]:
|
||||||
|
safe_limit = max(1, min(limit, 100))
|
||||||
|
rows = await list_user_notifications(db, user_id=user_id, limit=safe_limit)
|
||||||
|
return [NotificationRead.model_validate(item) for item in rows]
|
||||||
|
|||||||
15
app/notifications/tasks.py
Normal file
15
app/notifications/tasks.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(name="notifications.send_push")
|
||||||
|
def send_push_notification_task(user_id: int, title: str, body: str, data: dict) -> None:
|
||||||
|
logger.info("PUSH user=%s title=%s body=%s data=%s", user_id, title, body, data)
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(name="notifications.send_mention")
|
||||||
|
def send_mention_notification_task(user_id: int, title: str, body: str, data: dict) -> None:
|
||||||
|
logger.info("MENTION user=%s title=%s body=%s data=%s", user_id, title, body, data)
|
||||||
34
app/realtime/presence.py
Normal file
34
app/realtime/presence.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from redis.exceptions import RedisError
|
||||||
|
|
||||||
|
from app.utils.redis_client import get_redis_client
|
||||||
|
|
||||||
|
|
||||||
|
async def mark_user_online(user_id: int) -> None:
|
||||||
|
try:
|
||||||
|
redis = get_redis_client()
|
||||||
|
key = f"presence:user:{user_id}"
|
||||||
|
count = await redis.incr(key)
|
||||||
|
if count == 1:
|
||||||
|
await redis.expire(key, 3600)
|
||||||
|
except RedisError:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def mark_user_offline(user_id: int) -> None:
|
||||||
|
try:
|
||||||
|
redis = get_redis_client()
|
||||||
|
key = f"presence:user:{user_id}"
|
||||||
|
value = await redis.decr(key)
|
||||||
|
if value <= 0:
|
||||||
|
await redis.delete(key)
|
||||||
|
except RedisError:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def is_user_online(user_id: int) -> bool:
|
||||||
|
try:
|
||||||
|
redis = get_redis_client()
|
||||||
|
value = await redis.get(f"presence:user:{user_id}")
|
||||||
|
return bool(value and str(value).isdigit() and int(value) > 0)
|
||||||
|
except RedisError:
|
||||||
|
return False
|
||||||
@@ -12,6 +12,7 @@ from app.chats.service import ensure_chat_membership
|
|||||||
from app.messages.schemas import MessageCreateRequest, MessageRead
|
from app.messages.schemas import MessageCreateRequest, MessageRead
|
||||||
from app.messages.service import create_chat_message
|
from app.messages.service import create_chat_message
|
||||||
from app.realtime.models import ConnectionContext
|
from app.realtime.models import ConnectionContext
|
||||||
|
from app.realtime.presence import mark_user_offline, mark_user_online
|
||||||
from app.realtime.repository import RedisRealtimeRepository
|
from app.realtime.repository import RedisRealtimeRepository
|
||||||
from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload
|
from app.realtime.schemas import ChatEventPayload, MessageStatusPayload, OutgoingRealtimeEvent, SendMessagePayload
|
||||||
|
|
||||||
@@ -51,6 +52,7 @@ class RealtimeGateway:
|
|||||||
)
|
)
|
||||||
for chat_id in user_chat_ids:
|
for chat_id in user_chat_ids:
|
||||||
self._chat_subscribers[chat_id].add(user_id)
|
self._chat_subscribers[chat_id].add(user_id)
|
||||||
|
await mark_user_online(user_id)
|
||||||
await self._send_user_event(
|
await self._send_user_event(
|
||||||
user_id,
|
user_id,
|
||||||
OutgoingRealtimeEvent(
|
OutgoingRealtimeEvent(
|
||||||
@@ -73,6 +75,7 @@ class RealtimeGateway:
|
|||||||
subscribers.discard(user_id)
|
subscribers.discard(user_id)
|
||||||
if not subscribers:
|
if not subscribers:
|
||||||
self._chat_subscribers.pop(chat_id, None)
|
self._chat_subscribers.pop(chat_id, None)
|
||||||
|
await mark_user_offline(user_id)
|
||||||
|
|
||||||
async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None:
|
async def handle_send_message(self, db: AsyncSession, user_id: int, payload: SendMessagePayload) -> None:
|
||||||
message = await create_chat_message(
|
message = await create_chat_message(
|
||||||
@@ -164,6 +167,7 @@ class RealtimeGateway:
|
|||||||
subscribers.discard(user_id)
|
subscribers.discard(user_id)
|
||||||
if not subscribers:
|
if not subscribers:
|
||||||
self._chat_subscribers.pop(chat_id, None)
|
self._chat_subscribers.pop(chat_id, None)
|
||||||
|
await mark_user_offline(user_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_chat_id(channel: str) -> int | None:
|
def _extract_chat_id(channel: str) -> int | None:
|
||||||
|
|||||||
@@ -24,3 +24,10 @@ async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
|
|||||||
async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
|
async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
|
||||||
result = await db.execute(select(User).where(User.username == username))
|
result = await db.execute(select(User).where(User.username == username))
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def list_users_by_ids(db: AsyncSession, user_ids: list[int]) -> list[User]:
|
||||||
|
if not user_ids:
|
||||||
|
return []
|
||||||
|
result = await db.execute(select(User).where(User.id.in_(user_ids)))
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|||||||
54
app/utils/rate_limit.py
Normal file
54
app/utils/rate_limit.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from fastapi import HTTPException, Request, status
|
||||||
|
from redis.exceptions import RedisError
|
||||||
|
|
||||||
|
from app.utils.redis_client import get_redis_client
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_ip(request: Request) -> str:
|
||||||
|
if not request.client or not request.client.host:
|
||||||
|
return "unknown"
|
||||||
|
return request.client.host
|
||||||
|
|
||||||
|
|
||||||
|
async def enforce_ip_rate_limit(
|
||||||
|
request: Request,
|
||||||
|
*,
|
||||||
|
scope: str,
|
||||||
|
limit: int,
|
||||||
|
window_seconds: int = 60,
|
||||||
|
) -> None:
|
||||||
|
if limit <= 0:
|
||||||
|
return
|
||||||
|
key = f"rl:{scope}:ip:{_safe_ip(request)}"
|
||||||
|
await _enforce(key=key, limit=limit, window_seconds=window_seconds)
|
||||||
|
|
||||||
|
|
||||||
|
async def enforce_user_rate_limit(
|
||||||
|
user_id: int,
|
||||||
|
*,
|
||||||
|
scope: str,
|
||||||
|
limit: int,
|
||||||
|
window_seconds: int = 60,
|
||||||
|
) -> None:
|
||||||
|
if limit <= 0:
|
||||||
|
return
|
||||||
|
key = f"rl:{scope}:user:{user_id}"
|
||||||
|
await _enforce(key=key, limit=limit, window_seconds=window_seconds)
|
||||||
|
|
||||||
|
|
||||||
|
async def _enforce(*, key: str, limit: int, window_seconds: int) -> None:
|
||||||
|
try:
|
||||||
|
redis = get_redis_client()
|
||||||
|
current = await redis.incr(key)
|
||||||
|
if current == 1:
|
||||||
|
await redis.expire(key, window_seconds)
|
||||||
|
if current > limit:
|
||||||
|
ttl = await redis.ttl(key)
|
||||||
|
retry_after = max(1, ttl if ttl and ttl > 0 else window_seconds)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
detail=f"Rate limit exceeded. Retry in {retry_after} seconds.",
|
||||||
|
)
|
||||||
|
except RedisError:
|
||||||
|
# Fail-open in case of Redis outage to keep core API available.
|
||||||
|
return
|
||||||
19
app/utils/redis_client.py
Normal file
19
app/utils/redis_client.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
_redis_client: Redis | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_redis_client() -> Redis:
|
||||||
|
global _redis_client
|
||||||
|
if _redis_client is None:
|
||||||
|
_redis_client = Redis.from_url(settings.redis_url, decode_responses=True)
|
||||||
|
return _redis_client
|
||||||
|
|
||||||
|
|
||||||
|
async def close_redis_client() -> None:
|
||||||
|
global _redis_client
|
||||||
|
if _redis_client is not None:
|
||||||
|
await _redis_client.aclose()
|
||||||
|
_redis_client = None
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from secrets import token_urlsafe
|
from secrets import token_urlsafe
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
@@ -18,9 +19,11 @@ def verify_password(password: str, hashed_password: str) -> bool:
|
|||||||
return pwd_context.verify(password, hashed_password)
|
return pwd_context.verify(password, hashed_password)
|
||||||
|
|
||||||
|
|
||||||
def _create_token(subject: str, token_type: str, expires_delta: timedelta) -> str:
|
def _create_token(subject: str, token_type: str, expires_delta: timedelta, jti: str | None = None) -> str:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
expire = datetime.now(timezone.utc) + expires_delta
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
payload = {"sub": subject, "type": token_type, "exp": expire}
|
payload = {"sub": subject, "type": token_type, "exp": expire, "iat": now}
|
||||||
|
payload["jti"] = jti or str(uuid4())
|
||||||
return jwt.encode(payload, settings.secret_key, algorithm=settings.jwt_algorithm)
|
return jwt.encode(payload, settings.secret_key, algorithm=settings.jwt_algorithm)
|
||||||
|
|
||||||
|
|
||||||
@@ -32,11 +35,12 @@ def create_access_token(subject: str) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_refresh_token(subject: str) -> str:
|
def create_refresh_token(subject: str, *, jti: str | None = None) -> str:
|
||||||
return _create_token(
|
return _create_token(
|
||||||
subject=subject,
|
subject=subject,
|
||||||
token_type="refresh",
|
token_type="refresh",
|
||||||
expires_delta=timedelta(days=settings.refresh_token_expire_days),
|
expires_delta=timedelta(days=settings.refresh_token_expire_days),
|
||||||
|
jti=jti,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
3
pytest.ini
Normal file
3
pytest.ini
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[pytest]
|
||||||
|
asyncio_mode = auto
|
||||||
|
testpaths = tests
|
||||||
@@ -6,6 +6,7 @@ pydantic==2.11.7
|
|||||||
pydantic-settings==2.10.1
|
pydantic-settings==2.10.1
|
||||||
python-jose[cryptography]==3.5.0
|
python-jose[cryptography]==3.5.0
|
||||||
passlib[bcrypt]==1.7.4
|
passlib[bcrypt]==1.7.4
|
||||||
|
bcrypt==4.0.1
|
||||||
email-validator==2.2.0
|
email-validator==2.2.0
|
||||||
python-multipart==0.0.20
|
python-multipart==0.0.20
|
||||||
redis==6.4.0
|
redis==6.4.0
|
||||||
@@ -13,3 +14,7 @@ celery==5.5.3
|
|||||||
boto3==1.40.31
|
boto3==1.40.31
|
||||||
aiosmtplib==4.0.2
|
aiosmtplib==4.0.2
|
||||||
alembic==1.16.5
|
alembic==1.16.5
|
||||||
|
pytest==8.4.2
|
||||||
|
pytest-asyncio==1.2.0
|
||||||
|
httpx==0.28.1
|
||||||
|
aiosqlite==0.21.0
|
||||||
|
|||||||
40
tests/conftest.py
Normal file
40
tests/conftest.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Set test env before importing app modules.
|
||||||
|
os.environ["POSTGRES_DSN"] = "sqlite+aiosqlite:///./test.db"
|
||||||
|
os.environ["AUTO_CREATE_TABLES"] = "false"
|
||||||
|
os.environ["REDIS_URL"] = "redis://localhost:6399/15"
|
||||||
|
os.environ["SECRET_KEY"] = "test-secret-key-1234567890"
|
||||||
|
os.environ["CELERY_TASK_ALWAYS_EAGER"] = "true"
|
||||||
|
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
sys.path.insert(0, str(PROJECT_ROOT))
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
from app.database.base import Base
|
||||||
|
from app.database.session import AsyncSessionLocal, engine
|
||||||
|
from app.main import app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def reset_db() -> None:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client() -> AsyncClient:
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
|
||||||
|
yield ac
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with AsyncSessionLocal() as session:
|
||||||
|
yield session
|
||||||
69
tests/test_auth_flow.py
Normal file
69
tests/test_auth_flow.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.auth.models import EmailVerificationToken
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_verify_login_and_me(client, db_session):
|
||||||
|
register_payload = {
|
||||||
|
"email": "alice@example.com",
|
||||||
|
"username": "alice",
|
||||||
|
"password": "strongpass123",
|
||||||
|
}
|
||||||
|
register_response = await client.post("/api/v1/auth/register", json=register_payload)
|
||||||
|
assert register_response.status_code == 201
|
||||||
|
|
||||||
|
login_response_before_verify = await client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": register_payload["email"], "password": register_payload["password"]},
|
||||||
|
)
|
||||||
|
assert login_response_before_verify.status_code == 403
|
||||||
|
|
||||||
|
token_row = await db_session.execute(select(EmailVerificationToken).order_by(EmailVerificationToken.id.desc()))
|
||||||
|
verify_token = token_row.scalar_one().token
|
||||||
|
|
||||||
|
verify_response = await client.post("/api/v1/auth/verify-email", json={"token": verify_token})
|
||||||
|
assert verify_response.status_code == 200
|
||||||
|
|
||||||
|
login_response = await client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": register_payload["email"], "password": register_payload["password"]},
|
||||||
|
)
|
||||||
|
assert login_response.status_code == 200
|
||||||
|
token_data = login_response.json()
|
||||||
|
assert "access_token" in token_data
|
||||||
|
assert "refresh_token" in token_data
|
||||||
|
|
||||||
|
me_response = await client.get(
|
||||||
|
"/api/v1/auth/me",
|
||||||
|
headers={"Authorization": f"Bearer {token_data['access_token']}"},
|
||||||
|
)
|
||||||
|
assert me_response.status_code == 200
|
||||||
|
me_data = me_response.json()
|
||||||
|
assert me_data["email"] == "alice@example.com"
|
||||||
|
assert me_data["email_verified"] is True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_token_rotation(client, db_session):
|
||||||
|
payload = {
|
||||||
|
"email": "bob@example.com",
|
||||||
|
"username": "bob",
|
||||||
|
"password": "strongpass123",
|
||||||
|
}
|
||||||
|
await client.post("/api/v1/auth/register", json=payload)
|
||||||
|
token_row = await db_session.execute(select(EmailVerificationToken).order_by(EmailVerificationToken.id.desc()))
|
||||||
|
verify_token = token_row.scalar_one().token
|
||||||
|
await client.post("/api/v1/auth/verify-email", json={"token": verify_token})
|
||||||
|
|
||||||
|
login_response = await client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": payload["email"], "password": payload["password"]},
|
||||||
|
)
|
||||||
|
refresh_token = login_response.json()["refresh_token"]
|
||||||
|
|
||||||
|
refresh_response = await client.post("/api/v1/auth/refresh", json={"refresh_token": refresh_token})
|
||||||
|
assert refresh_response.status_code == 200
|
||||||
|
rotated_refresh_token = refresh_response.json()["refresh_token"]
|
||||||
|
assert rotated_refresh_token != refresh_token
|
||||||
|
|
||||||
|
old_refresh_reuse = await client.post("/api/v1/auth/refresh", json={"refresh_token": refresh_token})
|
||||||
|
assert old_refresh_reuse.status_code == 401
|
||||||
61
tests/test_chat_message_flow.py
Normal file
61
tests/test_chat_message_flow.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.auth.models import EmailVerificationToken
|
||||||
|
from app.chats.models import ChatType
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_verified_user(client, db_session, email: str, username: str, password: str) -> dict:
|
||||||
|
await client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": email, "username": username, "password": password},
|
||||||
|
)
|
||||||
|
token_row = await db_session.execute(select(EmailVerificationToken).order_by(EmailVerificationToken.id.desc()))
|
||||||
|
verify_token = token_row.scalar_one().token
|
||||||
|
await client.post("/api/v1/auth/verify-email", json={"token": verify_token})
|
||||||
|
login_response = await client.post("/api/v1/auth/login", json={"email": email, "password": password})
|
||||||
|
return login_response.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_private_chat_message_lifecycle(client, db_session):
|
||||||
|
u1 = await _create_verified_user(client, db_session, "u1@example.com", "user_one", "strongpass123")
|
||||||
|
u2 = await _create_verified_user(client, db_session, "u2@example.com", "user_two", "strongpass123")
|
||||||
|
|
||||||
|
me_u2 = await client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {u2['access_token']}"})
|
||||||
|
u2_id = me_u2.json()["id"]
|
||||||
|
|
||||||
|
create_chat_response = await client.post(
|
||||||
|
"/api/v1/chats",
|
||||||
|
headers={"Authorization": f"Bearer {u1['access_token']}"},
|
||||||
|
json={"type": ChatType.PRIVATE.value, "title": None, "member_ids": [u2_id]},
|
||||||
|
)
|
||||||
|
assert create_chat_response.status_code == 200
|
||||||
|
chat_id = create_chat_response.json()["id"]
|
||||||
|
|
||||||
|
send_message_response = await client.post(
|
||||||
|
"/api/v1/messages",
|
||||||
|
headers={"Authorization": f"Bearer {u1['access_token']}"},
|
||||||
|
json={"chat_id": chat_id, "type": "text", "text": "hello @user_two"},
|
||||||
|
)
|
||||||
|
assert send_message_response.status_code == 201
|
||||||
|
message_id = send_message_response.json()["id"]
|
||||||
|
|
||||||
|
list_messages_response = await client.get(
|
||||||
|
f"/api/v1/messages/{chat_id}",
|
||||||
|
headers={"Authorization": f"Bearer {u2['access_token']}"},
|
||||||
|
)
|
||||||
|
assert list_messages_response.status_code == 200
|
||||||
|
assert len(list_messages_response.json()) == 1
|
||||||
|
|
||||||
|
edit_message_response = await client.put(
|
||||||
|
f"/api/v1/messages/{message_id}",
|
||||||
|
headers={"Authorization": f"Bearer {u1['access_token']}"},
|
||||||
|
json={"text": "edited text"},
|
||||||
|
)
|
||||||
|
assert edit_message_response.status_code == 200
|
||||||
|
assert edit_message_response.json()["text"] == "edited text"
|
||||||
|
|
||||||
|
delete_message_response = await client.delete(
|
||||||
|
f"/api/v1/messages/{message_id}",
|
||||||
|
headers={"Authorization": f"Bearer {u1['access_token']}"},
|
||||||
|
)
|
||||||
|
assert delete_message_response.status_code == 204
|
||||||
Reference in New Issue
Block a user