All checks were successful
ci / lint_and_test (push) Successful in 16s
Orgs can only grant permissions to groups that they themselves have been granted access to. Super admin bypasses not added, flagged as todos.
290 lines
7.9 KiB
Python
290 lines
7.9 KiB
Python
import pytest
|
|
|
|
from typing import AsyncGenerator
|
|
from itertools import combinations
|
|
from fastapi.routing import APIRoute
|
|
from httpx import AsyncClient, ASGITransport
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from src.user.models import User
|
|
from src.service.models import Service
|
|
from src.organisation.models import Organisation as Org, OrgUsers
|
|
from src.contact.models import Contact
|
|
from src.iam.models import Group, Permission, OrgPermissions
|
|
from src.auth.service import get_current_user, get_dev_user
|
|
from src.auth.dependencies import empty_su_list, get_super_admin_list, testing_su_list
|
|
from src.main import app # inited FastAPI app
|
|
from src.database import engine, Base, get_db
|
|
|
|
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
|
|
@pytest.fixture()
|
|
def db_session():
|
|
Base.metadata.drop_all(bind=engine)
|
|
Base.metadata.create_all(bind=engine)
|
|
db = SessionLocal()
|
|
try:
|
|
_seed(db) # extracted seeding logic into a plain function
|
|
yield db
|
|
finally:
|
|
db.rollback()
|
|
db.close()
|
|
|
|
|
|
@pytest.fixture
|
|
async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]:
|
|
def get_db_override():
|
|
return db_session
|
|
|
|
app.dependency_overrides[get_db] = get_db_override
|
|
app.dependency_overrides[get_current_user] = get_dev_user
|
|
app.dependency_overrides[get_super_admin_list] = testing_su_list
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(
|
|
transport=transport, base_url="http://localhost:8000/api/v1"
|
|
) as ac:
|
|
yield ac
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
@pytest.fixture
|
|
async def no_user_client(db_session) -> AsyncGenerator[AsyncClient, None]:
|
|
def get_db_override():
|
|
return db_session
|
|
|
|
app.dependency_overrides[get_db] = get_db_override
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(
|
|
transport=transport, base_url="http://localhost:8000/api/v1"
|
|
) as ac:
|
|
yield ac
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
@pytest.fixture
|
|
async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]:
|
|
def get_db_override():
|
|
return db_session
|
|
|
|
app.dependency_overrides[get_db] = get_db_override
|
|
app.dependency_overrides[get_current_user] = get_dev_user
|
|
app.dependency_overrides[get_super_admin_list] = empty_su_list
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(
|
|
transport=transport, base_url="http://localhost:8000/api/v1"
|
|
) as ac:
|
|
yield ac
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
def _seed(db):
|
|
db.add(
|
|
User(
|
|
email="admin@test.com",
|
|
first_name="Admin",
|
|
last_name="Test",
|
|
oidc_id="abcd-efgh-ijkl-mnop",
|
|
)
|
|
)
|
|
db.add(
|
|
User(
|
|
email="user@orgone.com",
|
|
first_name="User",
|
|
last_name="Test",
|
|
oidc_id="abcd-efgh-ijkl-qwer",
|
|
)
|
|
)
|
|
db.add(
|
|
User(
|
|
email="root@orgtwo.com",
|
|
first_name="Root",
|
|
last_name="Test",
|
|
oidc_id="abcd-efgh-ijkl-hjkl",
|
|
)
|
|
)
|
|
db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927"))
|
|
db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927"))
|
|
db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927"))
|
|
db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927"))
|
|
db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927"))
|
|
db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927"))
|
|
db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927"))
|
|
db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927"))
|
|
db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927"))
|
|
db.flush()
|
|
db.add(
|
|
Org(
|
|
name="Org One",
|
|
root_user_id=1,
|
|
billing_contact_id=1,
|
|
owner_contact_id=2,
|
|
security_contact_id=3,
|
|
status="approved",
|
|
intake_questionnaire={
|
|
"metadata": {"version": 0, "submission_date": None},
|
|
"questions": {"question_two": "answer two"},
|
|
},
|
|
)
|
|
)
|
|
db.add(
|
|
Org(
|
|
name="Org Two",
|
|
root_user_id=3,
|
|
billing_contact_id=4,
|
|
owner_contact_id=5,
|
|
security_contact_id=6,
|
|
status="approved",
|
|
intake_questionnaire={
|
|
"metadata": {"version": 0, "submission_date": None},
|
|
"questions": {"question_two": "answer two"},
|
|
},
|
|
)
|
|
)
|
|
db.add(
|
|
Org(
|
|
name="Org Three",
|
|
root_user_id=1,
|
|
billing_contact_id=7,
|
|
owner_contact_id=8,
|
|
security_contact_id=9,
|
|
status="partial",
|
|
intake_questionnaire={
|
|
"metadata": {"version": 0, "submission_date": None},
|
|
"questions": {"question_two": "answer two"},
|
|
},
|
|
)
|
|
)
|
|
db.add(OrgUsers(org_id=1, user_id=2))
|
|
db.add(Service(name="Test Service", api_key="123456789"))
|
|
db.add(Permission(service_id=1, resource="test_resource", action="read"))
|
|
db.add(Permission(service_id=1, resource="test_resource", action="move"))
|
|
db.add(Permission(service_id=1, resource="test_resource", action="delete"))
|
|
db.add(OrgPermissions(org_id=1, permission_id=1))
|
|
db.add(OrgPermissions(org_id=1, permission_id=2))
|
|
db.add(Group(name="Org One Group", org_id=1))
|
|
db.add(Group(name="Org Two Group", org_id=2))
|
|
db.add(Group(name="Org One Group Two", org_id=1))
|
|
db.flush()
|
|
group_model = db.get(Group, 1)
|
|
perm_model = db.get(Permission, 1)
|
|
group_model.permission_rel.append(perm_model)
|
|
user_model = db.get(User, 1)
|
|
org_model = db.get(Org, 1)
|
|
org_model.user_rel.append(user_model)
|
|
org_model.group_rel.append(group_model)
|
|
db.flush()
|
|
group_model.user_rel.append(user_model)
|
|
db.commit()
|
|
|
|
|
|
def generate_query_and_status(params) -> list[tuple[str, int]]:
|
|
possible_values = [0, -1, 42, "banana", ""]
|
|
|
|
defaults = [f"{param}=1" for param in params]
|
|
|
|
# Missing params
|
|
query_list = [
|
|
"&".join(combo)
|
|
for r in range(len(defaults) + 1)
|
|
for combo in combinations(defaults, r)
|
|
]
|
|
|
|
# Complete query as default for invalid checks
|
|
default_query = query_list.pop(-1)
|
|
|
|
# Checks for each param being invalid
|
|
for param in params:
|
|
for value in possible_values:
|
|
new_value = f"&{param}={value}"
|
|
query_list.append(default_query.replace(f"{param}=1", new_value))
|
|
|
|
query_and_status = []
|
|
|
|
# Assign expected status
|
|
for query in query_list:
|
|
# ID 42 is used to represent a non-existent entry. So it should 404.
|
|
status = 404 if "42" in query else 422
|
|
# Remove leading "&" if present
|
|
query = query if len(query) > 1 and query[0] != "&" else query[1:]
|
|
query_and_status.append((query, status))
|
|
|
|
return query_and_status
|
|
|
|
|
|
def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]:
|
|
possible_values_int = [0, -1, 42, "banana", ""]
|
|
possible_values_str = [0, "", "a"]
|
|
|
|
defaults = [{param: 1 for param in params.keys()}]
|
|
|
|
# Missing params
|
|
body_list = [
|
|
{key: ("valid string" if params[key] == "str" else 1) for key in combo}
|
|
for r in range(len(defaults[0].keys()) + 1)
|
|
for combo in combinations(defaults[0].keys(), r)
|
|
]
|
|
|
|
# Complete body as default for generating invalid checks
|
|
default_body = body_list.pop(-1)
|
|
|
|
# Generates checks for each param being invalid
|
|
for param, typ in params.items():
|
|
if typ == "int":
|
|
possible_values = possible_values_int
|
|
elif typ == "str":
|
|
possible_values = possible_values_str
|
|
else:
|
|
raise TypeError(f"Unknown type {typ}")
|
|
for value in possible_values:
|
|
new_record = default_body.copy()
|
|
new_record[param] = value
|
|
body_list.append(new_record)
|
|
|
|
body_and_status = []
|
|
|
|
# Assign expected status
|
|
for body in body_list:
|
|
# ID 42 is used to represent a non-existent entry. So it should 404.
|
|
status = 404 if 42 in body.values() else 422
|
|
body_and_status.append((body, status))
|
|
return body_and_status
|
|
|
|
|
|
def get_testable_routes():
|
|
routes = []
|
|
|
|
for route in app.routes:
|
|
if not isinstance(route, APIRoute):
|
|
continue
|
|
|
|
for method in route.methods:
|
|
if method in {"HEAD", "OPTIONS"}:
|
|
continue
|
|
|
|
routes.append(
|
|
(
|
|
method,
|
|
route.path,
|
|
route.status_code,
|
|
route.response_model,
|
|
route.summary,
|
|
)
|
|
)
|
|
|
|
return routes
|
|
|
|
|
|
# with open("endpoints.txt", "w") as f:
|
|
# for ep in get_testable_routes():
|
|
# f.write(f"[{ep[0]}]({ep[1]}) -> {ep[2]}: {ep[3]}\n")
|
|
#
|
|
#
|
|
### Docstring formatted output ###
|
|
# with open("endpoints.txt", "w") as f:
|
|
# for ep in get_testable_routes():
|
|
# f.write(f"- [{ep[0]}]({ep[1]}): []: {ep[4]}\n")
|