feat: pydantic response schemas for misp endpoints

This commit is contained in:
Chris Milne 2026-05-13 12:42:08 +01:00
parent d81f865ec2
commit 6a9a31b973
4 changed files with 81 additions and 22 deletions

View file

@ -4,4 +4,16 @@ Constants and error codes for <this module>
Constants: Constants:
- List: Description - List: Description
- Consts: Description - Consts: Description
""" """
from enum import StrEnum, auto
class BlockedStatus(StrEnum):
BLOCKED = auto()
ALLOWED = auto()
IGNORED = auto()
class EventStatus(StrEnum):
ACTIVE = auto()
IGNORED = auto()

View file

@ -5,7 +5,8 @@ Endpoints:
- List: Description - List: Description
- Endpoints: Description - Endpoints: Description
""" """
from fastapi import APIRouter, HTTPException, Request, BackgroundTasks from typing import Annotated
from fastapi import APIRouter, HTTPException, Request, BackgroundTasks, Query
from sqlalchemy.sql import and_ from sqlalchemy.sql import and_
from src.auth.service import authed_dependency from src.auth.service import authed_dependency
@ -13,7 +14,8 @@ from src.database import db_dependency
from src.prometheus import prometheus from src.prometheus import prometheus
from src.misp.models import Domain from src.misp.models import Domain
from src.misp.schemas import MispUpdatePutRequest, MispUpdatePutResponse from src.misp.schemas import MispUpdatePutRequest, MispUpdatePutResponse, DomainBlockedGetResponse, \
DomainSearchGetResponse, AlwaysAllowPatchResponse, EventIgnorePatchResponse, DomainDetailsGetResponse
router = APIRouter( router = APIRouter(
tags=["misp"], tags=["misp"],
@ -48,7 +50,7 @@ async def manual_misp_update(request: Request, update_request: MispUpdatePutRequ
return {"time": published_time, "state": "Starting"} return {"time": published_time, "state": "Starting"}
@router.get("/domain/blocked/{domain}") @router.get("/domain/blocked/{domain}", response_model=DomainBlockedGetResponse)
async def domain_blocked(domain: str, db: db_dependency): async def domain_blocked(domain: str, db: db_dependency):
same_elements = and_( same_elements = and_(
Domain.events.contains(Domain.ignored_events), Domain.events.contains(Domain.ignored_events),
@ -60,10 +62,10 @@ async def domain_blocked(domain: str, db: db_dependency):
~same_elements ~same_elements
).first()) ).first())
return {"is_blocked": bool(domain_model)} return {"domain": domain, "is_blocked": bool(domain_model)}
@router.get("/domain/search") @router.get("/domain/search", response_model=DomainSearchGetResponse)
async def domain_search(domain: str, db: db_dependency): async def domain_search(domain: str, db: db_dependency):
domain = domain.replace("*", "%") domain = domain.replace("*", "%")
@ -83,10 +85,10 @@ async def domain_search(domain: str, db: db_dependency):
else: else:
return "blocked" return "blocked"
return {item[0].domain: domain_status(item) for item in results} return {"domains": {item[0].domain: domain_status(item) for item in results}}
@router.patch("/domain/always_allowed/{domain}") @router.patch("/domain/always_allowed/{domain}", response_model=AlwaysAllowPatchResponse)
async def always_allowed(db: db_dependency, domain: str, allow: bool): async def always_allowed(db: db_dependency, domain: str, allow: Annotated[bool, Query()]):
domain_model = db.query(Domain).filter(Domain.domain == domain).first() domain_model = db.query(Domain).filter(Domain.domain == domain).first()
if domain_model: if domain_model:
@ -96,10 +98,13 @@ async def always_allowed(db: db_dependency, domain: str, allow: bool):
db.add(domain_model) db.add(domain_model)
db.commit() db.commit()
db.refresh(domain_model)
return {"domain": domain, "is_always_allowed": domain_model.always_allowed}
@router.patch("/domain/events/{domain}/ignore") @router.patch("/domain/events/{domain}/ignore", response_model=EventIgnorePatchResponse)
async def event_ignore(domain: str, db: db_dependency, event: int): async def event_ignore(domain: str, db: db_dependency, event: Annotated[int, Query()]):
domain_model = db.query(Domain).filter(Domain.domain == domain).first() domain_model = db.query(Domain).filter(Domain.domain == domain).first()
if not domain_model: if not domain_model:
raise HTTPException(status_code=404, detail="Domain Not Found") raise HTTPException(status_code=404, detail="Domain Not Found")
@ -110,16 +115,17 @@ async def event_ignore(domain: str, db: db_dependency, event: int):
ignored_events = domain_model.ignored_events or [] ignored_events = domain_model.ignored_events or []
if event in ignored_events: if event in ignored_events:
return {"status": "Event Ignored"} return {"domain": domain, "event_id": event, "status": "ignored"}
domain_model.ignored_events = ignored_events + [event] domain_model.ignored_events = ignored_events + [event]
db.commit() db.commit()
return {"status": "Event Ignored"} return {"domain": domain, "event_id": event, "status": "ignored"}
@router.patch("/domain/events/{domain}/reinstate") # TODO: Combine event management routes
@router.patch("/domain/events/{domain}/reinstate", response_model=EventIgnorePatchResponse)
async def event_reinstate(domain: str, db: db_dependency, event: int): async def event_reinstate(domain: str, db: db_dependency, event: int):
domain_model = db.query(Domain).filter(Domain.domain == domain).first() domain_model = db.query(Domain).filter(Domain.domain == domain).first()
if not domain_model: if not domain_model:
@ -128,7 +134,7 @@ async def event_reinstate(domain: str, db: db_dependency, event: int):
ignored_events = domain_model.ignored_events ignored_events = domain_model.ignored_events
if not ignored_events or event not in ignored_events: if not ignored_events or event not in ignored_events:
return {"status": "Event Ignored"} return {"domain": domain, "event_id": event, "status": "active"}
ignored_events.remove(event) ignored_events.remove(event)
domain_model.ignored_events = ignored_events domain_model.ignored_events = ignored_events
@ -136,10 +142,18 @@ async def event_reinstate(domain: str, db: db_dependency, event: int):
db.add(domain_model) db.add(domain_model)
db.commit() db.commit()
return {"status": "Event Un-ignored"} return {"domain": domain, "event_id": event, "status": "active"}
@router.get("/domain/details/{domain}") @router.get("/domain/details/{domain}", response_model=DomainDetailsGetResponse)
async def domain_details(db: db_dependency, domain: str): async def domain_details(db: db_dependency, domain: str):
result = db.query(Domain).filter(Domain.domain==domain).first() result = db.query(Domain).filter(Domain.domain==domain).first()
if not result:
raise HTTPException(status_code=404, detail="Domain Not Found")
response = {
"domain": result.domain,
"always_allowed": result.always_allowed,
"events": result.events if result.events else [],
"ignored_events": result.ignored_events if result.ignored_events else [],
}
return result return response

View file

@ -10,6 +10,8 @@ from pydantic import Field
from src.schemas import CustomBaseModel from src.schemas import CustomBaseModel
from src.misp.constants import BlockedStatus, EventStatus
class MispUpdatePutRequest(CustomBaseModel): class MispUpdatePutRequest(CustomBaseModel):
published_timestamp: Optional[str] = Field(default=None, description="Timestamp for how far back to check for published timestamps") published_timestamp: Optional[str] = Field(default=None, description="Timestamp for how far back to check for published timestamps")
@ -18,3 +20,30 @@ class MispUpdatePutRequest(CustomBaseModel):
class MispUpdatePutResponse(CustomBaseModel): class MispUpdatePutResponse(CustomBaseModel):
time: str time: str
state: str state: str
class DomainBlockedGetResponse(CustomBaseModel):
domain: str
is_blocked: bool
class DomainSearchGetResponse(CustomBaseModel):
domains: dict[str, BlockedStatus]
class AlwaysAllowPatchResponse(CustomBaseModel):
domain: str
is_always_allowed: bool
class EventIgnorePatchResponse(CustomBaseModel):
domain: str
event_id: int
status: EventStatus
class DomainDetailsGetResponse(CustomBaseModel):
domain: str
always_allowed: bool
events: list[int] = []
ignored_events: list[int] = []

View file

@ -57,6 +57,7 @@ class UI {
this.domain_search_btn = document.getElementById("domain_search_btn") this.domain_search_btn = document.getElementById("domain_search_btn")
this.false_positive_actions = document.getElementById("false_positive_actions")
this.false_positive_load_btn = document.getElementById("false_positive_load_btn") this.false_positive_load_btn = document.getElementById("false_positive_load_btn")
this.event_container = document.getElementById("events_container") this.event_container = document.getElementById("events_container")
@ -105,12 +106,12 @@ class UI {
render_search_results(results) { render_search_results(results) {
const container = document.getElementById("domain_search_results"); const container = document.getElementById("domain_search_results");
container.innerHTML = "";
const result_count = Object.keys(results).length; const result_count = Object.keys(results).length;
if (!result_count) { if (!result_count) {
container.innerHTML = `<div class="search-result">No results found.</div>`; container.innerHTML = `<div class="search-result">No results found.</div>`;
return; return;
} }
container.innerHTML = "";
for (const [key, value] of Object.entries(results)) { for (const [key, value] of Object.entries(results)) {
const row = document.createElement("div"); const row = document.createElement("div");
row.className = "search-result"; row.className = "search-result";
@ -221,7 +222,7 @@ class UI {
} }
render_false_positive_controls(domain_data) { render_false_positive_controls(domain_data) {
document.getElementById("false_positive_actions").style.display = "block"; ui.false_positive_actions.style.display = "block";
const always_allow_btn = document.getElementById("always_allow_btn"); const always_allow_btn = document.getElementById("always_allow_btn");
always_allow_btn.innerText = domain_data.always_allow ? "Enabled" : "Disabled"; always_allow_btn.innerText = domain_data.always_allow ? "Enabled" : "Disabled";
@ -296,7 +297,7 @@ class UI {
ui.set_loading("domain_search_btn", true); ui.set_loading("domain_search_btn", true);
try { try {
const data = await api.domain_search() const data = await api.domain_search()
ui.render_search_results(data) ui.render_search_results(data.domains)
} catch (error) { } catch (error) {
console.error(error); console.error(error);
} finally { } finally {
@ -306,8 +307,12 @@ class UI {
async load_false_positive_controls(){ async load_false_positive_controls(){
ui.set_loading("false_positive_load_btn", true); ui.set_loading("false_positive_load_btn", true);
ui.false_positive_actions.style.display = "none";
ui.event_container.innerHTML = "";
console.log(".")
try { try {
const data = await api.load_domain_fp() const data = await api.load_domain_fp()
if(!data){return;}
const parsed_data= { const parsed_data= {
domain: data.domain, domain: data.domain,
always_allow: data.always_allowed, always_allow: data.always_allowed,
@ -317,7 +322,6 @@ class UI {
data.events.forEach(event => { data.events.forEach(event => {
parsed_data.events.push({"id": event, "ignored": data.ignored_events.includes(event)}); parsed_data.events.push({"id": event, "ignored": data.ignored_events.includes(event)});
}) })
ui.render_false_positive_controls(parsed_data) ui.render_false_positive_controls(parsed_data)
} catch (error) { } catch (error) {
console.error(error); console.error(error);