210 lines
6.9 KiB
Python
210 lines
6.9 KiB
Python
import mimetypes
|
|
import json
|
|
from urllib.parse import quote
|
|
from uuid import uuid4
|
|
|
|
import boto3
|
|
from botocore.client import Config
|
|
from botocore.exceptions import BotoCoreError, ClientError
|
|
from fastapi import HTTPException, status
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.config.settings import settings
|
|
from app.media import repository
|
|
from app.media.schemas import AttachmentCreateRequest, AttachmentRead, ChatAttachmentRead, UploadUrlRequest, UploadUrlResponse
|
|
from app.chats.service import ensure_chat_membership
|
|
from app.messages.repository import get_message_by_id
|
|
|
|
ALLOWED_MIME_TYPES = {
|
|
"image/jpeg",
|
|
"image/png",
|
|
"image/webp",
|
|
"video/mp4",
|
|
"video/webm",
|
|
"audio/mpeg",
|
|
"audio/ogg",
|
|
"audio/webm",
|
|
"audio/wav",
|
|
"application/pdf",
|
|
"application/zip",
|
|
"text/plain",
|
|
}
|
|
|
|
|
|
def _normalize_waveform(points: list[int] | None) -> list[int] | None:
|
|
if points is None:
|
|
return None
|
|
normalized = [max(0, min(31, int(value))) for value in points]
|
|
if len(normalized) < 8:
|
|
return None
|
|
return normalized
|
|
|
|
|
|
def _decode_waveform(raw: str | None) -> list[int] | None:
|
|
if not raw:
|
|
return None
|
|
try:
|
|
parsed = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
return None
|
|
if not isinstance(parsed, list):
|
|
return None
|
|
result: list[int] = []
|
|
for value in parsed[:256]:
|
|
if isinstance(value, int):
|
|
result.append(max(0, min(31, value)))
|
|
elif isinstance(value, float):
|
|
result.append(max(0, min(31, int(value))))
|
|
else:
|
|
return None
|
|
return result or None
|
|
|
|
def _normalize_mime(file_type: str) -> str:
|
|
return file_type.split(";", maxsplit=1)[0].strip().lower()
|
|
|
|
|
|
def _extension_from_mime(file_type: str) -> str:
|
|
ext = mimetypes.guess_extension(_normalize_mime(file_type))
|
|
if not ext:
|
|
return ".bin"
|
|
if ext == ".jpe":
|
|
return ".jpg"
|
|
return ext
|
|
|
|
|
|
def _build_file_url(bucket: str, object_key: str) -> str:
|
|
base = (settings.s3_public_endpoint_url or settings.s3_endpoint_url).rstrip("/")
|
|
encoded_key = quote(object_key)
|
|
return f"{base}/{bucket}/{encoded_key}"
|
|
|
|
|
|
def _allowed_file_url_prefixes() -> tuple[str, ...]:
|
|
endpoints = [settings.s3_endpoint_url]
|
|
if settings.s3_public_endpoint_url:
|
|
endpoints.append(settings.s3_public_endpoint_url)
|
|
return tuple(f"{endpoint.rstrip('/')}/{settings.s3_bucket_name}/" for endpoint in endpoints)
|
|
|
|
|
|
def _validate_media(file_type: str, file_size: int) -> None:
|
|
normalized_file_type = _normalize_mime(file_type)
|
|
if normalized_file_type not in ALLOWED_MIME_TYPES:
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Unsupported file type")
|
|
if file_size > settings.max_upload_size_bytes:
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="File size exceeds limit")
|
|
|
|
|
|
def _get_s3_client(endpoint_url: str):
|
|
return boto3.client(
|
|
"s3",
|
|
endpoint_url=endpoint_url,
|
|
aws_access_key_id=settings.s3_access_key,
|
|
aws_secret_access_key=settings.s3_secret_key,
|
|
region_name=settings.s3_region,
|
|
config=Config(signature_version="s3v4", s3={"addressing_style": "path"}),
|
|
)
|
|
|
|
|
|
async def generate_upload_url(payload: UploadUrlRequest) -> UploadUrlResponse:
|
|
_validate_media(payload.file_type, payload.file_size)
|
|
|
|
extension = _extension_from_mime(payload.file_type)
|
|
object_key = f"uploads/{uuid4().hex}{extension}"
|
|
bucket = settings.s3_bucket_name
|
|
|
|
try:
|
|
presign_endpoint = settings.s3_public_endpoint_url or settings.s3_endpoint_url
|
|
s3_client = _get_s3_client(presign_endpoint)
|
|
upload_url = s3_client.generate_presigned_url(
|
|
"put_object",
|
|
Params={
|
|
"Bucket": bucket,
|
|
"Key": object_key,
|
|
"ContentType": payload.file_type,
|
|
},
|
|
ExpiresIn=settings.s3_presign_expire_seconds,
|
|
HttpMethod="PUT",
|
|
)
|
|
except (BotoCoreError, ClientError) as exc:
|
|
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Storage service unavailable") from exc
|
|
|
|
return UploadUrlResponse(
|
|
upload_url=upload_url,
|
|
file_url=_build_file_url(bucket, object_key),
|
|
object_key=object_key,
|
|
expires_in=settings.s3_presign_expire_seconds,
|
|
required_headers={"Content-Type": payload.file_type},
|
|
)
|
|
|
|
|
|
async def store_attachment_metadata(
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: int,
|
|
payload: AttachmentCreateRequest,
|
|
) -> AttachmentRead:
|
|
_validate_media(payload.file_type, payload.file_size)
|
|
if not payload.file_url.startswith(_allowed_file_url_prefixes()):
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Invalid file URL")
|
|
|
|
message = await get_message_by_id(db, payload.message_id)
|
|
if not message:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Message not found")
|
|
if message.sender_id != user_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Only the message sender can attach files",
|
|
)
|
|
|
|
normalized_waveform = _normalize_waveform(payload.waveform_points)
|
|
attachment = await repository.create_attachment(
|
|
db,
|
|
message_id=payload.message_id,
|
|
file_url=payload.file_url,
|
|
file_type=payload.file_type,
|
|
file_size=payload.file_size,
|
|
waveform_data=json.dumps(normalized_waveform, ensure_ascii=True) if normalized_waveform else None,
|
|
)
|
|
await db.commit()
|
|
await db.refresh(attachment)
|
|
return AttachmentRead.model_validate(
|
|
{
|
|
"id": attachment.id,
|
|
"message_id": attachment.message_id,
|
|
"file_url": attachment.file_url,
|
|
"file_type": attachment.file_type,
|
|
"file_size": attachment.file_size,
|
|
"waveform_points": _decode_waveform(attachment.waveform_data),
|
|
}
|
|
)
|
|
|
|
|
|
async def list_attachments_for_chat(
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: int,
|
|
chat_id: int,
|
|
limit: int = 100,
|
|
before_id: int | None = None,
|
|
) -> list[ChatAttachmentRead]:
|
|
await ensure_chat_membership(db, chat_id=chat_id, user_id=user_id)
|
|
rows = await repository.list_chat_attachments(
|
|
db,
|
|
chat_id=chat_id,
|
|
limit=max(1, min(limit, 200)),
|
|
before_id=before_id,
|
|
)
|
|
return [
|
|
ChatAttachmentRead(
|
|
id=attachment.id,
|
|
message_id=attachment.message_id,
|
|
sender_id=message.sender_id,
|
|
message_type=message.type.value if hasattr(message.type, "value") else str(message.type),
|
|
message_created_at=message.created_at,
|
|
file_url=attachment.file_url,
|
|
file_type=attachment.file_type,
|
|
file_size=attachment.file_size,
|
|
waveform_points=_decode_waveform(attachment.waveform_data),
|
|
)
|
|
for attachment, message in rows
|
|
]
|