refactor: FastAPI best practices — return types, Pydantic schemas, middleware, code dedup

- Все 18 роутов получили return type annotations
- Создан schemas.py с Pydantic-моделями (ScanOut, PackageOut, FindingOut, ...)
- API-роуты: response_model на list/detail/export/stats
- 404 через HTTPException(404) вместо {'detail':'Not found'} (200)
- RequestLoggingMiddleware: method, path, status, duration_ms
- Глобальный exception handler: ловит необработанные исключения → 500
- _parse_flagged(): вынесен дублирующийся string→bool
- parse_package_path(): общий для web.py и api_packages.py
- selectinload: вынесены в top-level imports
- harvester: makedirs/mkdtemp/rmtree обёрнуты в asyncio.to_thread()
This commit is contained in:
Marker689
2026-05-10 12:53:33 +03:00
parent 935d96b35a
commit c1258dde99
11 changed files with 188 additions and 55 deletions

View File

@@ -95,8 +95,8 @@ async def harvest(
await session.refresh(scan) await session.refresh(scan)
try: try:
os.makedirs(config.temp_dir, exist_ok=True) await asyncio.to_thread(os.makedirs, config.temp_dir, exist_ok=True)
tmpdir = tempfile.mkdtemp(dir=config.temp_dir) tmpdir = await asyncio.to_thread(tempfile.mkdtemp, dir=config.temp_dir)
scan.status = ScanStatus.SCANNING.value scan.status = ScanStatus.SCANNING.value
await session.commit() await session.commit()
@@ -201,7 +201,7 @@ async def harvest(
return scan return scan
finally: finally:
shutil.rmtree(tmpdir, ignore_errors=True) await asyncio.to_thread(shutil.rmtree, tmpdir, ignore_errors=True)
async def _run_llm_analysis(findings: list[Finding], session: AsyncSession) -> list[dict]: async def _run_llm_analysis(findings: list[Finding], session: AsyncSession) -> list[dict]:

View File

@@ -3,6 +3,7 @@
import asyncio import asyncio
import hashlib import hashlib
import os import os
from urllib.parse import unquote
import httpx import httpx
@@ -89,6 +90,15 @@ def extract_package_info(asset_path: str, ecosystem: str) -> tuple[str, str] | N
return None return None
def parse_package_path(path: str) -> tuple[str, str]:
"""Parse a URL path like 'eviltest/0.1.0' or 'github.com/attacker/evilmodule/v0.1.0'
into (package_name, package_version)."""
parts = path.rsplit("/", 1)
pkg_name = unquote(parts[0])
pkg_version = unquote(parts[1]) if len(parts) == 2 else ""
return pkg_name, pkg_version
async def download_asset(download_url: str, dest_dir: str) -> str | None: async def download_asset(download_url: str, dest_dir: str) -> str | None:
"""Download an asset from Nexus using async httpx.""" """Download an asset from Nexus using async httpx."""
dest_path = os.path.join(dest_dir, os.path.basename(download_url.split("?")[0])) dest_path = os.path.join(dest_dir, os.path.basename(download_url.split("?")[0]))

View File

@@ -1,10 +1,12 @@
"""GuardDog Nexus — FastAPI application entry point.""" """GuardDog Nexus — FastAPI application entry point."""
import os import os
import time
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import uvicorn import uvicorn
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
@@ -56,6 +58,21 @@ async def lifespan(app: FastAPI):
log.info("%s shutting down", APP_NAME) log.info("%s shutting down", APP_NAME)
class RequestLoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
start = time.monotonic()
response = await call_next(request)
duration = (time.monotonic() - start) * 1000
log.info(
"%s %s %s %.1fms",
request.method,
request.url.path,
response.status_code,
duration,
)
return response
app = FastAPI( app = FastAPI(
title=APP_NAME, title=APP_NAME,
version=APP_VERSION, version=APP_VERSION,
@@ -63,6 +80,14 @@ app = FastAPI(
lifespan=lifespan, lifespan=lifespan,
) )
app.add_middleware(LangMiddleware) app.add_middleware(LangMiddleware)
app.add_middleware(RequestLoggingMiddleware)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
log.exception("Unhandled exception on %s %s", request.method, request.url.path)
return JSONResponse(status_code=500, content={"detail": "Internal server error"})
app.include_router(webhook_router) app.include_router(webhook_router)
app.include_router(metrics_router) app.include_router(metrics_router)
@@ -76,7 +101,7 @@ if os.path.isdir(STATIC_DIR):
@app.get("/health") @app.get("/health")
async def health(): async def health() -> dict:
return {"status": "ok", "version": APP_VERSION} return {"status": "ok", "version": APP_VERSION}

View File

@@ -13,11 +13,12 @@ from ..constants import (
) )
from ..db.engine import get_session from ..db.engine import get_session
from ..db.models import Finding from ..db.models import Finding
from ..schemas import FindingsListResponse
router = APIRouter(prefix="/api/v1/findings", tags=["findings"]) router = APIRouter(prefix="/api/v1/findings", tags=["findings"])
@router.get("") @router.get("", response_model=FindingsListResponse)
async def list_findings( async def list_findings(
limit: int = Query(DEFAULT_PAGE_SIZE, le=MAX_PAGE_SIZE), limit: int = Query(DEFAULT_PAGE_SIZE, le=MAX_PAGE_SIZE),
offset: int = Query(DEFAULT_OFFSET, ge=0), offset: int = Query(DEFAULT_OFFSET, ge=0),
@@ -25,7 +26,7 @@ async def list_findings(
severity: str | None = Query(None), severity: str | None = Query(None),
scan_id: int | None = Query(None), scan_id: int | None = Query(None),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
): ) -> dict:
q = select(Finding) q = select(Finding)
if rule: if rule:
q = q.where(func.json_extract(Finding.data, JSON_PATH_RULE) == rule) q = q.where(func.json_extract(Finding.data, JSON_PATH_RULE) == rule)

View File

@@ -2,9 +2,8 @@
import csv import csv
import io import io
from urllib.parse import unquote
from fastapi import APIRouter, Depends, Query, Response from fastapi import APIRouter, Depends, HTTPException, Query, Response
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
@@ -17,14 +16,16 @@ from ..constants import (
DEFAULT_SORT_DIR, DEFAULT_SORT_DIR,
MAX_PAGE_SIZE, MAX_PAGE_SIZE,
) )
from ..core.nexus import parse_package_path
from ..db.engine import get_session from ..db.engine import get_session
from ..db.models import Scan from ..db.models import Scan
from ..db.queries import build_package_list_query from ..db.queries import build_package_list_query
from ..schemas import PackageDetailOut, PackageListResponse
router = APIRouter(prefix="/api/v1/packages", tags=["packages"]) router = APIRouter(prefix="/api/v1/packages", tags=["packages"])
@router.get("") @router.get("", response_model=PackageListResponse)
async def list_packages( async def list_packages(
limit: int = Query(DEFAULT_PAGE_SIZE, le=MAX_PAGE_SIZE), limit: int = Query(DEFAULT_PAGE_SIZE, le=MAX_PAGE_SIZE),
offset: int = Query(DEFAULT_OFFSET, ge=0), offset: int = Query(DEFAULT_OFFSET, ge=0),
@@ -35,7 +36,7 @@ async def list_packages(
sort_by: str = Query(DEFAULT_SORT_BY_PACKAGES), sort_by: str = Query(DEFAULT_SORT_BY_PACKAGES),
sort_dir: str = Query(DEFAULT_SORT_DIR), sort_dir: str = Query(DEFAULT_SORT_DIR),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
): ) -> dict:
rows_q, total_q = build_package_list_query( rows_q, total_q = build_package_list_query(
flagged=flagged, flagged=flagged,
ecosystem=ecosystem, ecosystem=ecosystem,
@@ -74,7 +75,7 @@ async def export_packages_csv(
flagged: bool | None = Query(None), flagged: bool | None = Query(None),
search: str | None = Query(None), search: str | None = Query(None),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
): ) -> Response:
rows_q, _total_q = build_package_list_query( rows_q, _total_q = build_package_list_query(
flagged=flagged, flagged=flagged,
search=search, search=search,
@@ -118,14 +119,12 @@ async def export_packages_csv(
) )
@router.get("/{name:path}") @router.get("/{name:path}", response_model=PackageDetailOut)
async def get_package( async def get_package(
name: str, name: str,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
): ) -> dict:
parts = name.rsplit("/", 1) pkg_name, pkg_version = parse_package_path(name)
pkg_name = unquote(parts[0])
pkg_version = unquote(parts[1]) if len(parts) == 2 else ""
scans = ( scans = (
( (
@@ -141,7 +140,7 @@ async def get_package(
) )
if not scans: if not scans:
return {"detail": "Not found"} raise HTTPException(status_code=404, detail="Package not found")
all_findings: list[dict] = [] all_findings: list[dict] = []
for s in scans: for s in scans:

View File

@@ -3,7 +3,7 @@
import csv import csv
import io import io
from fastapi import APIRouter, Depends, Query, Response from fastapi import APIRouter, Depends, HTTPException, Query, Response
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
@@ -19,11 +19,12 @@ from ..constants import (
from ..db.engine import get_session from ..db.engine import get_session
from ..db.models import Scan from ..db.models import Scan
from ..db.queries import build_scan_list_query, get_dashboard_stats from ..db.queries import build_scan_list_query, get_dashboard_stats
from ..schemas import ScanDetailOut, ScanListResponse, StatsResponse
router = APIRouter(prefix="/api/v1/scans", tags=["scans"]) router = APIRouter(prefix="/api/v1/scans", tags=["scans"])
@router.get("") @router.get("", response_model=ScanListResponse)
async def list_scans( async def list_scans(
limit: int = Query(DEFAULT_PAGE_SIZE, le=MAX_PAGE_SIZE), limit: int = Query(DEFAULT_PAGE_SIZE, le=MAX_PAGE_SIZE),
offset: int = Query(DEFAULT_OFFSET, ge=0), offset: int = Query(DEFAULT_OFFSET, ge=0),
@@ -34,7 +35,7 @@ async def list_scans(
sort_by: str = Query(DEFAULT_SORT_BY_SCANS), sort_by: str = Query(DEFAULT_SORT_BY_SCANS),
sort_dir: str = Query(DEFAULT_SORT_DIR), sort_dir: str = Query(DEFAULT_SORT_DIR),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
): ) -> dict:
q, count_q = build_scan_list_query( q, count_q = build_scan_list_query(
flagged=flagged, flagged=flagged,
status=status, status=status,
@@ -77,7 +78,7 @@ async def export_scans_csv(
search: str | None = Query(None), search: str | None = Query(None),
status: str | None = Query(None), status: str | None = Query(None),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
): ) -> Response:
q, _count_q = build_scan_list_query( q, _count_q = build_scan_list_query(
flagged=flagged, flagged=flagged,
status=status, status=status,
@@ -132,8 +133,8 @@ async def export_scans_csv(
) )
@router.get("/stats") @router.get("/stats", response_model=StatsResponse)
async def scan_stats(session: AsyncSession = Depends(get_session)): async def scan_stats(session: AsyncSession = Depends(get_session)) -> dict:
dashboard = await get_dashboard_stats(session) dashboard = await get_dashboard_stats(session)
return { return {
"total_scans": dashboard["total_scans"], "total_scans": dashboard["total_scans"],
@@ -147,13 +148,13 @@ async def scan_stats(session: AsyncSession = Depends(get_session)):
} }
@router.get("/{scan_id}") @router.get("/{scan_id}", response_model=ScanDetailOut)
async def get_scan(scan_id: int, session: AsyncSession = Depends(get_session)): async def get_scan(scan_id: int, session: AsyncSession = Depends(get_session)) -> dict:
scan = await session.scalar( scan = await session.scalar(
select(Scan).where(Scan.id == scan_id).options(selectinload(Scan.findings)) select(Scan).where(Scan.id == scan_id).options(selectinload(Scan.findings))
) )
if not scan: if not scan:
return {"detail": "Not found"} raise HTTPException(status_code=404, detail="Scan not found")
return { return {
"id": scan.id, "id": scan.id,
"package_name": scan.package_name, "package_name": scan.package_name,

View File

@@ -13,7 +13,7 @@ router = APIRouter(tags=["metrics"])
@router.get("/metrics") @router.get("/metrics")
async def metrics(session: AsyncSession = Depends(get_session)): async def metrics(session: AsyncSession = Depends(get_session)) -> Response:
total = await session.scalar(select(func.count(Scan.id))) or 0 total = await session.scalar(select(func.count(Scan.id))) or 0
flagged = await session.scalar(select(func.count(Scan.id)).where(Scan.flagged == True)) or 0 flagged = await session.scalar(select(func.count(Scan.id)).where(Scan.flagged == True)) or 0
findings_total = await session.scalar(select(func.count(Finding.id))) or 0 findings_total = await session.scalar(select(func.count(Finding.id))) or 0

View File

@@ -1,13 +1,13 @@
"""Web UI routes — Jinja2 + htmx pages.""" """Web UI routes — Jinja2 + htmx pages."""
import asyncio import asyncio
from urllib.parse import unquote
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from jinja2 import Environment, PackageLoader, select_autoescape from jinja2 import Environment, PackageLoader, select_autoescape
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from ..config import config from ..config import config
from ..constants import ( from ..constants import (
@@ -17,6 +17,7 @@ from ..constants import (
DEFAULT_SORT_DIR, DEFAULT_SORT_DIR,
WEB_PER_PAGE, WEB_PER_PAGE,
) )
from ..core.nexus import parse_package_path
from ..db.engine import get_session from ..db.engine import get_session
from ..db.models import Finding, Scan from ..db.models import Finding, Scan
from ..db.queries import ( from ..db.queries import (
@@ -45,14 +46,20 @@ def _render(name: str, **context) -> HTMLResponse:
return HTMLResponse(template.render(**context), status_code=status_code) return HTMLResponse(template.render(**context), status_code=status_code)
def _parse_flagged(value: str) -> bool | None:
return True if value == "1" else None
@router.get("/", response_class=HTMLResponse) @router.get("/", response_class=HTMLResponse)
async def dashboard(request: Request, session: AsyncSession = Depends(get_session)): async def dashboard(request: Request, session: AsyncSession = Depends(get_session)) -> HTMLResponse:
ctx = await get_dashboard_stats(session) ctx = await get_dashboard_stats(session)
return _render("dashboard.html", **ctx, request=request) return _render("dashboard.html", **ctx, request=request)
@router.get("/dashboard/stats", response_class=HTMLResponse) @router.get("/dashboard/stats", response_class=HTMLResponse)
async def dashboard_stats_fragment(request: Request, session: AsyncSession = Depends(get_session)): async def dashboard_stats_fragment(
request: Request, session: AsyncSession = Depends(get_session)
) -> HTMLResponse:
ctx = await get_dashboard_stats(session) ctx = await get_dashboard_stats(session)
return _render("dashboard_stats.html", request=request, **ctx) return _render("dashboard_stats.html", request=request, **ctx)
@@ -67,13 +74,11 @@ async def scans_list(
sort_by: str = DEFAULT_SORT_BY_SCANS, sort_by: str = DEFAULT_SORT_BY_SCANS,
sort_dir: str = DEFAULT_SORT_DIR, sort_dir: str = DEFAULT_SORT_DIR,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
): ) -> HTMLResponse:
per_page = WEB_PER_PAGE per_page = WEB_PER_PAGE
offset = (page - 1) * per_page offset = (page - 1) * per_page
flagged_bool = None flagged_bool = _parse_flagged(flagged)
if flagged == "1":
flagged_bool = True
q, count_q = build_scan_list_query( q, count_q = build_scan_list_query(
flagged=flagged_bool, flagged=flagged_bool,
@@ -105,9 +110,9 @@ async def scans_list(
@router.get("/scans/{scan_id}", response_class=HTMLResponse) @router.get("/scans/{scan_id}", response_class=HTMLResponse)
async def scan_detail(scan_id: int, request: Request, session: AsyncSession = Depends(get_session)): async def scan_detail(
from sqlalchemy.orm import selectinload scan_id: int, request: Request, session: AsyncSession = Depends(get_session)
) -> HTMLResponse:
scan = await session.scalar( scan = await session.scalar(
select(Scan).where(Scan.id == scan_id).options(selectinload(Scan.findings)) select(Scan).where(Scan.id == scan_id).options(selectinload(Scan.findings))
) )
@@ -126,13 +131,11 @@ async def packages_list(
sort_by: str = DEFAULT_SORT_BY_PACKAGES, sort_by: str = DEFAULT_SORT_BY_PACKAGES,
sort_dir: str = DEFAULT_SORT_DIR, sort_dir: str = DEFAULT_SORT_DIR,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
): ) -> HTMLResponse:
per_page = WEB_PER_PAGE per_page = WEB_PER_PAGE
offset = (page - 1) * per_page offset = (page - 1) * per_page
flagged_bool = None flagged_bool = _parse_flagged(flagged)
if flagged == "1":
flagged_bool = True
rows_q, total_q = build_package_list_query( rows_q, total_q = build_package_list_query(
flagged=flagged_bool, flagged=flagged_bool,
@@ -166,14 +169,8 @@ async def package_detail(
name: str, name: str,
request: Request, request: Request,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
): ) -> HTMLResponse:
# name:path captures the entire path after /packages/ pkg_name, pkg_version = parse_package_path(name)
# e.g. "eviltest/0.1.0" or "github.com/attacker/evilmodule/v0.1.0"
parts = name.rsplit("/", 1)
pkg_name = unquote(parts[0])
pkg_version = unquote(parts[1]) if len(parts) == 2 else ""
from sqlalchemy.orm import selectinload
scans = ( scans = (
( (
@@ -211,7 +208,7 @@ async def analyze_finding_htmx(
request: Request, request: Request,
retry: bool = False, retry: bool = False,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
): ) -> HTMLResponse:
"""HTMX fragment: trigger LLM analysis and return styled result HTML.""" """HTMX fragment: trigger LLM analysis and return styled result HTML."""
from ..config import config from ..config import config
from ..core.llm import analyze_finding from ..core.llm import analyze_finding

View File

@@ -67,7 +67,7 @@ async def nexus_webhook(
request: Request, request: Request,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
x_nexus_webhook_signature: str | None = Header(None, alias="X-Nexus-Webhook-Signature"), x_nexus_webhook_signature: str | None = Header(None, alias="X-Nexus-Webhook-Signature"),
): ) -> dict:
payload = await request.body() payload = await request.body()
if config.webhook_secret: if config.webhook_secret:

102
guarddog_nexus/schemas.py Normal file
View File

@@ -0,0 +1,102 @@
"""Pydantic schemas for API request/response models."""
from datetime import datetime
from pydantic import BaseModel
class ScanOut(BaseModel):
id: int
package_name: str
package_version: str
ecosystem: str
repository: str
status: str
total_findings: int
flagged: bool
started_at: datetime | None = None
finished_at: datetime | None = None
error_message: str | None = None
model_config = {"from_attributes": True}
class ScanListResponse(BaseModel):
total: int
limit: int
offset: int
scans: list[ScanOut]
class ScanDetailOut(ScanOut):
nexus_asset_url: str | None = None
sha256: str | None = None
initiator: str | None = None
source_ip: str | None = None
findings: list[dict] = []
class FindingOut(BaseModel):
id: int
scan_id: int
rule: str = ""
severity: str = ""
message: str = ""
location: str = ""
code: str = ""
report: dict | None = None
created_at: datetime | None = None
model_config = {"from_attributes": True}
class FindingsListResponse(BaseModel):
total: int
limit: int
offset: int
findings: list[FindingOut]
class PackageOut(BaseModel):
name: str
version: str
ecosystem: str
repository: str
last_scanned_at: datetime | None = None
flagged: bool
total_findings: int
latest_scan_id: int
class PackageListResponse(BaseModel):
total: int
limit: int
offset: int
packages: list[PackageOut]
class PackageScanOut(BaseModel):
id: int
status: str
total_findings: int
flagged: bool
started_at: datetime | None = None
class PackageDetailOut(BaseModel):
name: str
version: str
ecosystem: str
repository: str
flagged: bool
scans: list[PackageScanOut]
findings: list[dict]
class StatsResponse(BaseModel):
total_scans: int
flagged_scans: int
recent_flagged: int
total_findings: int
top_rules: list[dict]
latest_scan_at: datetime | None = None

View File

@@ -34,8 +34,7 @@ async def test_scan_stats_empty(client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_scan_not_found(client): async def test_scan_not_found(client):
resp = await client.get("/api/v1/scans/99999") resp = await client.get("/api/v1/scans/99999")
assert resp.status_code == 200 assert resp.status_code == 404
assert "detail" in resp.json()
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -129,8 +128,7 @@ async def test_package_with_data(client, sample_flagged_scan):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_package_not_found(client): async def test_package_not_found(client):
resp = await client.get("/api/v1/packages/nonexistent/1.0") resp = await client.get("/api/v1/packages/nonexistent/1.0")
assert resp.status_code == 200 assert resp.status_code == 404
assert "detail" in resp.json()
# --- Findings --- # --- Findings ---