feat: LLM response validation with defaults, security headers middleware, cleaner stats endpoint
This commit is contained in:
@@ -14,6 +14,22 @@ from ..logging_setup import log
|
||||
|
||||
_llm_semaphore = asyncio.Semaphore(config.llm_max_concurrent)
|
||||
|
||||
_REPORT_DEFAULTS = {
|
||||
"verdict": "unknown",
|
||||
"summary": "No summary provided",
|
||||
"analysis": "No analysis provided",
|
||||
"severity_rating": "unknown",
|
||||
}
|
||||
|
||||
|
||||
def _validate_report(report: dict) -> dict:
|
||||
for field, default in _REPORT_DEFAULTS.items():
|
||||
if not report.get(field):
|
||||
report[field] = default
|
||||
if report["verdict"] not in ("safe", "suspicious", "malicious", "unknown"):
|
||||
report["verdict"] = "unknown"
|
||||
return report
|
||||
|
||||
|
||||
def _build_user_message(finding: dict) -> str:
|
||||
"""Build a concise prompt from a finding's data."""
|
||||
@@ -78,7 +94,8 @@ async def _attempt_llm_call(finding_data: dict) -> dict | None:
|
||||
content = message.get("content", "")
|
||||
if not content:
|
||||
raise ValueError("Empty message content")
|
||||
return json.loads(content)
|
||||
parsed = json.loads(content)
|
||||
return _validate_report(parsed)
|
||||
except (ValueError, json.JSONDecodeError) as e:
|
||||
raw = ""
|
||||
try:
|
||||
@@ -94,7 +111,7 @@ async def _attempt_llm_call(finding_data: dict) -> dict | None:
|
||||
stripped = raw.strip().strip("`").strip()
|
||||
if stripped.startswith("json\n"):
|
||||
stripped = stripped[5:]
|
||||
return json.loads(stripped)
|
||||
return _validate_report(json.loads(stripped))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
log.warning(
|
||||
|
||||
@@ -85,6 +85,17 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
return response
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "no-referrer"
|
||||
response.headers["Permissions-Policy"] = "geolocation=(), microphone=()"
|
||||
return response
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title=APP_NAME,
|
||||
version=APP_VERSION,
|
||||
@@ -92,6 +103,7 @@ app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
)
|
||||
app.add_middleware(LangMiddleware)
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
app.add_middleware(RequestLoggingMiddleware)
|
||||
|
||||
|
||||
|
||||
@@ -136,15 +136,16 @@ async def export_scans_csv(
|
||||
@router.get("/stats", response_model=StatsResponse)
|
||||
async def scan_stats(session: AsyncSession = Depends(get_session)) -> dict:
|
||||
dashboard = await get_dashboard_stats(session)
|
||||
latest = dashboard["latest_flagged"]
|
||||
return {
|
||||
"total_scans": dashboard["total_scans"],
|
||||
"flagged_scans": dashboard["flagged_scans"],
|
||||
"recent_flagged": dashboard["recent_flagged"],
|
||||
"total_findings": dashboard["total_findings"],
|
||||
"top_rules": dashboard["top_rules"],
|
||||
"latest_scan_at": dashboard["latest_flagged"][0].started_at.isoformat()
|
||||
if dashboard["latest_flagged"] and dashboard["latest_flagged"][0].started_at
|
||||
else None,
|
||||
"latest_scan_at": (
|
||||
latest[0].started_at.isoformat() if latest and latest[0].started_at else None
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user