refactor: FastAPI best practices — return types, Pydantic schemas, middleware, code dedup

- Все 18 роутов получили return type annotations
- Создан schemas.py с Pydantic-моделями (ScanOut, PackageOut, FindingOut, ...)
- API-роуты: response_model на list/detail/export/stats
- 404 через HTTPException(404) вместо {'detail':'Not found'} (200)
- RequestLoggingMiddleware: method, path, status, duration_ms
- Глобальный exception handler: ловит необработанные исключения → 500
- _parse_flagged(): вынесен дублирующийся string→bool
- parse_package_path(): общий для web.py и api_packages.py
- selectinload: вынесены в top-level imports
- harvester: makedirs/mkdtemp/rmtree обёрнуты в asyncio.to_thread()
This commit is contained in:
Marker689
2026-05-10 12:53:33 +03:00
parent 935d96b35a
commit c1258dde99
11 changed files with 188 additions and 55 deletions

View File

@@ -95,8 +95,8 @@ async def harvest(
await session.refresh(scan)
try:
os.makedirs(config.temp_dir, exist_ok=True)
tmpdir = tempfile.mkdtemp(dir=config.temp_dir)
await asyncio.to_thread(os.makedirs, config.temp_dir, exist_ok=True)
tmpdir = await asyncio.to_thread(tempfile.mkdtemp, dir=config.temp_dir)
scan.status = ScanStatus.SCANNING.value
await session.commit()
@@ -201,7 +201,7 @@ async def harvest(
return scan
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
await asyncio.to_thread(shutil.rmtree, tmpdir, ignore_errors=True)
async def _run_llm_analysis(findings: list[Finding], session: AsyncSession) -> list[dict]:

View File

@@ -3,6 +3,7 @@
import asyncio
import hashlib
import os
from urllib.parse import unquote
import httpx
@@ -89,6 +90,15 @@ def extract_package_info(asset_path: str, ecosystem: str) -> tuple[str, str] | N
return None
def parse_package_path(path: str) -> tuple[str, str]:
"""Parse a URL path like 'eviltest/0.1.0' or 'github.com/attacker/evilmodule/v0.1.0'
into (package_name, package_version)."""
parts = path.rsplit("/", 1)
pkg_name = unquote(parts[0])
pkg_version = unquote(parts[1]) if len(parts) == 2 else ""
return pkg_name, pkg_version
async def download_asset(download_url: str, dest_dir: str) -> str | None:
"""Download an asset from Nexus using async httpx."""
dest_path = os.path.join(dest_dir, os.path.basename(download_url.split("?")[0]))

View File

@@ -1,10 +1,12 @@
"""GuardDog Nexus — FastAPI application entry point."""
import os
import time
from contextlib import asynccontextmanager
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from starlette.middleware.base import BaseHTTPMiddleware
@@ -56,6 +58,21 @@ async def lifespan(app: FastAPI):
log.info("%s shutting down", APP_NAME)
class RequestLoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
start = time.monotonic()
response = await call_next(request)
duration = (time.monotonic() - start) * 1000
log.info(
"%s %s %s %.1fms",
request.method,
request.url.path,
response.status_code,
duration,
)
return response
app = FastAPI(
title=APP_NAME,
version=APP_VERSION,
@@ -63,6 +80,14 @@ app = FastAPI(
lifespan=lifespan,
)
app.add_middleware(LangMiddleware)
app.add_middleware(RequestLoggingMiddleware)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
log.exception("Unhandled exception on %s %s", request.method, request.url.path)
return JSONResponse(status_code=500, content={"detail": "Internal server error"})
app.include_router(webhook_router)
app.include_router(metrics_router)
@@ -76,7 +101,7 @@ if os.path.isdir(STATIC_DIR):
@app.get("/health")
async def health():
async def health() -> dict:
return {"status": "ok", "version": APP_VERSION}

View File

@@ -13,11 +13,12 @@ from ..constants import (
)
from ..db.engine import get_session
from ..db.models import Finding
from ..schemas import FindingsListResponse
router = APIRouter(prefix="/api/v1/findings", tags=["findings"])
@router.get("")
@router.get("", response_model=FindingsListResponse)
async def list_findings(
limit: int = Query(DEFAULT_PAGE_SIZE, le=MAX_PAGE_SIZE),
offset: int = Query(DEFAULT_OFFSET, ge=0),
@@ -25,7 +26,7 @@ async def list_findings(
severity: str | None = Query(None),
scan_id: int | None = Query(None),
session: AsyncSession = Depends(get_session),
):
) -> dict:
q = select(Finding)
if rule:
q = q.where(func.json_extract(Finding.data, JSON_PATH_RULE) == rule)

View File

@@ -2,9 +2,8 @@
import csv
import io
from urllib.parse import unquote
from fastapi import APIRouter, Depends, Query, Response
from fastapi import APIRouter, Depends, HTTPException, Query, Response
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
@@ -17,14 +16,16 @@ from ..constants import (
DEFAULT_SORT_DIR,
MAX_PAGE_SIZE,
)
from ..core.nexus import parse_package_path
from ..db.engine import get_session
from ..db.models import Scan
from ..db.queries import build_package_list_query
from ..schemas import PackageDetailOut, PackageListResponse
router = APIRouter(prefix="/api/v1/packages", tags=["packages"])
@router.get("")
@router.get("", response_model=PackageListResponse)
async def list_packages(
limit: int = Query(DEFAULT_PAGE_SIZE, le=MAX_PAGE_SIZE),
offset: int = Query(DEFAULT_OFFSET, ge=0),
@@ -35,7 +36,7 @@ async def list_packages(
sort_by: str = Query(DEFAULT_SORT_BY_PACKAGES),
sort_dir: str = Query(DEFAULT_SORT_DIR),
session: AsyncSession = Depends(get_session),
):
) -> dict:
rows_q, total_q = build_package_list_query(
flagged=flagged,
ecosystem=ecosystem,
@@ -74,7 +75,7 @@ async def export_packages_csv(
flagged: bool | None = Query(None),
search: str | None = Query(None),
session: AsyncSession = Depends(get_session),
):
) -> Response:
rows_q, _total_q = build_package_list_query(
flagged=flagged,
search=search,
@@ -118,14 +119,12 @@ async def export_packages_csv(
)
@router.get("/{name:path}")
@router.get("/{name:path}", response_model=PackageDetailOut)
async def get_package(
name: str,
session: AsyncSession = Depends(get_session),
):
parts = name.rsplit("/", 1)
pkg_name = unquote(parts[0])
pkg_version = unquote(parts[1]) if len(parts) == 2 else ""
) -> dict:
pkg_name, pkg_version = parse_package_path(name)
scans = (
(
@@ -141,7 +140,7 @@ async def get_package(
)
if not scans:
return {"detail": "Not found"}
raise HTTPException(status_code=404, detail="Package not found")
all_findings: list[dict] = []
for s in scans:

View File

@@ -3,7 +3,7 @@
import csv
import io
from fastapi import APIRouter, Depends, Query, Response
from fastapi import APIRouter, Depends, HTTPException, Query, Response
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
@@ -19,11 +19,12 @@ from ..constants import (
from ..db.engine import get_session
from ..db.models import Scan
from ..db.queries import build_scan_list_query, get_dashboard_stats
from ..schemas import ScanDetailOut, ScanListResponse, StatsResponse
router = APIRouter(prefix="/api/v1/scans", tags=["scans"])
@router.get("")
@router.get("", response_model=ScanListResponse)
async def list_scans(
limit: int = Query(DEFAULT_PAGE_SIZE, le=MAX_PAGE_SIZE),
offset: int = Query(DEFAULT_OFFSET, ge=0),
@@ -34,7 +35,7 @@ async def list_scans(
sort_by: str = Query(DEFAULT_SORT_BY_SCANS),
sort_dir: str = Query(DEFAULT_SORT_DIR),
session: AsyncSession = Depends(get_session),
):
) -> dict:
q, count_q = build_scan_list_query(
flagged=flagged,
status=status,
@@ -77,7 +78,7 @@ async def export_scans_csv(
search: str | None = Query(None),
status: str | None = Query(None),
session: AsyncSession = Depends(get_session),
):
) -> Response:
q, _count_q = build_scan_list_query(
flagged=flagged,
status=status,
@@ -132,8 +133,8 @@ async def export_scans_csv(
)
@router.get("/stats")
async def scan_stats(session: AsyncSession = Depends(get_session)):
@router.get("/stats", response_model=StatsResponse)
async def scan_stats(session: AsyncSession = Depends(get_session)) -> dict:
dashboard = await get_dashboard_stats(session)
return {
"total_scans": dashboard["total_scans"],
@@ -147,13 +148,13 @@ async def scan_stats(session: AsyncSession = Depends(get_session)):
}
@router.get("/{scan_id}")
async def get_scan(scan_id: int, session: AsyncSession = Depends(get_session)):
@router.get("/{scan_id}", response_model=ScanDetailOut)
async def get_scan(scan_id: int, session: AsyncSession = Depends(get_session)) -> dict:
scan = await session.scalar(
select(Scan).where(Scan.id == scan_id).options(selectinload(Scan.findings))
)
if not scan:
return {"detail": "Not found"}
raise HTTPException(status_code=404, detail="Scan not found")
return {
"id": scan.id,
"package_name": scan.package_name,

View File

@@ -13,7 +13,7 @@ router = APIRouter(tags=["metrics"])
@router.get("/metrics")
async def metrics(session: AsyncSession = Depends(get_session)):
async def metrics(session: AsyncSession = Depends(get_session)) -> Response:
total = await session.scalar(select(func.count(Scan.id))) or 0
flagged = await session.scalar(select(func.count(Scan.id)).where(Scan.flagged == True)) or 0
findings_total = await session.scalar(select(func.count(Finding.id))) or 0

View File

@@ -1,13 +1,13 @@
"""Web UI routes — Jinja2 + htmx pages."""
import asyncio
from urllib.parse import unquote
from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse
from jinja2 import Environment, PackageLoader, select_autoescape
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from ..config import config
from ..constants import (
@@ -17,6 +17,7 @@ from ..constants import (
DEFAULT_SORT_DIR,
WEB_PER_PAGE,
)
from ..core.nexus import parse_package_path
from ..db.engine import get_session
from ..db.models import Finding, Scan
from ..db.queries import (
@@ -45,14 +46,20 @@ def _render(name: str, **context) -> HTMLResponse:
return HTMLResponse(template.render(**context), status_code=status_code)
def _parse_flagged(value: str) -> bool | None:
return True if value == "1" else None
@router.get("/", response_class=HTMLResponse)
async def dashboard(request: Request, session: AsyncSession = Depends(get_session)):
async def dashboard(request: Request, session: AsyncSession = Depends(get_session)) -> HTMLResponse:
ctx = await get_dashboard_stats(session)
return _render("dashboard.html", **ctx, request=request)
@router.get("/dashboard/stats", response_class=HTMLResponse)
async def dashboard_stats_fragment(request: Request, session: AsyncSession = Depends(get_session)):
async def dashboard_stats_fragment(
request: Request, session: AsyncSession = Depends(get_session)
) -> HTMLResponse:
ctx = await get_dashboard_stats(session)
return _render("dashboard_stats.html", request=request, **ctx)
@@ -67,13 +74,11 @@ async def scans_list(
sort_by: str = DEFAULT_SORT_BY_SCANS,
sort_dir: str = DEFAULT_SORT_DIR,
session: AsyncSession = Depends(get_session),
):
) -> HTMLResponse:
per_page = WEB_PER_PAGE
offset = (page - 1) * per_page
flagged_bool = None
if flagged == "1":
flagged_bool = True
flagged_bool = _parse_flagged(flagged)
q, count_q = build_scan_list_query(
flagged=flagged_bool,
@@ -105,9 +110,9 @@ async def scans_list(
@router.get("/scans/{scan_id}", response_class=HTMLResponse)
async def scan_detail(scan_id: int, request: Request, session: AsyncSession = Depends(get_session)):
from sqlalchemy.orm import selectinload
async def scan_detail(
scan_id: int, request: Request, session: AsyncSession = Depends(get_session)
) -> HTMLResponse:
scan = await session.scalar(
select(Scan).where(Scan.id == scan_id).options(selectinload(Scan.findings))
)
@@ -126,13 +131,11 @@ async def packages_list(
sort_by: str = DEFAULT_SORT_BY_PACKAGES,
sort_dir: str = DEFAULT_SORT_DIR,
session: AsyncSession = Depends(get_session),
):
) -> HTMLResponse:
per_page = WEB_PER_PAGE
offset = (page - 1) * per_page
flagged_bool = None
if flagged == "1":
flagged_bool = True
flagged_bool = _parse_flagged(flagged)
rows_q, total_q = build_package_list_query(
flagged=flagged_bool,
@@ -166,14 +169,8 @@ async def package_detail(
name: str,
request: Request,
session: AsyncSession = Depends(get_session),
):
# name:path captures the entire path after /packages/
# e.g. "eviltest/0.1.0" or "github.com/attacker/evilmodule/v0.1.0"
parts = name.rsplit("/", 1)
pkg_name = unquote(parts[0])
pkg_version = unquote(parts[1]) if len(parts) == 2 else ""
from sqlalchemy.orm import selectinload
) -> HTMLResponse:
pkg_name, pkg_version = parse_package_path(name)
scans = (
(
@@ -211,7 +208,7 @@ async def analyze_finding_htmx(
request: Request,
retry: bool = False,
session: AsyncSession = Depends(get_session),
):
) -> HTMLResponse:
"""HTMX fragment: trigger LLM analysis and return styled result HTML."""
from ..config import config
from ..core.llm import analyze_finding

View File

@@ -67,7 +67,7 @@ async def nexus_webhook(
request: Request,
background_tasks: BackgroundTasks,
x_nexus_webhook_signature: str | None = Header(None, alias="X-Nexus-Webhook-Signature"),
):
) -> dict:
payload = await request.body()
if config.webhook_secret:

102
guarddog_nexus/schemas.py Normal file
View File

@@ -0,0 +1,102 @@
"""Pydantic schemas for API request/response models."""
from datetime import datetime
from pydantic import BaseModel
class ScanOut(BaseModel):
id: int
package_name: str
package_version: str
ecosystem: str
repository: str
status: str
total_findings: int
flagged: bool
started_at: datetime | None = None
finished_at: datetime | None = None
error_message: str | None = None
model_config = {"from_attributes": True}
class ScanListResponse(BaseModel):
total: int
limit: int
offset: int
scans: list[ScanOut]
class ScanDetailOut(ScanOut):
nexus_asset_url: str | None = None
sha256: str | None = None
initiator: str | None = None
source_ip: str | None = None
findings: list[dict] = []
class FindingOut(BaseModel):
id: int
scan_id: int
rule: str = ""
severity: str = ""
message: str = ""
location: str = ""
code: str = ""
report: dict | None = None
created_at: datetime | None = None
model_config = {"from_attributes": True}
class FindingsListResponse(BaseModel):
total: int
limit: int
offset: int
findings: list[FindingOut]
class PackageOut(BaseModel):
name: str
version: str
ecosystem: str
repository: str
last_scanned_at: datetime | None = None
flagged: bool
total_findings: int
latest_scan_id: int
class PackageListResponse(BaseModel):
total: int
limit: int
offset: int
packages: list[PackageOut]
class PackageScanOut(BaseModel):
id: int
status: str
total_findings: int
flagged: bool
started_at: datetime | None = None
class PackageDetailOut(BaseModel):
name: str
version: str
ecosystem: str
repository: str
flagged: bool
scans: list[PackageScanOut]
findings: list[dict]
class StatsResponse(BaseModel):
total_scans: int
flagged_scans: int
recent_flagged: int
total_findings: int
top_rules: list[dict]
latest_scan_at: datetime | None = None

View File

@@ -34,8 +34,7 @@ async def test_scan_stats_empty(client):
@pytest.mark.asyncio
async def test_scan_not_found(client):
resp = await client.get("/api/v1/scans/99999")
assert resp.status_code == 200
assert "detail" in resp.json()
assert resp.status_code == 404
@pytest.mark.asyncio
@@ -129,8 +128,7 @@ async def test_package_with_data(client, sample_flagged_scan):
@pytest.mark.asyncio
async def test_package_not_found(client):
resp = await client.get("/api/v1/packages/nonexistent/1.0")
assert resp.status_code == 200
assert "detail" in resp.json()
assert resp.status_code == 404
# --- Findings ---