diff --git a/helm/blueapi/config_schema.json b/helm/blueapi/config_schema.json index d64f3ce6b..3b9d03138 100644 --- a/helm/blueapi/config_schema.json +++ b/helm/blueapi/config_schema.json @@ -307,11 +307,16 @@ "submit_task_check": { "title": "Submit Task Check", "type": "string" + }, + "admin_check": { + "title": "Admin Check", + "type": "string" } }, "required": [ "tiled_service_account_check", - "submit_task_check" + "submit_task_check", + "admin_check" ], "title": "OpaConfig", "type": "object", diff --git a/helm/blueapi/values.schema.json b/helm/blueapi/values.schema.json index 1cb719910..c16ca03dc 100644 --- a/helm/blueapi/values.schema.json +++ b/helm/blueapi/values.schema.json @@ -712,9 +712,14 @@ "type": "object", "required": [ "tiled_service_account_check", - "submit_task_check" + "submit_task_check", + "admin_check" ], "properties": { + "admin_check": { + "title": "Admin Check", + "type": "string" + }, "audience": { "title": "Audience", "default": "account", diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 5b2b4453d..9367c9f88 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -314,6 +314,7 @@ class OpaConfig(BlueapiBaseModel): audience: str = "account" tiled_service_account_check: str submit_task_check: str + admin_check: str class ApplicationConfig(BlueapiBaseModel): @@ -323,7 +324,7 @@ class ApplicationConfig(BlueapiBaseModel): """ #: API version to publish in OpenAPI schema - REST_API_VERSION: ClassVar[str] = "1.4.0" + REST_API_VERSION: ClassVar[str] = "1.4.1" LICENSE_INFO: ClassVar[dict[str, str]] = { "name": "Apache 2.0", diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py index 9d9c96ddd..f9008138a 100644 --- a/src/blueapi/service/authorization.py +++ b/src/blueapi/service/authorization.py @@ -1,19 +1,18 @@ import logging -import re from collections.abc import Mapping from contextlib import AbstractAsyncContextManager, aclosing, nullcontext from typing import Annotated, Any, Self, cast from aiohttp import ClientSession from fastapi import Depends, HTTPException, Request -from starlette import status +from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from blueapi.config import OIDCConfig, OpaConfig, ServiceAccount from blueapi.service.authentication import TiledAuth, unchecked_bearer_token from blueapi.service.model import TaskRequest +from blueapi.utils import INSTRUMENT_SESSION_RE LOGGER = logging.getLogger(__name__) -INSTRUMENT_SESSION_RE = re.compile(r"^[a-z]{2}(?P\d+)-(?P\d+)$") class OpaClient: @@ -66,14 +65,19 @@ async def require_submit_task(self, instrument_session: str, token: str): raise ValueError("Invalid instrument session") if not await self._call_opa( - self._conf.submit_task_check, + self._config.submit_task_check, { "token": token, "proposal": int(match["proposal"]), "visit": int(match["visit"]), }, ): - raise HTTPException(status_code=status.HTTP_403_UNORTHORIZED) + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authorized to submit task" + ) + + async def is_admin(self, token: str) -> bool: + return await self._call_opa(self._config.admin_check, {"token": token}) class OpaUserClient: @@ -88,6 +92,9 @@ async def can_submit_task(self, task: TaskRequest): LOGGER.info("Checking permissions to run task") await self.client.require_submit_task(task.instrument_session, self.token) + async def admin(self) -> bool: + return await self.client.is_admin(self.token) + async def validate_tiled_config( tiled: ServiceAccount | str | None, oidc: OIDCConfig | None, opa: OpaClient | None @@ -112,12 +119,16 @@ async def opa( if opa := cast(OpaClient | None, getattr(request.app.state, "authz", None)): if not token: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, detail="Authentication missing" + ) return OpaUserClient(opa, token) return None + async def submit_permission( - opa: Annotated[OpaUserClient, Depends(opa)], + opa: Annotated[OpaUserClient | None, Depends(opa)], task_request: TaskRequest, ): - await opa.can_submit_task(task_request) + if opa: + await opa.can_submit_task(task_request) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 7e9c2f017..d35ffa38b 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -41,7 +41,13 @@ from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum -from .authorization import OpaClient, submit_permission, validate_tiled_config +from .authorization import ( + OpaClient, + OpaUserClient, + opa, + submit_permission, + validate_tiled_config, +) from .model import ( DeviceModel, DeviceResponse, @@ -148,31 +154,27 @@ def get_app(config: ApplicationConfig): return app -def access_task_permission( - request: Request, +async def access_task_permission( + opa: Annotated[OpaUserClient | None, Depends(opa)], task_id: str, + fedid: Fedid, runner: Annotated[WorkerDispatcher, Depends(_runner)], ): - access_token: dict[str, Any] | None = getattr( - request.state, "decoded_access_token", None - ) - try: - task = runner.run(interface.get_task_by_id, task_id) - except KeyError: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from None + task = runner.run(interface.get_task_by_id, task_id) if ( - access_token - and task - and access_token.get("fedid") != task.task.metadata.get("user") + opa + and not await opa.admin() + and (task and fedid != task.task.metadata.get("user")) ): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) # start_task_permission is used when there is WorkerTask -def start_task_permission( - request: Request, +async def start_task_permission( task: WorkerTask, + opa: Annotated[OpaUserClient, Depends(opa)], + fedid: Fedid, runner: Annotated[WorkerDispatcher, Depends(_runner)], ): if not task.task_id: @@ -180,7 +182,7 @@ def start_task_permission( status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail="No task id provided", ) - access_task_permission(request, task.task_id, runner) + await access_task_permission(opa, task.task_id, fedid, runner) async def on_key_error_404(_: Request, __: Exception): @@ -310,12 +312,11 @@ def submit_task( task_request: Annotated[TaskRequest, Body(..., examples=[example_task_request])], _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], - user: Fedid, + fedid: Fedid, ) -> TaskResponse: """Submit a task to the worker.""" try: - user = user or "Unknown" - task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) + task_id: str = runner.run(interface.submit_task, task_request, {"user": fedid}) response.headers["Location"] = f"{request.url}/{task_id}" return TaskResponse(task_id=task_id) except ValidationError as e: @@ -364,9 +365,10 @@ def validate_task_status(v: str) -> TaskStatusEnum: @secure_router_v1.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK]) @secure_router.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK]) @start_as_current_span(TRACER) -def get_tasks( - request: Request, +async def get_tasks( + fedid: Fedid, runner: Annotated[WorkerDispatcher, Depends(_runner)], + opa: Annotated[OpaUserClient, Depends(opa)], task_status: str | SkipJsonSchema[None] = None, ) -> TasksListResponse: """ @@ -387,13 +389,8 @@ def get_tasks( else: tasks = runner.run(interface.get_tasks) - access_token: dict[str, Any] | None = getattr( - request.state, "decoded_access_token", None - ) - user = access_token.get("fedid") if access_token else None - - if user: - tasks = [t for t in tasks if t.task.metadata.get("user") == user] + if opa and not await opa.admin(): + tasks = [t for t in tasks if t.task.metadata.get("user") == fedid] return TasksListResponse(tasks=tasks) @@ -518,10 +515,11 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt tags=[Tag.TASK], ) @start_as_current_span(TRACER, "state_change_request.new_state") -def set_state( +async def set_state( state_change_request: StateChangeRequest, response: Response, - _: Annotated[None, Depends(access_task_permission)], + fedid: Fedid, + opa: Annotated[OpaUserClient, Depends(opa)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerState: """ @@ -548,14 +546,24 @@ def set_state( current_state in _ALLOWED_TRANSITIONS and new_state in _ALLOWED_TRANSITIONS[current_state] ): + active = runner.run(interface.get_active_task) + + if ( + opa + and not await opa.admin() + and active + and active.task.metadata.get("user") != fedid + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to set worker state", + ) + if new_state == WorkerState.PAUSED: runner.run(interface.pause_worker, state_change_request.defer) elif new_state == WorkerState.RUNNING: runner.run(interface.resume_worker) elif new_state in {WorkerState.ABORTING, WorkerState.STOPPING}: - # active = runner.run(interface.get_active_task) - # if active.task.metadata.get("user"): - try: runner.run( interface.cancel_active_task, diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index e5cc85166..f722c5b42 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,3 +1,4 @@ +import re from collections.abc import Callable, Mapping from functools import wraps from logging import Logger @@ -30,6 +31,8 @@ Args = ParamSpec("Args") Return = TypeVar("Return") +INSTRUMENT_SESSION_RE = re.compile(r"^[a-z]{2}(?P\d+)-(?P\d+)$") + def report_successful_devices( devices: Mapping[str, Any], sim_backend: bool, logger: Logger diff --git a/src/blueapi/utils/serialization.py b/src/blueapi/utils/serialization.py index deee82b1e..8918cf882 100644 --- a/src/blueapi/utils/serialization.py +++ b/src/blueapi/utils/serialization.py @@ -1,9 +1,10 @@ import json -import re from typing import Any from pydantic import BaseModel +from blueapi import utils + def serialize(obj: Any) -> Any: """ @@ -28,13 +29,8 @@ def serialize(obj: Any) -> Any: return obj -_INSTRUMENT_SESSION_AUTHZ_REGEX: re.Pattern = re.compile( - r"^[a-zA-Z]{2}(?P\d+)-(?P\d+)$" -) - - def access_blob(instrument_session: str, beamline: str) -> str: - m = _INSTRUMENT_SESSION_AUTHZ_REGEX.match(instrument_session) + m = utils.INSTRUMENT_SESSION_RE.match(instrument_session) if m is None: raise ValueError( "Unable to extract proposal and visit from " diff --git a/tests/unit_tests/service/test_authorization.py b/tests/unit_tests/service/test_authorization.py index 37c1d7e3f..a2e602f21 100644 --- a/tests/unit_tests/service/test_authorization.py +++ b/tests/unit_tests/service/test_authorization.py @@ -8,9 +8,12 @@ from blueapi.config import OIDCConfig, OpaConfig, ServiceAccount from blueapi.service.authorization import ( OpaClient, + OpaUserClient, opa, + submit_permission, validate_tiled_config, ) +from blueapi.service.model import TaskRequest # Reusable client patch decorator patch_client_session = patch( @@ -25,6 +28,7 @@ def opa_config() -> OpaConfig: return OpaConfig( root=HttpUrl("http://auth.example.com"), submit_task_check="/auth/submit", + admin_check="/auth/admin", tiled_service_account_check="/auth/tiled", ) @@ -108,6 +112,105 @@ async def test_opa_adds_input_fields(session: MagicMock, opa_config: OpaConfig): ) +@pytest.mark.parametrize( + "result,context", + [(True, nullcontext()), (False, pytest.raises(HTTPException, match="403"))], +) +@patch_client_session +async def test_require_submit_task( + session: MagicMock, + opa_config: OpaConfig, + result: bool, + context: AbstractContextManager, +): + session.return_value.post = AsyncMock( + return_value=MagicMock(json=AsyncMock(return_value={"result": result})) + ) + + client = OpaClient(instrument="p99", config=opa_config) + + session.assert_called_once_with(base_url="http://auth.example.com/") + with context: + await client.require_submit_task( + instrument_session="cm12345-1", token="foo_bar" + ) + + session().post.assert_called_once_with( + "/auth/submit", + json={ + "input": { + "token": "foo_bar", + "beamline": "p99", + "audience": "account", + "visit": 1, + "proposal": 12345, + } + }, + ) + + +@patch_client_session +async def test_opa_require_submit_task_invalid_session( + session: MagicMock, opa_config: OpaConfig +): + client = OpaClient(instrument="p45", config=opa_config) + + with pytest.raises(ValueError, match="Invalid instrument session"): + await client.require_submit_task( + instrument_session="not a session", token="foo_bar" + ) + + +@pytest.mark.parametrize("result", [True, False]) +@patch_client_session +async def test_opa_is_admin(session: MagicMock, opa_config: OpaConfig, result: bool): + session.return_value.post = AsyncMock( + return_value=MagicMock(json=AsyncMock(return_value={"result": result})) + ) + client = OpaClient(instrument="p45", config=opa_config) + + admin = await client.is_admin("foo_bar") + + assert admin == result + + session().post.assert_called_once_with( + "/auth/admin", + json={"input": {"token": "foo_bar", "beamline": "p45", "audience": "account"}}, + ) + + +@pytest.mark.parametrize( + "result,context", + [ + (None, nullcontext()), + (HTTPException(status_code=403), pytest.raises(HTTPException, match="403")), + ], +) +async def test_user_client_can_submit_task(result, context: AbstractContextManager): + opa = MagicMock(spec=OpaUserClient) + opa.require_submit_task = AsyncMock(side_effect=result) + + user_client = OpaUserClient(opa, "foo_bar") + + with context: + await user_client.can_submit_task( + TaskRequest(name="foo", params={}, instrument_session="cm12345-1") + ) + opa.require_submit_task.assert_called_once_with("cm12345-1", "foo_bar") + + +@pytest.mark.parametrize("result", [True, False]) +async def test_user_client_admin(result: bool): + opa = MagicMock(spec=OpaUserClient) + opa.is_admin = AsyncMock(return_value=result) + + user_client = OpaUserClient(opa, "foo_bar") + + admin = await user_client.admin() + + assert admin == result + + async def test_validate_tiled_config(): opa = MagicMock(spec=OpaClient) tiled = ServiceAccount() @@ -177,3 +280,21 @@ async def test_opa_dependency_without_authz(token): del request.app.state.authz user_client = await opa(request, token) assert user_client is None + + +@pytest.mark.parametrize( + "result,context", + [ + (None, nullcontext()), + (HTTPException(status_code=403), pytest.raises(HTTPException, match="403")), + ], +) +async def test_submit_permission_dependency(result, context: AbstractContextManager): + opa = MagicMock(spec=OpaUserClient) + opa.can_submit_task.side_effect = result + with context: + await submit_permission(opa, Mock()) + + +async def test_submit_permission_dependency_without_opa(): + assert await submit_permission(None, Mock()) is None diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 1ddf2c6ca..8ec1b9b65 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -7,16 +7,22 @@ import jwt import pytest from bluesky.protocols import Stoppable -from fastapi import status +from fastapi import HTTPException, status from fastapi.testclient import TestClient from httpx import Headers from pydantic import BaseModel, ValidationError from pydantic_core import InitErrorDetails from super_state_machine.errors import TransitionError -from blueapi.config import ApplicationConfig, CORSConfig, OIDCConfig, RestConfig +from blueapi.config import ( + ApplicationConfig, + CORSConfig, + OIDCConfig, + RestConfig, +) from blueapi.core.bluesky_types import Plan -from blueapi.service import main +from blueapi.service import interface, main +from blueapi.service.authorization import OpaUserClient, opa from blueapi.service.interface import ( cancel_active_task, get_device, @@ -54,6 +60,11 @@ def mock_runner() -> Mock: return Mock(spec=WorkerDispatcher) +@pytest.fixture +def mock_opa_client() -> Mock: + return Mock(spec=OpaUserClient) + + @pytest.fixture def client(mock_runner: Mock) -> Iterator[TestClient]: with patch("blueapi.service.interface.worker"): @@ -79,6 +90,27 @@ def client_with_auth( main.teardown_runner() +@pytest.fixture +def access_token(valid_token_with_jwt: dict[str, Any]) -> str: + return valid_token_with_jwt["access_token"] + + +@pytest.fixture +def client_with_opa( + mock_runner: Mock, + oidc_config: OIDCConfig, + mock_opa_client: Mock, + mock_authn_server, +): + with patch("blueapi.service.interface.worker"): + main.setup_runner(runner=mock_runner) + app = main.get_app(ApplicationConfig(oidc=oidc_config)) + app.dependency_overrides[opa] = lambda: mock_opa_client + client = TestClient(app) + yield client + main.teardown_runner() + + @pytest.fixture def rest_config_with_cors() -> RestConfig: cors_config = CORSConfig( @@ -251,10 +283,27 @@ def test_create_task(mock_runner: Mock, client: TestClient) -> None: response = client.post("/tasks", json=task.model_dump()) - mock_runner.run.assert_called_with(submit_task, task, {"user": "Unknown"}) + mock_runner.run.assert_called_with(submit_task, task, {"user": None}) assert response.json() == {"task_id": task_id} +def test_submit_task_requires_permission( + mock_runner: Mock, + client_with_opa: TestClient, + mock_opa_client: Mock, + access_token: str, +): + task = TaskRequest(name="sleep", params={"time": 2}, instrument_session="cm12345-2") + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + mock_opa_client.can_submit_task.side_effect = HTTPException(status_code=403) + mock_runner.run.side_effect = RuntimeError("Task should not be submitted") + + resp = client_with_opa.post("/tasks", json=task.model_dump()) + + assert resp.status_code == 403 + mock_runner.run.assert_not_called() + + def test_create_task_inserts_auth_metadata( mock_runner: Mock, client_with_auth: TestClient, @@ -416,6 +465,27 @@ def test_get_tasks_by_status_invalid(client: TestClient) -> None: assert response.status_code == status.HTTP_400_BAD_REQUEST +@pytest.mark.parametrize("admin,task_ids", [(True, ["foo", "bar"]), (False, ["foo"])]) +def test_get_tasks_filters_by_user( + mock_runner: Mock, + client_with_opa: TestClient, + access_token: str, + mock_opa_client: Mock, + admin: bool, + task_ids: list[str], +): + + mock_runner.run.return_value = [ + TrackableTask(task_id="foo", task=Task(name="f1", metadata={"user": "jd1"})), + TrackableTask(task_id="bar", task=Task(name="f2", metadata={"user": "jd2"})), + ] + mock_opa_client.admin.return_value = admin + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + tasks = client_with_opa.get("/tasks").json().get("tasks") + + assert [t["task_id"] for t in tasks] == task_ids + + def test_delete_submitted_task(mock_runner: Mock, client: TestClient) -> None: task_id = str(uuid.uuid4()) mock_runner.run.return_value = task_id @@ -423,6 +493,28 @@ def test_delete_submitted_task(mock_runner: Mock, client: TestClient) -> None: assert response.json() == {"task_id": f"{task_id}"} +def test_cant_delete_other_users_task( + mock_runner: Mock, + client_with_opa: TestClient, + access_token: str, + mock_opa_client: Mock, +): + mock_opa_client.admin.return_value = False + mock_runner.run.side_effect = lambda mth, *args: { + interface.get_task_by_id: TrackableTask( + task_id="bar", task=Task(name="t2", metadata={"user": "jd2"}) + ), + }[mth] + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + + resp = client_with_opa.delete("/tasks/bar") + + # 404 to obfuscate whether task exists when inaccessible + assert resp.status_code == 404 + + mock_runner.run.assert_called_once() + + def test_set_active_task(client: TestClient) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) @@ -471,6 +563,35 @@ def test_set_active_task_worker_already_running( assert response.json() == {"detail": "Worker already active"} +@pytest.mark.parametrize("admin,status", [(True, 200), (False, 404)]) +def test_set_other_users_task_active( + mock_runner: Mock, + client_with_opa: TestClient, + mock_opa_client: Mock, + access_token: str, + admin: bool, + status: int, +): + + task_id = "foo" + task = WorkerTask(task_id=task_id) + mock_opa_client.admin.return_value = admin + + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + + mock_runner.run.side_effect = lambda mth, *a, **kw: { + interface.get_task_by_id: TrackableTask( + task_id="foo", task=Task(name="bar", metadata={"user": "jd2"}) + ), + interface.get_active_task: None, + interface.begin_task: None, + }[mth] + + resp = client_with_opa.put("/worker/task", json=task.model_dump()) + + assert resp.status_code == status + + def test_get_task(mock_runner: Mock, client: TestClient): task_id = str(uuid.uuid4()) task = TrackableTask( @@ -503,6 +624,25 @@ def test_get_task(mock_runner: Mock, client: TestClient): } +@pytest.mark.parametrize("admin,status", [(True, 200), (False, 404)]) +def test_get_other_users_task( + mock_runner: Mock, + client_with_opa: TestClient, + mock_opa_client: Mock, + access_token: str, + admin: bool, + status: int, +): + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + mock_runner.run.return_value = TrackableTask( + task_id="foo", task=Task(name="bar", metadata={"user": "jd2"}) + ) + mock_opa_client.admin.return_value = admin + + resp = client_with_opa.get("/tasks/foo") + assert resp.status_code == status + + def test_get_all_tasks(mock_runner: Mock, client: TestClient): task_id = str(uuid.uuid4()) tasks = [ @@ -574,7 +714,12 @@ def test_get_state(mock_runner: Mock, client: TestClient): def test_set_state_running_to_paused(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.PAUSED - mock_runner.run.side_effect = [current_state, None, final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + None, + final_state, + ] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() @@ -588,7 +733,12 @@ def test_set_state_running_to_paused(mock_runner: Mock, client: TestClient): def test_set_state_paused_to_running(mock_runner: Mock, client: TestClient): current_state = WorkerState.PAUSED final_state = WorkerState.RUNNING - mock_runner.run.side_effect = [current_state, None, final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + None, + final_state, + ] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() @@ -602,7 +752,12 @@ def test_set_state_paused_to_running(mock_runner: Mock, client: TestClient): def test_set_state_running_to_aborting(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.ABORTING - mock_runner.run.side_effect = [current_state, None, final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + None, + final_state, + ] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() @@ -619,7 +774,12 @@ def test_set_state_running_to_stopping_including_reason( current_state = WorkerState.RUNNING final_state = WorkerState.STOPPING reason = "blueapi is being stopped" - mock_runner.run.side_effect = [current_state, None, final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + None, + final_state, + ] response = client.put( "/worker/state", @@ -635,7 +795,11 @@ def test_set_state_transition_error(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.STOPPING - mock_runner.run.side_effect = [current_state, TransitionError()] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + TransitionError(), + ] response = client.put( "/worker/state", @@ -666,6 +830,35 @@ def test_set_state_invalid_transition(mock_runner: Mock, client: TestClient): } +@pytest.mark.parametrize("admin,status", [(True, 202), (False, 403)]) +def test_set_state_of_other_users_task( + mock_runner: Mock, + client_with_opa: TestClient, + mock_opa_client: Mock, + access_token: str, + admin: bool, + status: int, +): + + mock_opa_client.admin.return_value = admin + mock_runner.run.side_effect = lambda mth, *a, **kw: { + interface.get_active_task: TrackableTask( + task_id="foo", task=Task(name="bar", metadata={"user": "jd2"}) + ), + interface.get_worker_state: WorkerState.RUNNING, + interface.cancel_active_task: WorkerState.ABORTING, + }[mth] + + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + + resp = client_with_opa.put( + "/worker/state", + json=StateChangeRequest(new_state=WorkerState.ABORTING).model_dump(), + ) + + assert resp.status_code == status + + def test_get_environment_idle(mock_runner: Mock, client: TestClient) -> None: environment_id = uuid.uuid4() mock_runner.state = EnvironmentResponse( diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 999e01d23..ed00587a1 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -345,6 +345,7 @@ def test_config_yaml_parsed(temp_yaml_config_file): "audience": "account", "tiled_service_account_check": "v1/tiled_service_account", "submit_task_check": "v1/submit_task", + "admin_check": "v1/admin_check", }, }, {