feat: LLM response validation with defaults, security headers middleware, cleaner stats endpoint

This commit is contained in:
Marker689
2026-05-11 20:11:47 +03:00
parent 2d9ab9f436
commit fe384aed17
3 changed files with 35 additions and 5 deletions

View File

@@ -14,6 +14,22 @@ from ..logging_setup import log
_llm_semaphore = asyncio.Semaphore(config.llm_max_concurrent)
_REPORT_DEFAULTS = {
"verdict": "unknown",
"summary": "No summary provided",
"analysis": "No analysis provided",
"severity_rating": "unknown",
}
def _validate_report(report: dict) -> dict:
for field, default in _REPORT_DEFAULTS.items():
if not report.get(field):
report[field] = default
if report["verdict"] not in ("safe", "suspicious", "malicious", "unknown"):
report["verdict"] = "unknown"
return report
def _build_user_message(finding: dict) -> str:
"""Build a concise prompt from a finding's data."""
@@ -78,7 +94,8 @@ async def _attempt_llm_call(finding_data: dict) -> dict | None:
content = message.get("content", "")
if not content:
raise ValueError("Empty message content")
return json.loads(content)
parsed = json.loads(content)
return _validate_report(parsed)
except (ValueError, json.JSONDecodeError) as e:
raw = ""
try:
@@ -94,7 +111,7 @@ async def _attempt_llm_call(finding_data: dict) -> dict | None:
stripped = raw.strip().strip("`").strip()
if stripped.startswith("json\n"):
stripped = stripped[5:]
return json.loads(stripped)
return _validate_report(json.loads(stripped))
except json.JSONDecodeError:
pass
log.warning(

View File

@@ -85,6 +85,17 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
return response
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "no-referrer"
response.headers["Permissions-Policy"] = "geolocation=(), microphone=()"
return response
app = FastAPI(
title=APP_NAME,
version=APP_VERSION,
@@ -92,6 +103,7 @@ app = FastAPI(
lifespan=lifespan,
)
app.add_middleware(LangMiddleware)
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(RequestLoggingMiddleware)

View File

@@ -136,15 +136,16 @@ async def export_scans_csv(
@router.get("/stats", response_model=StatsResponse)
async def scan_stats(session: AsyncSession = Depends(get_session)) -> dict:
dashboard = await get_dashboard_stats(session)
latest = dashboard["latest_flagged"]
return {
"total_scans": dashboard["total_scans"],
"flagged_scans": dashboard["flagged_scans"],
"recent_flagged": dashboard["recent_flagged"],
"total_findings": dashboard["total_findings"],
"top_rules": dashboard["top_rules"],
"latest_scan_at": dashboard["latest_flagged"][0].started_at.isoformat()
if dashboard["latest_flagged"] and dashboard["latest_flagged"][0].started_at
else None,
"latest_scan_at": (
latest[0].started_at.isoformat() if latest and latest[0].started_at else None
),
}