Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def on_event(
@controller.command(name="run")
@click.argument("name", type=str)
@click.argument("parameters", type=ParametersType(), default={}, required=False)
@click.option("--ws", type=bool, is_flag=True, default=False)
@click.option(
"--foreground/--background", "--fg/--bg", type=bool, is_flag=True, default=True
)
Expand Down Expand Up @@ -348,6 +349,7 @@ def run_plan(
name: str,
timeout: float | None,
foreground: bool,
ws: bool,
instrument_session: str,
parameters: TaskParameters,
) -> None:
Expand All @@ -374,7 +376,13 @@ def on_event(event: AnyEvent) -> None:
elif isinstance(event, DataEvent):
callback(event.name, event.doc)

resp = client.run_task(task, on_event=on_event)
client.add_callback(on_event)

if ws:
resp = client.run_blocking(task)
else:
resp = client.run_task(task)

match resp.result:
case TaskResult(result=None, type="NoneType"):
print("Plan succeeded")
Expand Down
20 changes: 20 additions & 0 deletions src/blueapi/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,26 @@ def get_active_task(self) -> WorkerTask:

return self.active_task

@start_as_current_span(TRACER, "request")
def run_blocking(
self, request: TaskRequest, on_event: OnAnyEvent | None = None
) -> TaskStatus:
for event in self._rest.run_blocking(request):
if on_event is not None:
on_event(event)
for cb in self._callbacks.values():
try:
cb(event)
except Exception as e:
log.error(f"Callback ({cb}) failed for event: {event}", exc_info=e)
if isinstance(event, WorkerEvent) and event.is_complete():
if event.task_status is None:
raise BlueskyRemoteControlError(
"Server completed without task status"
)
return event.task_status
raise BlueskyRemoteControlError("Connection closed before plan completed.")

@start_as_current_span(TRACER, "task", "timeout")
def run_task(
self,
Expand Down
33 changes: 30 additions & 3 deletions src/blueapi/client/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
get_tracer,
start_as_current_span,
)
from pydantic import BaseModel, TypeAdapter, ValidationError
from pydantic import BaseModel, TypeAdapter, ValidationError, WebsocketUrl
from pydantic_core import PydanticSerializationError
from websockets.sync.client import connect

from blueapi import __version__
from blueapi.client import client
from blueapi.client.event_bus import AnyEvent
from blueapi.config import RestConfig
from blueapi.service.authentication import JWTAuth, SessionManager
from blueapi.service.model import (
Expand All @@ -39,6 +41,8 @@

LOGGER = logging.getLogger(__name__)

USER_AGENT = f"blueapi cli {__version__}"


class BlueskyRequestError(Exception):
"""An error response from the blueapi server."""
Expand Down Expand Up @@ -307,14 +311,15 @@ def _request_and_deserialize(
) -> T:
url = self._config.url.unicode_string().removesuffix("/") + suffix
# Get the trace context to propagate to the REST API
carr = get_context_propagator()
headers = get_context_propagator()
headers["User-Agent"] = USER_AGENT
try:
response = self._pool.request(
method,
url,
json=data,
params=params,
headers=carr,
headers=headers,
auth=JWTAuth(self._session_manager),
)
except requests.exceptions.ConnectionError as ce:
Expand All @@ -340,6 +345,28 @@ def _request_and_deserialize(
)
return deserialized

def run_blocking(self, req: TaskRequest):
url = self._ws_address().unicode_string().removesuffix("/") + "/run_plan"
headers = get_context_propagator()
if self._session_manager:
auth = self._session_manager.get_valid_access_token()
headers["Authorization"] = f"Bearer {auth}"
with connect(
url,
additional_headers=headers,
user_agent_header=USER_AGENT,
) as ws:
ws.send(req.model_dump_json())
for message in ws:
event = TypeAdapter(AnyEvent).validate_json(message)
yield event

def _ws_address(self) -> WebsocketUrl:
# url = WebsocketUrl.build(
# scheme="ws", host=api.host, port=api.port, path=api.path
# )
return WebsocketUrl("ws://localhost:8000/")


# https://github.com/DiamondLightSource/blueapi/issues/1256 - remove before 2.0
def __getattr__(name: str):
Expand Down
12 changes: 7 additions & 5 deletions src/blueapi/service/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import httpx
import jwt
import requests
from fastapi import Depends, HTTPException, Request
from fastapi import Depends, HTTPException
from fastapi.requests import HTTPConnection
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import TypeAdapter
from requests.auth import AuthBase
Expand Down Expand Up @@ -278,14 +279,15 @@ def sync_auth_flow(self, request):
yield request


def unchecked_bearer_token(req: Request) -> str | None:
def unchecked_bearer_token(req: HTTPConnection) -> str | None:
"""Get bearer token value from authorization header"""
# This is an abridged version of the same feature of
# OAuth2AuthorizationCodeBearer from fastapi. Replicating here prevents
# passing unused configuration and means the schema does not include auth
# details for servers that do not support it.
auth = req.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(auth)
auth_cookie = req.cookies.get("Authorization")
scheme, param = get_authorization_scheme_param(auth or auth_cookie)
if scheme.casefold() != "bearer":
return None
return param.strip()
Expand All @@ -303,7 +305,7 @@ def build_access_token_check(config: OIDCConfig):
"""
jwkclient = jwt.PyJWKClient(config.jwks_uri)

def validate_bearer_token(request: Request, token: UncheckedBearerToken):
def validate_bearer_token(request: HTTPConnection, token: UncheckedBearerToken):
"""Check that a bearer token is valid and inject into request state"""
if not token:
raise HTTPException(
Expand All @@ -326,7 +328,7 @@ def validate_bearer_token(request: Request, token: UncheckedBearerToken):
return validate_bearer_token


def access_token(request: Request) -> Mapping[str, Any] | None:
def access_token(request: HTTPConnection) -> Mapping[str, Any] | None:
"""Get the decoded and verified access token of the user making the request"""
return getattr(request.state, "decoded_access_token", None)

Expand Down
42 changes: 40 additions & 2 deletions src/blueapi/service/interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import logging
from collections.abc import Mapping
from dataclasses import dataclass
from functools import cache
from multiprocessing.connection import Connection
from typing import Any

from bluesky.callbacks.tiled_writer import TiledWriter
Expand All @@ -9,6 +12,7 @@

from blueapi.cli.scratch import get_python_environment
from blueapi.config import ApplicationConfig, OIDCConfig, ServiceAccount, StompConfig
from blueapi.core.bluesky_types import DataEvent
from blueapi.core.context import BlueskyContext
from blueapi.core.event import EventStream
from blueapi.log import set_up_logging
Expand All @@ -22,14 +26,14 @@
WorkerTask,
)
from blueapi.utils.serialization import access_blob
from blueapi.worker.event import TaskStatusEnum, WorkerEvent, WorkerState
from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent, WorkerState
from blueapi.worker.task import Task
from blueapi.worker.task_worker import TaskWorker, TrackableTask

"""This module provides interface between web application and underlying Bluesky
context and worker"""


LOGGER = logging.getLogger(__name__)
_CONFIG: ApplicationConfig = ApplicationConfig()


Expand Down Expand Up @@ -286,3 +290,37 @@ def get_python_env(
"""Retrieve information about the Python environment"""
scratch = config().scratch
return get_python_environment(config=scratch, name=name, source=source)


@dataclass
class SubHandles:
worker: int
progress: int
data: int


def pipe_events(tx: Connection) -> SubHandles:
tw = worker()

def handler(
worker_event: WorkerEvent | DataEvent | ProgressEvent,
_cor_id: str | None,
) -> None:

try:
tx.send(worker_event)
except BrokenPipeError:
LOGGER.warning("Sending event to broken pipe")
pass
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
Comment thread
tpoliaw marked this conversation as resolved.
Dismissed

w = tw.worker_events.subscribe(handler)
d = tw.data_events.subscribe(handler)
p = tw.progress_events.subscribe(handler)
return SubHandles(worker=w, data=d, progress=p)


def unpipe_events(hnd: SubHandles):
tw = worker()
tw.worker_events.unsubscribe(hnd.worker)
tw.data_events.unsubscribe(hnd.data)
tw.progress_events.unsubscribe(hnd.progress)
59 changes: 57 additions & 2 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
HTTPException,
Request,
Response,
WebSocket,
WebSocketDisconnect,
status,
)
from fastapi.datastructures import Address
Expand All @@ -32,14 +34,19 @@
from super_state_machine.errors import TransitionError

from blueapi.config import ApplicationConfig, OIDCConfig, Tag
from blueapi.core.bluesky_types import DataEvent
from blueapi.service import interface
from blueapi.service.authentication import Fedid, build_access_token_check
from blueapi.service.authentication import (
Fedid,
build_access_token_check,
)
from blueapi.service.middleware import (
ObservabilityContextPropagator,
VersionHeaders,
)
from blueapi.worker import TrackableTask, WorkerState
from blueapi.worker.event import TaskStatusEnum
from blueapi.worker.event import ProgressEvent, TaskStatusEnum, WorkerEvent
from blueapi.worker.worker_errors import WorkerBusyError

from .authorization import OpaClient, validate_tiled_config
from .model import (
Expand All @@ -66,6 +73,9 @@
TRACER = get_tracer("interface")


AnyEvent = WorkerEvent | DataEvent | ProgressEvent


def _runner() -> WorkerDispatcher:
"""Intended to be used only with FastAPI Depends"""
if RUNNER is None:
Expand Down Expand Up @@ -563,6 +573,51 @@ def logout(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> Response:
)


@secure_router_v1.websocket("/run_plan")
async def run_plan(
ws: WebSocket, runner: Annotated[WorkerDispatcher, Depends(_runner)], user: Fedid
):
LOGGER.info("Starting WS plan as %s", user)
await ws.accept()
rq = await ws.receive_json()
LOGGER.info("Raw request: %s", rq)
try:
task_request: TaskRequest = TaskRequest.model_validate(rq)
LOGGER.info("Plan request: %s", task_request)
task_id: str = runner.run(interface.submit_task, task_request, {"user": user})
LOGGER.info("Task ID: %s", task_id)
except ValidationError:
LOGGER.error("Args not valid", exc_info=True)
await ws.close(code=1003, reason="invalid args")
return
except KeyError:
LOGGER.error("Plan not found", exc_info=True)
await ws.close(code=1003, reason="unknown plan")
return

try:
with runner.event_pipe() as events:
LOGGER.info("Created event pipe")
runner.run(interface.begin_task, task=WorkerTask(task_id=task_id))
async for evt in events:
LOGGER.debug("Event: %s", evt)
await ws.send_json(evt.model_dump(mode="json"))
if isinstance(evt, WorkerEvent) and evt.is_complete():
LOGGER.info("End of stream")
break
except WorkerBusyError:
LOGGER.error("Worker was busy")
await ws.close(code=1013, reason="Worker busy")
except WebSocketDisconnect:
LOGGER.info("Client disconnected")
runner.run(
interface.cancel_active_task, failure=True, reason="Client disconnected"
)
else:
LOGGER.info("Plan complete")
await ws.close()


@start_as_current_span(TRACER, "config")
def start(config: ApplicationConfig):
import uvicorn
Expand Down
Loading
Loading