diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 8c746fe92..6368da0ac 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -321,6 +321,7 @@ def on_event( @controller.command(name="run") @click.argument("name", type=str) @click.argument("parameters", type=ParametersType(), default={}, required=False) +@click.option("--ws", type=bool, is_flag=True, default=False) @click.option( "--foreground/--background", "--fg/--bg", type=bool, is_flag=True, default=True ) @@ -348,6 +349,7 @@ def run_plan( name: str, timeout: float | None, foreground: bool, + ws: bool, instrument_session: str, parameters: TaskParameters, ) -> None: @@ -374,7 +376,13 @@ def on_event(event: AnyEvent) -> None: elif isinstance(event, DataEvent): callback(event.name, event.doc) - resp = client.run_task(task, on_event=on_event) + client.add_callback(on_event) + + if ws: + resp = client.run_blocking(task) + else: + resp = client.run_task(task) + match resp.result: case TaskResult(result=None, type="NoneType"): print("Plan succeeded") diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 20de15892..3d6e3bf6a 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -459,6 +459,26 @@ def get_active_task(self) -> WorkerTask: return self.active_task + @start_as_current_span(TRACER, "request") + def run_blocking( + self, request: TaskRequest, on_event: OnAnyEvent | None = None + ) -> TaskStatus: + for event in self._rest.run_blocking(request): + if on_event is not None: + on_event(event) + for cb in self._callbacks.values(): + try: + cb(event) + except Exception as e: + log.error(f"Callback ({cb}) failed for event: {event}", exc_info=e) + if isinstance(event, WorkerEvent) and event.is_complete(): + if event.task_status is None: + raise BlueskyRemoteControlError( + "Server completed without task status" + ) + return event.task_status + raise BlueskyRemoteControlError("Connection closed before plan completed.") + @start_as_current_span(TRACER, "task", "timeout") def run_task( self, diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 0bddb5c87..e9d313e9a 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -10,11 +10,13 @@ get_tracer, start_as_current_span, ) -from pydantic import BaseModel, TypeAdapter, ValidationError +from pydantic import BaseModel, TypeAdapter, ValidationError, WebsocketUrl from pydantic_core import PydanticSerializationError +from websockets.sync.client import connect from blueapi import __version__ from blueapi.client import client +from blueapi.client.event_bus import AnyEvent from blueapi.config import RestConfig from blueapi.service.authentication import JWTAuth, SessionManager from blueapi.service.model import ( @@ -39,6 +41,8 @@ LOGGER = logging.getLogger(__name__) +USER_AGENT = f"blueapi cli {__version__}" + class BlueskyRequestError(Exception): """An error response from the blueapi server.""" @@ -307,14 +311,15 @@ def _request_and_deserialize( ) -> T: url = self._config.url.unicode_string().removesuffix("/") + suffix # Get the trace context to propagate to the REST API - carr = get_context_propagator() + headers = get_context_propagator() + headers["User-Agent"] = USER_AGENT try: response = self._pool.request( method, url, json=data, params=params, - headers=carr, + headers=headers, auth=JWTAuth(self._session_manager), ) except requests.exceptions.ConnectionError as ce: @@ -340,6 +345,28 @@ def _request_and_deserialize( ) return deserialized + def run_blocking(self, req: TaskRequest): + url = self._ws_address().unicode_string().removesuffix("/") + "/api/v2/run_plan" + headers = get_context_propagator() + if self._session_manager: + auth = self._session_manager.get_valid_access_token() + headers["Authorization"] = f"Bearer {auth}" + with connect( + url, + additional_headers=headers, + user_agent_header=USER_AGENT, + ) as ws: + ws.send(req.model_dump_json()) + for message in ws: + event = TypeAdapter(AnyEvent).validate_json(message) + yield event + + def _ws_address(self) -> WebsocketUrl: + # url = WebsocketUrl.build( + # scheme="ws", host=api.host, port=api.port, path=api.path + # ) + return WebsocketUrl("ws://localhost:8000/") + # https://github.com/DiamondLightSource/blueapi/issues/1256 - remove before 2.0 def __getattr__(name: str): diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index 6761256de..edbf495d2 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -15,7 +15,8 @@ import httpx import jwt import requests -from fastapi import Depends, HTTPException, Request +from fastapi import Depends, HTTPException +from fastapi.requests import HTTPConnection from fastapi.security.utils import get_authorization_scheme_param from pydantic import TypeAdapter from requests.auth import AuthBase @@ -278,14 +279,15 @@ def sync_auth_flow(self, request): yield request -def unchecked_bearer_token(req: Request) -> str | None: +def unchecked_bearer_token(req: HTTPConnection) -> str | None: """Get bearer token value from authorization header""" # This is an abridged version of the same feature of # OAuth2AuthorizationCodeBearer from fastapi. Replicating here prevents # passing unused configuration and means the schema does not include auth # details for servers that do not support it. auth = req.headers.get("Authorization") - scheme, param = get_authorization_scheme_param(auth) + auth_cookie = req.cookies.get("Authorization") + scheme, param = get_authorization_scheme_param(auth or auth_cookie) if scheme.casefold() != "bearer": return None return param.strip() @@ -303,7 +305,7 @@ def build_access_token_check(config: OIDCConfig): """ jwkclient = jwt.PyJWKClient(config.jwks_uri) - def validate_bearer_token(request: Request, token: UncheckedBearerToken): + def validate_bearer_token(request: HTTPConnection, token: UncheckedBearerToken): """Check that a bearer token is valid and inject into request state""" if not token: raise HTTPException( @@ -326,7 +328,7 @@ def validate_bearer_token(request: Request, token: UncheckedBearerToken): return validate_bearer_token -def access_token(request: Request) -> Mapping[str, Any] | None: +def access_token(request: HTTPConnection) -> Mapping[str, Any] | None: """Get the decoded and verified access token of the user making the request""" return getattr(request.state, "decoded_access_token", None) diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 335d00477..65c59242b 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -1,5 +1,8 @@ +import logging from collections.abc import Mapping +from dataclasses import dataclass from functools import cache +from multiprocessing.connection import Connection from typing import Any from bluesky.callbacks.tiled_writer import TiledWriter @@ -9,6 +12,7 @@ from blueapi.cli.scratch import get_python_environment from blueapi.config import ApplicationConfig, OIDCConfig, ServiceAccount, StompConfig +from blueapi.core.bluesky_types import DataEvent from blueapi.core.context import BlueskyContext from blueapi.core.event import EventStream from blueapi.log import set_up_logging @@ -22,14 +26,14 @@ WorkerTask, ) from blueapi.utils.serialization import access_blob -from blueapi.worker.event import TaskStatusEnum, WorkerEvent, WorkerState +from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent, WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TaskWorker, TrackableTask """This module provides interface between web application and underlying Bluesky context and worker""" - +LOGGER = logging.getLogger(__name__) _CONFIG: ApplicationConfig = ApplicationConfig() @@ -286,3 +290,37 @@ def get_python_env( """Retrieve information about the Python environment""" scratch = config().scratch return get_python_environment(config=scratch, name=name, source=source) + + +@dataclass +class SubHandles: + worker: int + progress: int + data: int + + +def pipe_events(tx: Connection) -> SubHandles: + tw = worker() + + def handler( + worker_event: WorkerEvent | DataEvent | ProgressEvent, + _cor_id: str | None, + ) -> None: + + try: + tx.send(worker_event) + except BrokenPipeError: + LOGGER.warning("Sending event to broken pipe") + pass + + w = tw.worker_events.subscribe(handler) + d = tw.data_events.subscribe(handler) + p = tw.progress_events.subscribe(handler) + return SubHandles(worker=w, data=d, progress=p) + + +def unpipe_events(hnd: SubHandles) -> None: + tw = worker() + tw.worker_events.unsubscribe(hnd.worker) + tw.data_events.unsubscribe(hnd.data) + tw.progress_events.unsubscribe(hnd.progress) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 432cc5455..e8896a2d8 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -14,6 +14,8 @@ HTTPException, Request, Response, + WebSocket, + WebSocketDisconnect, status, ) from fastapi.datastructures import Address @@ -32,14 +34,19 @@ from super_state_machine.errors import TransitionError from blueapi.config import ApplicationConfig, OIDCConfig, Tag +from blueapi.core.bluesky_types import DataEvent from blueapi.service import interface -from blueapi.service.authentication import Fedid, build_access_token_check +from blueapi.service.authentication import ( + Fedid, + build_access_token_check, +) from blueapi.service.middleware import ( ObservabilityContextPropagator, VersionHeaders, ) from blueapi.worker import TrackableTask, WorkerState -from blueapi.worker.event import TaskStatusEnum +from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent +from blueapi.worker.worker_errors import WorkerBusyError from .authorization import OpaClient, validate_tiled_config from .model import ( @@ -66,6 +73,9 @@ TRACER = get_tracer("interface") +AnyEvent = WorkerEvent | DataEvent | ProgressEvent + + def _runner() -> WorkerDispatcher: """Intended to be used only with FastAPI Depends""" if RUNNER is None: @@ -109,6 +119,7 @@ async def inner(app: FastAPI): open_router = APIRouter() secure_router = APIRouter(deprecated=True) secure_router_v1 = APIRouter(prefix="/api/v1") +secure_router_v2 = APIRouter(prefix="/api/v2") def get_app(config: ApplicationConfig): @@ -130,6 +141,7 @@ def get_app(config: ApplicationConfig): } app.include_router(open_router) app.include_router(secure_router_v1, dependencies=dependencies) + app.include_router(secure_router_v2, dependencies=dependencies) app.include_router(secure_router, dependencies=dependencies) app.add_exception_handler(KeyError, on_key_error_404) app.add_exception_handler(jwt.PyJWTError, on_token_error_401) @@ -563,6 +575,51 @@ def logout(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> Response: ) +@secure_router_v2.websocket("/run_plan") +async def run_plan( + ws: WebSocket, runner: Annotated[WorkerDispatcher, Depends(_runner)], user: Fedid +): + LOGGER.info("Starting WS plan as %s", user) + await ws.accept() + rq = await ws.receive_json() + LOGGER.info("Raw request: %s", rq) + try: + task_request: TaskRequest = TaskRequest.model_validate(rq) + LOGGER.info("Plan request: %s", task_request) + task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) + LOGGER.info("Task ID: %s", task_id) + except ValidationError: + LOGGER.error("Args not valid", exc_info=True) + await ws.close(code=1003, reason="invalid args") + return + except KeyError: + LOGGER.error("Plan not found", exc_info=True) + await ws.close(code=1003, reason="unknown plan") + return + + try: + with runner.event_pipe() as events: + LOGGER.info("Created event pipe") + runner.run(interface.begin_task, task=WorkerTask(task_id=task_id)) + async for evt in events: + LOGGER.debug("Event: %s", evt) + await ws.send_json(evt.model_dump(mode="json")) + if isinstance(evt, WorkerEvent) and evt.is_complete(): + LOGGER.info("End of stream") + break + except WorkerBusyError: + LOGGER.error("Worker was busy") + await ws.close(code=1013, reason="Worker busy") + except WebSocketDisconnect: + LOGGER.info("Client disconnected") + runner.run( + interface.cancel_active_task, failure=True, reason="Client disconnected" + ) + else: + LOGGER.info("Plan complete") + await ws.close() + + @start_as_current_span(TRACER, "config") def start(config: ApplicationConfig): import uvicorn diff --git a/src/blueapi/service/runner.py b/src/blueapi/service/runner.py index 2b5a5f37f..83153ed5f 100644 --- a/src/blueapi/service/runner.py +++ b/src/blueapi/service/runner.py @@ -1,10 +1,12 @@ +import asyncio import inspect import logging import signal import uuid -from collections.abc import Callable +from collections.abc import AsyncIterator, Callable from importlib import import_module from multiprocessing import Pool, set_start_method +from multiprocessing.connection import Connection, Pipe from multiprocessing.pool import Pool as PoolClass from typing import Any, ParamSpec, TypeVar @@ -18,8 +20,11 @@ from pydantic import TypeAdapter from blueapi.config import ApplicationConfig -from blueapi.service.interface import setup, teardown +from blueapi.core.bluesky_types import DataEvent +from blueapi.service import interface +from blueapi.service.interface import SubHandles, setup, teardown from blueapi.service.model import EnvironmentResponse +from blueapi.worker.event import ProgressEvent, WorkerEvent # The default multiprocessing start method is fork set_start_method("spawn", force=True) @@ -145,11 +150,57 @@ def run( kwargs, ) + def event_pipe(self): + return EventPipe(self) + @property def state(self) -> EnvironmentResponse: return self._state +class EventStream: + def __init__(self, rx: Connection): + self._rx = rx + + def __aiter__(self) -> AsyncIterator[WorkerEvent | DataEvent | ProgressEvent]: + return self + + async def __anext__(self) -> WorkerEvent | DataEvent | ProgressEvent: + data_available = asyncio.Event() + asyncio.get_event_loop().add_reader(self._rx.fileno(), data_available.set) + try: + while not self._rx.poll(): + await data_available.wait() + data_available.clear() + return self._rx.recv() + except BrokenPipeError: + raise StopAsyncIteration() from None + finally: + asyncio.get_event_loop().remove_reader(self._rx.fileno()) + + +class EventPipe: + runner: WorkerDispatcher + handles: list[tuple[SubHandles, Connection]] + + def __init__(self, runner: WorkerDispatcher): + self.runner = runner + self.handles = [] + + def __enter__(self) -> EventStream: + tx, rx = Pipe() + hnd = self.runner.run(interface.pipe_events, tx) + LOGGER.debug("Subscribing new event pipe: %s", hnd) + self.handles.append((hnd, tx)) + return EventStream(rx) + + def __exit__(self, *exc): + hnd, conn = self.handles.pop() + LOGGER.debug("Unsubscribing event pipe: %s", hnd) + conn.close() + self.runner.run(interface.unpipe_events, hnd) + + class InvalidRunnerStateError(Exception): def __init__(self, message): super().__init__(message) diff --git a/tests/unit_tests/cli/test_cli.py b/tests/unit_tests/cli/test_cli.py index 3e27cfc98..fe71825c1 100644 --- a/tests/unit_tests/cli/test_cli.py +++ b/tests/unit_tests/cli/test_cli.py @@ -7,7 +7,6 @@ from pathlib import Path from textwrap import dedent from typing import Any, TypeVar -from unittest import mock from unittest.mock import Mock, patch import pytest @@ -385,9 +384,9 @@ def test_run_plan_feedback( main, ["controller", "run", "-i", "cm12345-1", "name"], ) + bc.add_callback.assert_called_once() bc.run_task.assert_called_once_with( TaskRequest(name="name", params={}, instrument_session="cm12345-1"), - on_event=mock.ANY, ) assert res.exit_code == 0 assert res.stdout == message diff --git a/tests/unit_tests/service/test_authentication.py b/tests/unit_tests/service/test_authentication.py index 01bc426e2..e3ee86f9d 100644 --- a/tests/unit_tests/service/test_authentication.py +++ b/tests/unit_tests/service/test_authentication.py @@ -189,18 +189,25 @@ def test_tiled_auth_sync_auth_flow(): @pytest.mark.parametrize( - "header,token", + "header,cookie,token", [ - (None, None), - ("ApiKey foobar", None), - ("Bearer foobar", "foobar"), - ("Bearer with_whitespace ", "with_whitespace"), - ("Bearerfoobar", None), + (None, None, None), + ("", None, None), + ("ApiKey foobar", None, None), + ("Bearer foobar", None, "foobar"), + ("Bearer with_whitespace ", None, "with_whitespace"), + ("Bearerfoobar", None, None), + (None, "Bearer foobar", "foobar"), + ("", "Bearer foo", "foo"), + ("Bearer foo", "bearer bar", "foo"), ], ) -def test_unchecked_bearer_token(header: str | None, token: str | None): +def test_unchecked_bearer_token( + header: str | None, cookie: str | None, token: str | None +): req = Mock() req.headers.get.side_effect = lambda key: header if key == "Authorization" else None + req.cookies.get.side_effect = lambda key: cookie if key == "Authorization" else None assert unchecked_bearer_token(req) == token