139 lines
4.2 KiB
Python
139 lines
4.2 KiB
Python
import logging
|
||
import os
|
||
from collections import deque
|
||
from datetime import datetime, timezone
|
||
from logging.handlers import TimedRotatingFileHandler
|
||
from typing import Any, Optional
|
||
|
||
from aiogram import BaseMiddleware
|
||
from aiogram.types import CallbackQuery, Message
|
||
|
||
|
||
def _get_audit_path(cfg: dict[str, Any]) -> str:
|
||
return cfg.get("audit", {}).get("path", "/var/server-bot/audit.log")
|
||
|
||
|
||
def get_audit_logger(cfg: dict[str, Any]) -> logging.Logger:
|
||
logger = logging.getLogger("audit")
|
||
if logger.handlers:
|
||
return logger
|
||
|
||
path = _get_audit_path(cfg)
|
||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||
|
||
rotate_when = cfg.get("audit", {}).get("rotate_when", "W0")
|
||
backup_count = int(cfg.get("audit", {}).get("backup_count", 8))
|
||
handler = TimedRotatingFileHandler(
|
||
path,
|
||
when=rotate_when,
|
||
interval=1,
|
||
backupCount=backup_count,
|
||
encoding="utf-8",
|
||
utc=True,
|
||
)
|
||
formatter = logging.Formatter("%(asctime)s\t%(message)s")
|
||
handler.setFormatter(formatter)
|
||
|
||
logger.setLevel(logging.INFO)
|
||
logger.addHandler(handler)
|
||
logger.propagate = False
|
||
return logger
|
||
|
||
|
||
def audit_health(cfg: dict[str, Any]) -> tuple[bool, str]:
|
||
path = _get_audit_path(cfg)
|
||
directory = os.path.dirname(path)
|
||
try:
|
||
os.makedirs(directory, exist_ok=True)
|
||
if not os.path.exists(path):
|
||
with open(path, "a", encoding="utf-8"):
|
||
pass
|
||
if not os.access(path, os.W_OK):
|
||
return False, f"Audit log not writable: {path}"
|
||
except Exception as e:
|
||
return False, f"Audit log error: {e}"
|
||
return True, path
|
||
|
||
|
||
def audit_start(cfg: dict[str, Any]) -> None:
|
||
logger = get_audit_logger(cfg)
|
||
ok, detail = audit_health(cfg)
|
||
status = "ok" if ok else "error"
|
||
logger.info("startup\tstatus=%s\tpath=%s", status, detail)
|
||
|
||
|
||
def _user_label(message: Message | CallbackQuery) -> str:
|
||
user = message.from_user
|
||
if not user:
|
||
return "unknown"
|
||
parts = [user.username, user.first_name, user.last_name]
|
||
label = " ".join(p for p in parts if p)
|
||
return label or str(user.id)
|
||
|
||
|
||
def _normalize_action(text: str, limit: int = 200) -> str:
|
||
cleaned = " ".join(text.split())
|
||
if len(cleaned) > limit:
|
||
return cleaned[:limit] + "…"
|
||
return cleaned
|
||
|
||
|
||
class AuditMiddleware(BaseMiddleware):
|
||
def __init__(self, cfg: dict[str, Any]) -> None:
|
||
self.cfg = cfg
|
||
self.logger = get_audit_logger(cfg)
|
||
|
||
async def __call__(self, handler, event, data):
|
||
if not self.cfg.get("audit", {}).get("enabled", True):
|
||
return await handler(event, data)
|
||
|
||
action: Optional[str] = None
|
||
if isinstance(event, Message):
|
||
if event.text:
|
||
action = _normalize_action(event.text)
|
||
elif event.caption:
|
||
action = _normalize_action(event.caption)
|
||
else:
|
||
action = f"<{event.content_type}>"
|
||
elif isinstance(event, CallbackQuery):
|
||
if event.data:
|
||
action = _normalize_action(f"callback:{event.data}")
|
||
else:
|
||
action = "callback:<empty>"
|
||
|
||
if action:
|
||
chat_id = event.chat.id if isinstance(event, Message) else event.message.chat.id
|
||
user_id = event.from_user.id if event.from_user else "unknown"
|
||
label = _user_label(event)
|
||
self.logger.info(
|
||
"user_id=%s\tuser=%s\tchat_id=%s\taction=%s",
|
||
user_id,
|
||
label,
|
||
chat_id,
|
||
action,
|
||
)
|
||
|
||
return await handler(event, data)
|
||
|
||
|
||
def read_audit_tail(cfg: dict[str, Any], limit: int = 200) -> str:
|
||
path = _get_audit_path(cfg)
|
||
if not os.path.exists(path):
|
||
return "⚠️ Audit log not found"
|
||
|
||
lines = deque(maxlen=limit)
|
||
with open(path, "r", encoding="utf-8", errors="replace") as f:
|
||
for line in f:
|
||
lines.append(line.rstrip())
|
||
|
||
if not lines:
|
||
return "ℹ️ Audit log is empty"
|
||
|
||
header = f"🧾 Audit log ({datetime.now(timezone.utc):%Y-%m-%d %H:%M UTC})"
|
||
body = "\n".join(lines)
|
||
max_body = 3500
|
||
if len(body) > max_body:
|
||
body = body[-max_body:]
|
||
body = "...(truncated)\n" + body
|
||
return f"{header}\n```\n{body}\n```"
|