Files
Messenger/app/media/service.py
benya 21c8f57169
Some checks failed
CI / test (push) Failing after 1m31s
fix(media): allow mp4/m4a audio uploads for voice recordings
2026-03-08 20:48:36 +03:00

212 lines
6.9 KiB
Python

import mimetypes
import json
from urllib.parse import quote
from uuid import uuid4
import boto3
from botocore.client import Config
from botocore.exceptions import BotoCoreError, ClientError
from fastapi import HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.media import repository
from app.media.schemas import AttachmentCreateRequest, AttachmentRead, ChatAttachmentRead, UploadUrlRequest, UploadUrlResponse
from app.chats.service import ensure_chat_membership
from app.messages.repository import get_message_by_id
ALLOWED_MIME_TYPES = {
"image/jpeg",
"image/png",
"image/webp",
"video/mp4",
"video/webm",
"audio/mpeg",
"audio/ogg",
"audio/mp4",
"audio/x-m4a",
"audio/webm",
"audio/wav",
"application/pdf",
"application/zip",
"text/plain",
}
def _normalize_waveform(points: list[int] | None) -> list[int] | None:
if points is None:
return None
normalized = [max(0, min(31, int(value))) for value in points]
if len(normalized) < 8:
return None
return normalized
def _decode_waveform(raw: str | None) -> list[int] | None:
if not raw:
return None
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
return None
if not isinstance(parsed, list):
return None
result: list[int] = []
for value in parsed[:256]:
if isinstance(value, int):
result.append(max(0, min(31, value)))
elif isinstance(value, float):
result.append(max(0, min(31, int(value))))
else:
return None
return result or None
def _normalize_mime(file_type: str) -> str:
return file_type.split(";", maxsplit=1)[0].strip().lower()
def _extension_from_mime(file_type: str) -> str:
ext = mimetypes.guess_extension(_normalize_mime(file_type))
if not ext:
return ".bin"
if ext == ".jpe":
return ".jpg"
return ext
def _build_file_url(bucket: str, object_key: str) -> str:
base = (settings.s3_public_endpoint_url or settings.s3_endpoint_url).rstrip("/")
encoded_key = quote(object_key)
return f"{base}/{bucket}/{encoded_key}"
def _allowed_file_url_prefixes() -> tuple[str, ...]:
endpoints = [settings.s3_endpoint_url]
if settings.s3_public_endpoint_url:
endpoints.append(settings.s3_public_endpoint_url)
return tuple(f"{endpoint.rstrip('/')}/{settings.s3_bucket_name}/" for endpoint in endpoints)
def _validate_media(file_type: str, file_size: int) -> None:
normalized_file_type = _normalize_mime(file_type)
if normalized_file_type not in ALLOWED_MIME_TYPES:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Unsupported file type")
if file_size > settings.max_upload_size_bytes:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="File size exceeds limit")
def _get_s3_client(endpoint_url: str):
return boto3.client(
"s3",
endpoint_url=endpoint_url,
aws_access_key_id=settings.s3_access_key,
aws_secret_access_key=settings.s3_secret_key,
region_name=settings.s3_region,
config=Config(signature_version="s3v4", s3={"addressing_style": "path"}),
)
async def generate_upload_url(payload: UploadUrlRequest) -> UploadUrlResponse:
_validate_media(payload.file_type, payload.file_size)
extension = _extension_from_mime(payload.file_type)
object_key = f"uploads/{uuid4().hex}{extension}"
bucket = settings.s3_bucket_name
try:
presign_endpoint = settings.s3_public_endpoint_url or settings.s3_endpoint_url
s3_client = _get_s3_client(presign_endpoint)
upload_url = s3_client.generate_presigned_url(
"put_object",
Params={
"Bucket": bucket,
"Key": object_key,
"ContentType": payload.file_type,
},
ExpiresIn=settings.s3_presign_expire_seconds,
HttpMethod="PUT",
)
except (BotoCoreError, ClientError) as exc:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Storage service unavailable") from exc
return UploadUrlResponse(
upload_url=upload_url,
file_url=_build_file_url(bucket, object_key),
object_key=object_key,
expires_in=settings.s3_presign_expire_seconds,
required_headers={"Content-Type": payload.file_type},
)
async def store_attachment_metadata(
db: AsyncSession,
*,
user_id: int,
payload: AttachmentCreateRequest,
) -> AttachmentRead:
_validate_media(payload.file_type, payload.file_size)
if not payload.file_url.startswith(_allowed_file_url_prefixes()):
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Invalid file URL")
message = await get_message_by_id(db, payload.message_id)
if not message:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
if message.sender_id != user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only the message sender can attach files",
)
normalized_waveform = _normalize_waveform(payload.waveform_points)
attachment = await repository.create_attachment(
db,
message_id=payload.message_id,
file_url=payload.file_url,
file_type=payload.file_type,
file_size=payload.file_size,
waveform_data=json.dumps(normalized_waveform, ensure_ascii=True) if normalized_waveform else None,
)
await db.commit()
await db.refresh(attachment)
return AttachmentRead.model_validate(
{
"id": attachment.id,
"message_id": attachment.message_id,
"file_url": attachment.file_url,
"file_type": attachment.file_type,
"file_size": attachment.file_size,
"waveform_points": _decode_waveform(attachment.waveform_data),
}
)
async def list_attachments_for_chat(
db: AsyncSession,
*,
user_id: int,
chat_id: int,
limit: int = 100,
before_id: int | None = None,
) -> list[ChatAttachmentRead]:
await ensure_chat_membership(db, chat_id=chat_id, user_id=user_id)
rows = await repository.list_chat_attachments(
db,
chat_id=chat_id,
limit=max(1, min(limit, 200)),
before_id=before_id,
)
return [
ChatAttachmentRead(
id=attachment.id,
message_id=attachment.message_id,
sender_id=message.sender_id,
message_type=message.type.value if hasattr(message.type, "value") else str(message.type),
message_created_at=message.created_at,
file_url=attachment.file_url,
file_type=attachment.file_type,
file_size=attachment.file_size,
waveform_points=_decode_waveform(attachment.waveform_data),
)
for attachment, message in rows
]