refactor: FastAPI best practices — return types, Pydantic schemas, middleware, code dedup

- Все 18 роутов получили return type annotations
- Создан schemas.py с Pydantic-моделями (ScanOut, PackageOut, FindingOut, ...)
- API-роуты: response_model на list/detail/export/stats
- 404 через HTTPException(404) вместо {'detail':'Not found'} (200)
- RequestLoggingMiddleware: method, path, status, duration_ms
- Глобальный exception handler: ловит необработанные исключения → 500
- _parse_flagged(): вынесен дублирующийся string→bool
- parse_package_path(): общий для web.py и api_packages.py
- selectinload: вынесены в top-level imports
- harvester: makedirs/mkdtemp/rmtree обёрнуты в asyncio.to_thread()
This commit is contained in:
Marker689
2026-05-10 12:53:33 +03:00
parent 935d96b35a
commit c1258dde99
11 changed files with 188 additions and 55 deletions

View File

@@ -13,11 +13,12 @@ from ..constants import (
)
from ..db.engine import get_session
from ..db.models import Finding
from ..schemas import FindingsListResponse
router = APIRouter(prefix="/api/v1/findings", tags=["findings"])
@router.get("")
@router.get("", response_model=FindingsListResponse)
async def list_findings(
limit: int = Query(DEFAULT_PAGE_SIZE, le=MAX_PAGE_SIZE),
offset: int = Query(DEFAULT_OFFSET, ge=0),
@@ -25,7 +26,7 @@ async def list_findings(
severity: str | None = Query(None),
scan_id: int | None = Query(None),
session: AsyncSession = Depends(get_session),
):
) -> dict:
q = select(Finding)
if rule:
q = q.where(func.json_extract(Finding.data, JSON_PATH_RULE) == rule)

View File

@@ -2,9 +2,8 @@
import csv
import io
from urllib.parse import unquote
from fastapi import APIRouter, Depends, Query, Response
from fastapi import APIRouter, Depends, HTTPException, Query, Response
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
@@ -17,14 +16,16 @@ from ..constants import (
DEFAULT_SORT_DIR,
MAX_PAGE_SIZE,
)
from ..core.nexus import parse_package_path
from ..db.engine import get_session
from ..db.models import Scan
from ..db.queries import build_package_list_query
from ..schemas import PackageDetailOut, PackageListResponse
router = APIRouter(prefix="/api/v1/packages", tags=["packages"])
@router.get("")
@router.get("", response_model=PackageListResponse)
async def list_packages(
limit: int = Query(DEFAULT_PAGE_SIZE, le=MAX_PAGE_SIZE),
offset: int = Query(DEFAULT_OFFSET, ge=0),
@@ -35,7 +36,7 @@ async def list_packages(
sort_by: str = Query(DEFAULT_SORT_BY_PACKAGES),
sort_dir: str = Query(DEFAULT_SORT_DIR),
session: AsyncSession = Depends(get_session),
):
) -> dict:
rows_q, total_q = build_package_list_query(
flagged=flagged,
ecosystem=ecosystem,
@@ -74,7 +75,7 @@ async def export_packages_csv(
flagged: bool | None = Query(None),
search: str | None = Query(None),
session: AsyncSession = Depends(get_session),
):
) -> Response:
rows_q, _total_q = build_package_list_query(
flagged=flagged,
search=search,
@@ -118,14 +119,12 @@ async def export_packages_csv(
)
@router.get("/{name:path}")
@router.get("/{name:path}", response_model=PackageDetailOut)
async def get_package(
name: str,
session: AsyncSession = Depends(get_session),
):
parts = name.rsplit("/", 1)
pkg_name = unquote(parts[0])
pkg_version = unquote(parts[1]) if len(parts) == 2 else ""
) -> dict:
pkg_name, pkg_version = parse_package_path(name)
scans = (
(
@@ -141,7 +140,7 @@ async def get_package(
)
if not scans:
return {"detail": "Not found"}
raise HTTPException(status_code=404, detail="Package not found")
all_findings: list[dict] = []
for s in scans:

View File

@@ -3,7 +3,7 @@
import csv
import io
from fastapi import APIRouter, Depends, Query, Response
from fastapi import APIRouter, Depends, HTTPException, Query, Response
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
@@ -19,11 +19,12 @@ from ..constants import (
from ..db.engine import get_session
from ..db.models import Scan
from ..db.queries import build_scan_list_query, get_dashboard_stats
from ..schemas import ScanDetailOut, ScanListResponse, StatsResponse
router = APIRouter(prefix="/api/v1/scans", tags=["scans"])
@router.get("")
@router.get("", response_model=ScanListResponse)
async def list_scans(
limit: int = Query(DEFAULT_PAGE_SIZE, le=MAX_PAGE_SIZE),
offset: int = Query(DEFAULT_OFFSET, ge=0),
@@ -34,7 +35,7 @@ async def list_scans(
sort_by: str = Query(DEFAULT_SORT_BY_SCANS),
sort_dir: str = Query(DEFAULT_SORT_DIR),
session: AsyncSession = Depends(get_session),
):
) -> dict:
q, count_q = build_scan_list_query(
flagged=flagged,
status=status,
@@ -77,7 +78,7 @@ async def export_scans_csv(
search: str | None = Query(None),
status: str | None = Query(None),
session: AsyncSession = Depends(get_session),
):
) -> Response:
q, _count_q = build_scan_list_query(
flagged=flagged,
status=status,
@@ -132,8 +133,8 @@ async def export_scans_csv(
)
@router.get("/stats")
async def scan_stats(session: AsyncSession = Depends(get_session)):
@router.get("/stats", response_model=StatsResponse)
async def scan_stats(session: AsyncSession = Depends(get_session)) -> dict:
dashboard = await get_dashboard_stats(session)
return {
"total_scans": dashboard["total_scans"],
@@ -147,13 +148,13 @@ async def scan_stats(session: AsyncSession = Depends(get_session)):
}
@router.get("/{scan_id}")
async def get_scan(scan_id: int, session: AsyncSession = Depends(get_session)):
@router.get("/{scan_id}", response_model=ScanDetailOut)
async def get_scan(scan_id: int, session: AsyncSession = Depends(get_session)) -> dict:
scan = await session.scalar(
select(Scan).where(Scan.id == scan_id).options(selectinload(Scan.findings))
)
if not scan:
return {"detail": "Not found"}
raise HTTPException(status_code=404, detail="Scan not found")
return {
"id": scan.id,
"package_name": scan.package_name,

View File

@@ -13,7 +13,7 @@ router = APIRouter(tags=["metrics"])
@router.get("/metrics")
async def metrics(session: AsyncSession = Depends(get_session)):
async def metrics(session: AsyncSession = Depends(get_session)) -> Response:
total = await session.scalar(select(func.count(Scan.id))) or 0
flagged = await session.scalar(select(func.count(Scan.id)).where(Scan.flagged == True)) or 0
findings_total = await session.scalar(select(func.count(Finding.id))) or 0

View File

@@ -1,13 +1,13 @@
"""Web UI routes — Jinja2 + htmx pages."""
import asyncio
from urllib.parse import unquote
from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse
from jinja2 import Environment, PackageLoader, select_autoescape
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from ..config import config
from ..constants import (
@@ -17,6 +17,7 @@ from ..constants import (
DEFAULT_SORT_DIR,
WEB_PER_PAGE,
)
from ..core.nexus import parse_package_path
from ..db.engine import get_session
from ..db.models import Finding, Scan
from ..db.queries import (
@@ -45,14 +46,20 @@ def _render(name: str, **context) -> HTMLResponse:
return HTMLResponse(template.render(**context), status_code=status_code)
def _parse_flagged(value: str) -> bool | None:
return True if value == "1" else None
@router.get("/", response_class=HTMLResponse)
async def dashboard(request: Request, session: AsyncSession = Depends(get_session)):
async def dashboard(request: Request, session: AsyncSession = Depends(get_session)) -> HTMLResponse:
ctx = await get_dashboard_stats(session)
return _render("dashboard.html", **ctx, request=request)
@router.get("/dashboard/stats", response_class=HTMLResponse)
async def dashboard_stats_fragment(request: Request, session: AsyncSession = Depends(get_session)):
async def dashboard_stats_fragment(
request: Request, session: AsyncSession = Depends(get_session)
) -> HTMLResponse:
ctx = await get_dashboard_stats(session)
return _render("dashboard_stats.html", request=request, **ctx)
@@ -67,13 +74,11 @@ async def scans_list(
sort_by: str = DEFAULT_SORT_BY_SCANS,
sort_dir: str = DEFAULT_SORT_DIR,
session: AsyncSession = Depends(get_session),
):
) -> HTMLResponse:
per_page = WEB_PER_PAGE
offset = (page - 1) * per_page
flagged_bool = None
if flagged == "1":
flagged_bool = True
flagged_bool = _parse_flagged(flagged)
q, count_q = build_scan_list_query(
flagged=flagged_bool,
@@ -105,9 +110,9 @@ async def scans_list(
@router.get("/scans/{scan_id}", response_class=HTMLResponse)
async def scan_detail(scan_id: int, request: Request, session: AsyncSession = Depends(get_session)):
from sqlalchemy.orm import selectinload
async def scan_detail(
scan_id: int, request: Request, session: AsyncSession = Depends(get_session)
) -> HTMLResponse:
scan = await session.scalar(
select(Scan).where(Scan.id == scan_id).options(selectinload(Scan.findings))
)
@@ -126,13 +131,11 @@ async def packages_list(
sort_by: str = DEFAULT_SORT_BY_PACKAGES,
sort_dir: str = DEFAULT_SORT_DIR,
session: AsyncSession = Depends(get_session),
):
) -> HTMLResponse:
per_page = WEB_PER_PAGE
offset = (page - 1) * per_page
flagged_bool = None
if flagged == "1":
flagged_bool = True
flagged_bool = _parse_flagged(flagged)
rows_q, total_q = build_package_list_query(
flagged=flagged_bool,
@@ -166,14 +169,8 @@ async def package_detail(
name: str,
request: Request,
session: AsyncSession = Depends(get_session),
):
# name:path captures the entire path after /packages/
# e.g. "eviltest/0.1.0" or "github.com/attacker/evilmodule/v0.1.0"
parts = name.rsplit("/", 1)
pkg_name = unquote(parts[0])
pkg_version = unquote(parts[1]) if len(parts) == 2 else ""
from sqlalchemy.orm import selectinload
) -> HTMLResponse:
pkg_name, pkg_version = parse_package_path(name)
scans = (
(
@@ -211,7 +208,7 @@ async def analyze_finding_htmx(
request: Request,
retry: bool = False,
session: AsyncSession = Depends(get_session),
):
) -> HTMLResponse:
"""HTMX fragment: trigger LLM analysis and return styled result HTML."""
from ..config import config
from ..core.llm import analyze_finding

View File

@@ -67,7 +67,7 @@ async def nexus_webhook(
request: Request,
background_tasks: BackgroundTasks,
x_nexus_webhook_signature: str | None = Header(None, alias="X-Nexus-Webhook-Signature"),
):
) -> dict:
payload = await request.body()
if config.webhook_secret: