refactor: FastAPI best practices — return types, Pydantic schemas, middleware, code dedup
- Все 18 роутов получили return type annotations
- Создан schemas.py с Pydantic-моделями (ScanOut, PackageOut, FindingOut, ...)
- API-роуты: response_model на list/detail/export/stats
- 404 через HTTPException(404) вместо {'detail':'Not found'} (200)
- RequestLoggingMiddleware: method, path, status, duration_ms
- Глобальный exception handler: ловит необработанные исключения → 500
- _parse_flagged(): вынесен дублирующийся string→bool
- parse_package_path(): общий для web.py и api_packages.py
- selectinload: вынесены в top-level imports
- harvester: makedirs/mkdtemp/rmtree обёрнуты в asyncio.to_thread()
This commit is contained in:
@@ -95,8 +95,8 @@ async def harvest(
|
||||
await session.refresh(scan)
|
||||
|
||||
try:
|
||||
os.makedirs(config.temp_dir, exist_ok=True)
|
||||
tmpdir = tempfile.mkdtemp(dir=config.temp_dir)
|
||||
await asyncio.to_thread(os.makedirs, config.temp_dir, exist_ok=True)
|
||||
tmpdir = await asyncio.to_thread(tempfile.mkdtemp, dir=config.temp_dir)
|
||||
scan.status = ScanStatus.SCANNING.value
|
||||
await session.commit()
|
||||
|
||||
@@ -201,7 +201,7 @@ async def harvest(
|
||||
return scan
|
||||
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
await asyncio.to_thread(shutil.rmtree, tmpdir, ignore_errors=True)
|
||||
|
||||
|
||||
async def _run_llm_analysis(findings: list[Finding], session: AsyncSession) -> list[dict]:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
from urllib.parse import unquote
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -89,6 +90,15 @@ def extract_package_info(asset_path: str, ecosystem: str) -> tuple[str, str] | N
|
||||
return None
|
||||
|
||||
|
||||
def parse_package_path(path: str) -> tuple[str, str]:
|
||||
"""Parse a URL path like 'eviltest/0.1.0' or 'github.com/attacker/evilmodule/v0.1.0'
|
||||
into (package_name, package_version)."""
|
||||
parts = path.rsplit("/", 1)
|
||||
pkg_name = unquote(parts[0])
|
||||
pkg_version = unquote(parts[1]) if len(parts) == 2 else ""
|
||||
return pkg_name, pkg_version
|
||||
|
||||
|
||||
async def download_asset(download_url: str, dest_dir: str) -> str | None:
|
||||
"""Download an asset from Nexus using async httpx."""
|
||||
dest_path = os.path.join(dest_dir, os.path.basename(download_url.split("?")[0]))
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""GuardDog Nexus — FastAPI application entry point."""
|
||||
|
||||
import os
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
@@ -56,6 +58,21 @@ async def lifespan(app: FastAPI):
|
||||
log.info("%s shutting down", APP_NAME)
|
||||
|
||||
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
start = time.monotonic()
|
||||
response = await call_next(request)
|
||||
duration = (time.monotonic() - start) * 1000
|
||||
log.info(
|
||||
"%s %s %s %.1fms",
|
||||
request.method,
|
||||
request.url.path,
|
||||
response.status_code,
|
||||
duration,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title=APP_NAME,
|
||||
version=APP_VERSION,
|
||||
@@ -63,6 +80,14 @@ app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
)
|
||||
app.add_middleware(LangMiddleware)
|
||||
app.add_middleware(RequestLoggingMiddleware)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
log.exception("Unhandled exception on %s %s", request.method, request.url.path)
|
||||
return JSONResponse(status_code=500, content={"detail": "Internal server error"})
|
||||
|
||||
|
||||
app.include_router(webhook_router)
|
||||
app.include_router(metrics_router)
|
||||
@@ -76,7 +101,7 @@ if os.path.isdir(STATIC_DIR):
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
async def health() -> dict:
|
||||
return {"status": "ok", "version": APP_VERSION}
|
||||
|
||||
|
||||
|
||||
@@ -13,11 +13,12 @@ from ..constants import (
|
||||
)
|
||||
from ..db.engine import get_session
|
||||
from ..db.models import Finding
|
||||
from ..schemas import FindingsListResponse
|
||||
|
||||
router = APIRouter(prefix="/api/v1/findings", tags=["findings"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
@router.get("", response_model=FindingsListResponse)
|
||||
async def list_findings(
|
||||
limit: int = Query(DEFAULT_PAGE_SIZE, le=MAX_PAGE_SIZE),
|
||||
offset: int = Query(DEFAULT_OFFSET, ge=0),
|
||||
@@ -25,7 +26,7 @@ async def list_findings(
|
||||
severity: str | None = Query(None),
|
||||
scan_id: int | None = Query(None),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
) -> dict:
|
||||
q = select(Finding)
|
||||
if rule:
|
||||
q = q.where(func.json_extract(Finding.data, JSON_PATH_RULE) == rule)
|
||||
|
||||
@@ -2,9 +2,8 @@
|
||||
|
||||
import csv
|
||||
import io
|
||||
from urllib.parse import unquote
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Response
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
@@ -17,14 +16,16 @@ from ..constants import (
|
||||
DEFAULT_SORT_DIR,
|
||||
MAX_PAGE_SIZE,
|
||||
)
|
||||
from ..core.nexus import parse_package_path
|
||||
from ..db.engine import get_session
|
||||
from ..db.models import Scan
|
||||
from ..db.queries import build_package_list_query
|
||||
from ..schemas import PackageDetailOut, PackageListResponse
|
||||
|
||||
router = APIRouter(prefix="/api/v1/packages", tags=["packages"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
@router.get("", response_model=PackageListResponse)
|
||||
async def list_packages(
|
||||
limit: int = Query(DEFAULT_PAGE_SIZE, le=MAX_PAGE_SIZE),
|
||||
offset: int = Query(DEFAULT_OFFSET, ge=0),
|
||||
@@ -35,7 +36,7 @@ async def list_packages(
|
||||
sort_by: str = Query(DEFAULT_SORT_BY_PACKAGES),
|
||||
sort_dir: str = Query(DEFAULT_SORT_DIR),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
) -> dict:
|
||||
rows_q, total_q = build_package_list_query(
|
||||
flagged=flagged,
|
||||
ecosystem=ecosystem,
|
||||
@@ -74,7 +75,7 @@ async def export_packages_csv(
|
||||
flagged: bool | None = Query(None),
|
||||
search: str | None = Query(None),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
) -> Response:
|
||||
rows_q, _total_q = build_package_list_query(
|
||||
flagged=flagged,
|
||||
search=search,
|
||||
@@ -118,14 +119,12 @@ async def export_packages_csv(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{name:path}")
|
||||
@router.get("/{name:path}", response_model=PackageDetailOut)
|
||||
async def get_package(
|
||||
name: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
parts = name.rsplit("/", 1)
|
||||
pkg_name = unquote(parts[0])
|
||||
pkg_version = unquote(parts[1]) if len(parts) == 2 else ""
|
||||
) -> dict:
|
||||
pkg_name, pkg_version = parse_package_path(name)
|
||||
|
||||
scans = (
|
||||
(
|
||||
@@ -141,7 +140,7 @@ async def get_package(
|
||||
)
|
||||
|
||||
if not scans:
|
||||
return {"detail": "Not found"}
|
||||
raise HTTPException(status_code=404, detail="Package not found")
|
||||
|
||||
all_findings: list[dict] = []
|
||||
for s in scans:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import csv
|
||||
import io
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Response
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
@@ -19,11 +19,12 @@ from ..constants import (
|
||||
from ..db.engine import get_session
|
||||
from ..db.models import Scan
|
||||
from ..db.queries import build_scan_list_query, get_dashboard_stats
|
||||
from ..schemas import ScanDetailOut, ScanListResponse, StatsResponse
|
||||
|
||||
router = APIRouter(prefix="/api/v1/scans", tags=["scans"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
@router.get("", response_model=ScanListResponse)
|
||||
async def list_scans(
|
||||
limit: int = Query(DEFAULT_PAGE_SIZE, le=MAX_PAGE_SIZE),
|
||||
offset: int = Query(DEFAULT_OFFSET, ge=0),
|
||||
@@ -34,7 +35,7 @@ async def list_scans(
|
||||
sort_by: str = Query(DEFAULT_SORT_BY_SCANS),
|
||||
sort_dir: str = Query(DEFAULT_SORT_DIR),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
) -> dict:
|
||||
q, count_q = build_scan_list_query(
|
||||
flagged=flagged,
|
||||
status=status,
|
||||
@@ -77,7 +78,7 @@ async def export_scans_csv(
|
||||
search: str | None = Query(None),
|
||||
status: str | None = Query(None),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
) -> Response:
|
||||
q, _count_q = build_scan_list_query(
|
||||
flagged=flagged,
|
||||
status=status,
|
||||
@@ -132,8 +133,8 @@ async def export_scans_csv(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def scan_stats(session: AsyncSession = Depends(get_session)):
|
||||
@router.get("/stats", response_model=StatsResponse)
|
||||
async def scan_stats(session: AsyncSession = Depends(get_session)) -> dict:
|
||||
dashboard = await get_dashboard_stats(session)
|
||||
return {
|
||||
"total_scans": dashboard["total_scans"],
|
||||
@@ -147,13 +148,13 @@ async def scan_stats(session: AsyncSession = Depends(get_session)):
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{scan_id}")
|
||||
async def get_scan(scan_id: int, session: AsyncSession = Depends(get_session)):
|
||||
@router.get("/{scan_id}", response_model=ScanDetailOut)
|
||||
async def get_scan(scan_id: int, session: AsyncSession = Depends(get_session)) -> dict:
|
||||
scan = await session.scalar(
|
||||
select(Scan).where(Scan.id == scan_id).options(selectinload(Scan.findings))
|
||||
)
|
||||
if not scan:
|
||||
return {"detail": "Not found"}
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
return {
|
||||
"id": scan.id,
|
||||
"package_name": scan.package_name,
|
||||
|
||||
@@ -13,7 +13,7 @@ router = APIRouter(tags=["metrics"])
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def metrics(session: AsyncSession = Depends(get_session)):
|
||||
async def metrics(session: AsyncSession = Depends(get_session)) -> Response:
|
||||
total = await session.scalar(select(func.count(Scan.id))) or 0
|
||||
flagged = await session.scalar(select(func.count(Scan.id)).where(Scan.flagged == True)) or 0
|
||||
findings_total = await session.scalar(select(func.count(Finding.id))) or 0
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
"""Web UI routes — Jinja2 + htmx pages."""
|
||||
|
||||
import asyncio
|
||||
from urllib.parse import unquote
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from jinja2 import Environment, PackageLoader, select_autoescape
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from ..config import config
|
||||
from ..constants import (
|
||||
@@ -17,6 +17,7 @@ from ..constants import (
|
||||
DEFAULT_SORT_DIR,
|
||||
WEB_PER_PAGE,
|
||||
)
|
||||
from ..core.nexus import parse_package_path
|
||||
from ..db.engine import get_session
|
||||
from ..db.models import Finding, Scan
|
||||
from ..db.queries import (
|
||||
@@ -45,14 +46,20 @@ def _render(name: str, **context) -> HTMLResponse:
|
||||
return HTMLResponse(template.render(**context), status_code=status_code)
|
||||
|
||||
|
||||
def _parse_flagged(value: str) -> bool | None:
|
||||
return True if value == "1" else None
|
||||
|
||||
|
||||
@router.get("/", response_class=HTMLResponse)
|
||||
async def dashboard(request: Request, session: AsyncSession = Depends(get_session)):
|
||||
async def dashboard(request: Request, session: AsyncSession = Depends(get_session)) -> HTMLResponse:
|
||||
ctx = await get_dashboard_stats(session)
|
||||
return _render("dashboard.html", **ctx, request=request)
|
||||
|
||||
|
||||
@router.get("/dashboard/stats", response_class=HTMLResponse)
|
||||
async def dashboard_stats_fragment(request: Request, session: AsyncSession = Depends(get_session)):
|
||||
async def dashboard_stats_fragment(
|
||||
request: Request, session: AsyncSession = Depends(get_session)
|
||||
) -> HTMLResponse:
|
||||
ctx = await get_dashboard_stats(session)
|
||||
return _render("dashboard_stats.html", request=request, **ctx)
|
||||
|
||||
@@ -67,13 +74,11 @@ async def scans_list(
|
||||
sort_by: str = DEFAULT_SORT_BY_SCANS,
|
||||
sort_dir: str = DEFAULT_SORT_DIR,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
) -> HTMLResponse:
|
||||
per_page = WEB_PER_PAGE
|
||||
offset = (page - 1) * per_page
|
||||
|
||||
flagged_bool = None
|
||||
if flagged == "1":
|
||||
flagged_bool = True
|
||||
flagged_bool = _parse_flagged(flagged)
|
||||
|
||||
q, count_q = build_scan_list_query(
|
||||
flagged=flagged_bool,
|
||||
@@ -105,9 +110,9 @@ async def scans_list(
|
||||
|
||||
|
||||
@router.get("/scans/{scan_id}", response_class=HTMLResponse)
|
||||
async def scan_detail(scan_id: int, request: Request, session: AsyncSession = Depends(get_session)):
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
async def scan_detail(
|
||||
scan_id: int, request: Request, session: AsyncSession = Depends(get_session)
|
||||
) -> HTMLResponse:
|
||||
scan = await session.scalar(
|
||||
select(Scan).where(Scan.id == scan_id).options(selectinload(Scan.findings))
|
||||
)
|
||||
@@ -126,13 +131,11 @@ async def packages_list(
|
||||
sort_by: str = DEFAULT_SORT_BY_PACKAGES,
|
||||
sort_dir: str = DEFAULT_SORT_DIR,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
) -> HTMLResponse:
|
||||
per_page = WEB_PER_PAGE
|
||||
offset = (page - 1) * per_page
|
||||
|
||||
flagged_bool = None
|
||||
if flagged == "1":
|
||||
flagged_bool = True
|
||||
flagged_bool = _parse_flagged(flagged)
|
||||
|
||||
rows_q, total_q = build_package_list_query(
|
||||
flagged=flagged_bool,
|
||||
@@ -166,14 +169,8 @@ async def package_detail(
|
||||
name: str,
|
||||
request: Request,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
# name:path captures the entire path after /packages/
|
||||
# e.g. "eviltest/0.1.0" or "github.com/attacker/evilmodule/v0.1.0"
|
||||
parts = name.rsplit("/", 1)
|
||||
pkg_name = unquote(parts[0])
|
||||
pkg_version = unquote(parts[1]) if len(parts) == 2 else ""
|
||||
|
||||
from sqlalchemy.orm import selectinload
|
||||
) -> HTMLResponse:
|
||||
pkg_name, pkg_version = parse_package_path(name)
|
||||
|
||||
scans = (
|
||||
(
|
||||
@@ -211,7 +208,7 @@ async def analyze_finding_htmx(
|
||||
request: Request,
|
||||
retry: bool = False,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
) -> HTMLResponse:
|
||||
"""HTMX fragment: trigger LLM analysis and return styled result HTML."""
|
||||
from ..config import config
|
||||
from ..core.llm import analyze_finding
|
||||
|
||||
@@ -67,7 +67,7 @@ async def nexus_webhook(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
x_nexus_webhook_signature: str | None = Header(None, alias="X-Nexus-Webhook-Signature"),
|
||||
):
|
||||
) -> dict:
|
||||
payload = await request.body()
|
||||
|
||||
if config.webhook_secret:
|
||||
|
||||
102
guarddog_nexus/schemas.py
Normal file
102
guarddog_nexus/schemas.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Pydantic schemas for API request/response models."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ScanOut(BaseModel):
|
||||
id: int
|
||||
package_name: str
|
||||
package_version: str
|
||||
ecosystem: str
|
||||
repository: str
|
||||
status: str
|
||||
total_findings: int
|
||||
flagged: bool
|
||||
started_at: datetime | None = None
|
||||
finished_at: datetime | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ScanListResponse(BaseModel):
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
scans: list[ScanOut]
|
||||
|
||||
|
||||
class ScanDetailOut(ScanOut):
|
||||
nexus_asset_url: str | None = None
|
||||
sha256: str | None = None
|
||||
initiator: str | None = None
|
||||
source_ip: str | None = None
|
||||
findings: list[dict] = []
|
||||
|
||||
|
||||
class FindingOut(BaseModel):
|
||||
id: int
|
||||
scan_id: int
|
||||
rule: str = ""
|
||||
severity: str = ""
|
||||
message: str = ""
|
||||
location: str = ""
|
||||
code: str = ""
|
||||
report: dict | None = None
|
||||
created_at: datetime | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class FindingsListResponse(BaseModel):
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
findings: list[FindingOut]
|
||||
|
||||
|
||||
class PackageOut(BaseModel):
|
||||
name: str
|
||||
version: str
|
||||
ecosystem: str
|
||||
repository: str
|
||||
last_scanned_at: datetime | None = None
|
||||
flagged: bool
|
||||
total_findings: int
|
||||
latest_scan_id: int
|
||||
|
||||
|
||||
class PackageListResponse(BaseModel):
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
packages: list[PackageOut]
|
||||
|
||||
|
||||
class PackageScanOut(BaseModel):
|
||||
id: int
|
||||
status: str
|
||||
total_findings: int
|
||||
flagged: bool
|
||||
started_at: datetime | None = None
|
||||
|
||||
|
||||
class PackageDetailOut(BaseModel):
|
||||
name: str
|
||||
version: str
|
||||
ecosystem: str
|
||||
repository: str
|
||||
flagged: bool
|
||||
scans: list[PackageScanOut]
|
||||
findings: list[dict]
|
||||
|
||||
|
||||
class StatsResponse(BaseModel):
|
||||
total_scans: int
|
||||
flagged_scans: int
|
||||
recent_flagged: int
|
||||
total_findings: int
|
||||
top_rules: list[dict]
|
||||
latest_scan_at: datetime | None = None
|
||||
@@ -34,8 +34,7 @@ async def test_scan_stats_empty(client):
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_not_found(client):
|
||||
resp = await client.get("/api/v1/scans/99999")
|
||||
assert resp.status_code == 200
|
||||
assert "detail" in resp.json()
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -129,8 +128,7 @@ async def test_package_with_data(client, sample_flagged_scan):
|
||||
@pytest.mark.asyncio
|
||||
async def test_package_not_found(client):
|
||||
resp = await client.get("/api/v1/packages/nonexistent/1.0")
|
||||
assert resp.status_code == 200
|
||||
assert "detail" in resp.json()
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# --- Findings ---
|
||||
|
||||
Reference in New Issue
Block a user