diff --git a/src/database.py b/src/database.py index d6b5a55..9b4a199 100644 --- a/src/database.py +++ b/src/database.py @@ -19,14 +19,12 @@ engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value()) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) def get_db(): - db = SessionLocal() - try: - yield db - except: - db.rollback() - raise - finally: - db.close() + with SessionLocal.begin() as db: + try: + yield db + finally: + db.rollback() # Anything not explicitly commited is rolled back + db.close() db_dependency = Annotated[Session, Depends(get_db)] diff --git a/src/main.py b/src/main.py index 7717625..5b1b615 100644 --- a/src/main.py +++ b/src/main.py @@ -18,7 +18,6 @@ from src.auth.config import auth_settings from src.misp.service import MISPHandler # TODO: Create Pydantic request/response schemas -# TODO: Untie MISP connection from app loading @asynccontextmanager @@ -26,8 +25,7 @@ async def lifespan(_application: FastAPI) -> AsyncGenerator: # Startup yield # Shutdown - if hasattr(_application, "misp_handler"): - _application.misp_handler.stop_timer() + _application.misp_handler.stop_timer() if settings.ENVIRONMENT.is_deployed: diff --git a/src/misp/constants.py b/src/misp/constants.py index 35f91c2..e1df957 100644 --- a/src/misp/constants.py +++ b/src/misp/constants.py @@ -4,16 +4,4 @@ Constants and error codes for Constants: - List: Description - Consts: Description -""" -from enum import StrEnum, auto - - -class BlockedStatus(StrEnum): - BLOCKED = auto() - ALLOWED = auto() - IGNORED = auto() - - -class EventStatus(StrEnum): - ACTIVE = auto() - IGNORED = auto() +""" \ No newline at end of file diff --git a/src/misp/router.py b/src/misp/router.py index ab2a5c3..03b8405 100644 --- a/src/misp/router.py +++ b/src/misp/router.py @@ -5,8 +5,7 @@ Endpoints: - List: Description - Endpoints: Description """ -from typing import Annotated -from fastapi import APIRouter, HTTPException, Request, BackgroundTasks, Query +from fastapi import APIRouter, HTTPException, Request, BackgroundTasks from sqlalchemy.sql import and_ from src.auth.service import authed_dependency @@ -14,8 +13,7 @@ from src.database import db_dependency from src.prometheus import prometheus from src.misp.models import Domain -from src.misp.schemas import MispUpdatePutRequest, MispUpdatePutResponse, DomainBlockedGetResponse, \ - DomainSearchGetResponse, AlwaysAllowPatchResponse, EventIgnorePatchResponse, DomainDetailsGetResponse +from src.misp.schemas import MispUpdatePutRequest, MispUpdatePutResponse router = APIRouter( tags=["misp"], @@ -50,7 +48,7 @@ async def manual_misp_update(request: Request, update_request: MispUpdatePutRequ return {"time": published_time, "state": "Starting"} -@router.get("/domain/blocked/{domain}", response_model=DomainBlockedGetResponse) +@router.get("/domain/blocked/{domain}") async def domain_blocked(domain: str, db: db_dependency): same_elements = and_( Domain.events.contains(Domain.ignored_events), @@ -62,10 +60,10 @@ async def domain_blocked(domain: str, db: db_dependency): ~same_elements ).first()) - return {"domain": domain, "is_blocked": bool(domain_model)} + return {"is_blocked": bool(domain_model)} -@router.get("/domain/search", response_model=DomainSearchGetResponse) +@router.get("/domain/search") async def domain_search(domain: str, db: db_dependency): domain = domain.replace("*", "%") @@ -85,10 +83,10 @@ async def domain_search(domain: str, db: db_dependency): else: return "blocked" - return {"domains": {item[0].domain: domain_status(item) for item in results}} + return {item[0].domain: domain_status(item) for item in results} -@router.patch("/domain/always_allowed/{domain}", response_model=AlwaysAllowPatchResponse) -async def always_allowed(db: db_dependency, domain: str, allow: Annotated[bool, Query()]): +@router.patch("/domain/always_allowed/{domain}") +async def always_allowed(db: db_dependency, domain: str, allow: bool): domain_model = db.query(Domain).filter(Domain.domain == domain).first() if domain_model: @@ -98,13 +96,10 @@ async def always_allowed(db: db_dependency, domain: str, allow: Annotated[bool, db.add(domain_model) db.commit() - db.refresh(domain_model) - - return {"domain": domain, "is_always_allowed": domain_model.always_allowed} -@router.patch("/domain/events/{domain}/ignore", response_model=EventIgnorePatchResponse) -async def event_ignore(domain: str, db: db_dependency, event: Annotated[int, Query()]): +@router.patch("/domain/events/{domain}/ignore") +async def event_ignore(domain: str, db: db_dependency, event: int): domain_model = db.query(Domain).filter(Domain.domain == domain).first() if not domain_model: raise HTTPException(status_code=404, detail="Domain Not Found") @@ -115,17 +110,16 @@ async def event_ignore(domain: str, db: db_dependency, event: Annotated[int, Que ignored_events = domain_model.ignored_events or [] if event in ignored_events: - return {"domain": domain, "event_id": event, "status": "ignored"} + return {"status": "Event Ignored"} domain_model.ignored_events = ignored_events + [event] db.commit() - return {"domain": domain, "event_id": event, "status": "ignored"} + return {"status": "Event Ignored"} -# TODO: Combine event management routes -@router.patch("/domain/events/{domain}/reinstate", response_model=EventIgnorePatchResponse) +@router.patch("/domain/events/{domain}/reinstate") async def event_reinstate(domain: str, db: db_dependency, event: int): domain_model = db.query(Domain).filter(Domain.domain == domain).first() if not domain_model: @@ -134,7 +128,7 @@ async def event_reinstate(domain: str, db: db_dependency, event: int): ignored_events = domain_model.ignored_events if not ignored_events or event not in ignored_events: - return {"domain": domain, "event_id": event, "status": "active"} + return {"status": "Event Ignored"} ignored_events.remove(event) domain_model.ignored_events = ignored_events @@ -142,18 +136,10 @@ async def event_reinstate(domain: str, db: db_dependency, event: int): db.add(domain_model) db.commit() - return {"domain": domain, "event_id": event, "status": "active"} + return {"status": "Event Un-ignored"} -@router.get("/domain/details/{domain}", response_model=DomainDetailsGetResponse) +@router.get("/domain/details/{domain}") async def domain_details(db: db_dependency, domain: str): result = db.query(Domain).filter(Domain.domain==domain).first() - if not result: - raise HTTPException(status_code=404, detail="Domain Not Found") - response = { - "domain": result.domain, - "always_allowed": result.always_allowed, - "events": result.events if result.events else [], - "ignored_events": result.ignored_events if result.ignored_events else [], - } - return response + return result diff --git a/src/misp/schemas.py b/src/misp/schemas.py index e2264b0..6410d6d 100644 --- a/src/misp/schemas.py +++ b/src/misp/schemas.py @@ -10,8 +10,6 @@ from pydantic import Field from src.schemas import CustomBaseModel -from src.misp.constants import BlockedStatus, EventStatus - class MispUpdatePutRequest(CustomBaseModel): published_timestamp: Optional[str] = Field(default=None, description="Timestamp for how far back to check for published timestamps") @@ -20,30 +18,3 @@ class MispUpdatePutRequest(CustomBaseModel): class MispUpdatePutResponse(CustomBaseModel): time: str state: str - - -class DomainBlockedGetResponse(CustomBaseModel): - domain: str - is_blocked: bool - - -class DomainSearchGetResponse(CustomBaseModel): - domains: dict[str, BlockedStatus] - - -class AlwaysAllowPatchResponse(CustomBaseModel): - domain: str - is_always_allowed: bool - - -class EventIgnorePatchResponse(CustomBaseModel): - domain: str - event_id: int - status: EventStatus - - -class DomainDetailsGetResponse(CustomBaseModel): - domain: str - always_allowed: bool - events: list[int] = [] - ignored_events: list[int] = [] diff --git a/src/static/assets/index/scripts.js b/src/static/assets/index/scripts.js index d8541a2..56af2d7 100644 --- a/src/static/assets/index/scripts.js +++ b/src/static/assets/index/scripts.js @@ -57,7 +57,6 @@ class UI { this.domain_search_btn = document.getElementById("domain_search_btn") - this.false_positive_actions = document.getElementById("false_positive_actions") this.false_positive_load_btn = document.getElementById("false_positive_load_btn") this.event_container = document.getElementById("events_container") @@ -106,12 +105,12 @@ class UI { render_search_results(results) { const container = document.getElementById("domain_search_results"); - container.innerHTML = ""; const result_count = Object.keys(results).length; if (!result_count) { container.innerHTML = `
No results found.
`; return; } + container.innerHTML = ""; for (const [key, value] of Object.entries(results)) { const row = document.createElement("div"); row.className = "search-result"; @@ -222,7 +221,7 @@ class UI { } render_false_positive_controls(domain_data) { - ui.false_positive_actions.style.display = "block"; + document.getElementById("false_positive_actions").style.display = "block"; const always_allow_btn = document.getElementById("always_allow_btn"); always_allow_btn.innerText = domain_data.always_allow ? "Enabled" : "Disabled"; @@ -297,7 +296,7 @@ class UI { ui.set_loading("domain_search_btn", true); try { const data = await api.domain_search() - ui.render_search_results(data.domains) + ui.render_search_results(data) } catch (error) { console.error(error); } finally { @@ -307,12 +306,8 @@ class UI { async load_false_positive_controls(){ ui.set_loading("false_positive_load_btn", true); - ui.false_positive_actions.style.display = "none"; - ui.event_container.innerHTML = ""; - console.log(".") try { const data = await api.load_domain_fp() - if(!data){return;} const parsed_data= { domain: data.domain, always_allow: data.always_allowed, @@ -322,6 +317,7 @@ class UI { data.events.forEach(event => { parsed_data.events.push({"id": event, "ignored": data.ignored_events.includes(event)}); }) + ui.render_false_positive_controls(parsed_data) } catch (error) { console.error(error);