diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 20de15892..10ddaf00b 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -1,11 +1,10 @@ import itertools import logging import time -from collections.abc import Iterable +from collections.abc import Iterable, KeysView from concurrent.futures import Future from contextlib import suppress from functools import cached_property -from itertools import chain from pathlib import Path from typing import Any, Self @@ -56,6 +55,35 @@ log = logging.getLogger(__name__) +_REPR_MAX_LENGTH = 100 +_REPR_MAX_ARGS_INLINE = 3 + + +def _pretty_type(schema: dict[str, Any]) -> str: + if "$ref" in schema: + return schema["$ref"].split("/")[-1] + + if schema.get("type") == "array": + item_schema = schema.get("items", {}) + inner = _pretty_type(item_schema) + return f"list[{inner}]" + + if "anyOf" in schema: + return " | ".join(_pretty_type(s) for s in schema["anyOf"]) + + json_type = schema.get("type") + type_map = { + "string": "str", + "integer": "int", + "boolean": "bool", + "number": "float", + "object": "dict", + } + if isinstance(json_type, str): + return type_map.get(json_type, json_type.split(".")[-1]) + + return "Any" + class MissingInstrumentSessionError(Exception): pass @@ -164,7 +192,7 @@ def help_text(self) -> str: return self.model.description or f"Plan {self!r}" @property - def properties(self) -> set[str]: + def properties(self) -> KeysView[str]: return self.model.parameter_schema.get("properties", {}).keys() @property @@ -201,10 +229,39 @@ def _build_args(self, *args, **kwargs): raise TypeError(f"Missing argument(s) for {missing}") return params - def __repr__(self): - opts = [p for p in self.properties if p not in self.required] - params = ", ".join(chain(self.required, (f"{opt}=None" for opt in opts))) - return f"{self.name}({params})" + def __repr__(self) -> str: + def _format_arg(name: str, info: dict[str, Any], required: set[str]) -> str: + typ = _pretty_type(info) + + is_required = name in required + has_default = "default" in info + default = info.get("default") + + if is_required: + return f"{name}: {typ}" + + # optional with explicit default + if has_default: + if default is None: + return f"{name}: {typ} | None = None" + return f"{name}: {typ} = {repr(default)}" + + # optional with no default + return f"{name}: {typ} | None = None" + + props = self.model.parameter_schema.get("properties", {}) + args = [ + _format_arg(name, info, set(self.required)) for name, info in props.items() + ] + single_line = f"{self.name}({', '.join(args)})" + + if len(single_line) <= _REPR_MAX_LENGTH and len(args) <= _REPR_MAX_ARGS_INLINE: + return single_line + + indent = " " + # Fall back to multiline if too many arguments or too long. + multiline_args = ",\n".join(f"{indent}{arg}" for arg in args) + return f"{self.name}(\n{multiline_args}\n)" class BlueapiClient: diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 2e5d67de2..37c481cb6 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -5,7 +5,7 @@ from importlib import import_module, metadata from inspect import Parameter, isclass, signature from types import ModuleType, NoneType, UnionType -from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints +from typing import Any, TypeVar, Union, get_args, get_origin, get_type_hints from bluesky.protocols import HasName from bluesky.run_engine import RunEngine @@ -459,14 +459,16 @@ def _type_spec_for_function( ): default_factory = self._composite_factory(arg_type) _type = SkipJsonSchema[self._convert_type(arg_type, no_default)] + field_info = FieldInfo(default_factory=default_factory) else: - default_factory = DefaultFactory(para.default) _type = self._convert_type(arg_type, no_default) - factory = None if no_default else default_factory - new_args[name] = ( - _type, - FieldInfo(default_factory=factory), - ) + if no_default: + field_info = FieldInfo() + else: + field_info = FieldInfo(default=para.default) + + new_args[name] = (_type, field_info) + return new_args def _convert_type(self, typ: Any, no_default: bool = True) -> type: @@ -517,19 +519,3 @@ def _inject_composite(): return composite_class(**devices) return _inject_composite - - -D = TypeVar("D") - - -class DefaultFactory(Generic[D]): - _value: D - - def __init__(self, value: D): - self._value = value - - def __call__(self) -> D: - return self._value - - def __eq__(self, other) -> bool: - return other.__class__ == self.__class__ and self._value == other._value diff --git a/tests/system_tests/plans.json b/tests/system_tests/plans.json index 0124b8657..e55e3b474 100644 --- a/tests/system_tests/plans.json +++ b/tests/system_tests/plans.json @@ -20,7 +20,8 @@ }, "num": { "title": "Num", - "type": "integer" + "type": "integer", + "default": 1 }, "delay": { "anyOf": [ @@ -34,12 +35,14 @@ "type": "array" } ], + "default": 0.0, "title": "Delay" }, "metadata": { "additionalProperties": true, "title": "Metadata", - "type": "object" + "type": "object", + "default": null } }, "required": [ @@ -681,7 +684,8 @@ "metadata": { "additionalProperties": true, "title": "Metadata", - "type": "object" + "type": "object", + "default": null } }, "required": [ @@ -711,11 +715,13 @@ }, "group": { "title": "Group", - "type": "string" + "type": "string", + "default": null }, "wait": { "title": "Wait", - "type": "boolean" + "type": "boolean", + "default": false } }, "required": [ @@ -745,11 +751,13 @@ }, "group": { "title": "Group", - "type": "string" + "type": "string", + "default": null }, "wait": { "title": "Wait", - "type": "boolean" + "type": "boolean", + "default": false } }, "required": [ @@ -773,7 +781,8 @@ }, "group": { "title": "Group", - "type": "string" + "type": "string", + "default": null } }, "required": [ @@ -796,7 +805,8 @@ }, "group": { "title": "Group", - "type": "string" + "type": "string", + "default": null } }, "required": [ @@ -832,11 +842,13 @@ "properties": { "group": { "title": "Group", - "type": "string" + "type": "string", + "default": null }, "timeout": { "title": "Timeout", - "type": "number" + "type": "number", + "default": null } }, "title": "wait", diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index eaf96a3b2..b22e1d960 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -806,7 +806,72 @@ def test_plan_fallback_help_text(client): ), client, ) - assert plan.help_text == "Plan foo(one, two=None)" + assert plan.help_text == "Plan foo(one: Any, two: Any | None = None)" + + +def test_plan_multi_parameter_fallback_help_text(client): + plan = Plan( + "foo", + PlanModel( + name="foo", + schema={ + "properties": { + "one": {}, + "two": { + "anyOf": [{"items": {}, "type": "array"}, {"type": "boolean"}], + }, + "three": {"default": 3}, + "four": {"default": None}, + }, + "required": ["one", "two"], + }, + ), + client, + ) + assert ( + plan.help_text == "Plan foo(\n" + " one: Any,\n" + " two: list[Any] | bool,\n" + " three: Any = 3,\n" + " four: Any | None = None\n" + ")" + ) + + +def test_plan_help_text_with_ref(client): + schema = { + "$defs": { + "Spec": { + "properties": { + "foo": {"type": "integer"}, + "bar": {"$ref": "#/$defs/InnerSpec"}, + }, + "required": ["foo", "bar"], + }, + "InnerSpec": { + "properties": { + "x": {"type": "number"}, + "y": {"default": 10, "type": "number"}, + }, + "required": ["x"], + }, + }, + "properties": { + "spec": {"$ref": "#/$defs/Spec"}, + "meta": {"type": "string", "default": "abc"}, + }, + "required": ["spec"], + } + + plan = Plan( + "ref_plan", + PlanModel(name="ref_plan", schema=schema), + client, + ) + + expected = "Plan ref_plan(spec: Spec, meta: str = 'abc')" + + assert plan.help_text == expected def test_plan_properties(client): diff --git a/tests/unit_tests/core/test_context.py b/tests/unit_tests/core/test_context.py index 23462a148..8cb7d953f 100644 --- a/tests/unit_tests/core/test_context.py +++ b/tests/unit_tests/core/test_context.py @@ -1,9 +1,10 @@ from __future__ import annotations from dataclasses import dataclass, field +from inspect import Parameter from pathlib import Path from types import ModuleType, NoneType -from typing import Any, Generic, TypeVar, Union +from typing import Any, Generic, TypeVar, Union, get_args, get_type_hints from unittest.mock import MagicMock, Mock, patch import pytest @@ -44,7 +45,7 @@ TiledConfig, ) from blueapi.core import BlueskyContext, is_bluesky_compatible_device -from blueapi.core.context import DefaultFactory, generic_bounds, qualified_name +from blueapi.core.context import generic_bounds, qualified_name from blueapi.core.protocols import DeviceConnectResult, DeviceManager from blueapi.utils.invalid_config_error import InvalidConfigError @@ -85,6 +86,10 @@ def has_some_params(foo: int = 42, bar: str = "bar") -> MsgGenerator: yield from () +def has_optional_parameter(foo: dict[str, Any] | None = None) -> MsgGenerator: + yield from () + + def has_typeless_param(foo) -> MsgGenerator: yield from () @@ -166,7 +171,9 @@ def some_configurable() -> SomeConfigurable: return SomeConfigurable() -@pytest.mark.parametrize("plan", [has_no_params, has_one_param, has_some_params]) +@pytest.mark.parametrize( + "plan", [has_no_params, has_one_param, has_some_params, has_optional_parameter] +) def test_add_plan(empty_context: BlueskyContext, plan: PlanGenerator): empty_context.register_plan(plan) assert plan.__name__ in empty_context.plans @@ -353,12 +360,23 @@ def test_add_metadata_with_config( assert md in empty_context.run_engine.md.items() -def test_function_spec(empty_context: BlueskyContext): +def test_function_spec_with_some_params(empty_context: BlueskyContext): spec = empty_context._type_spec_for_function(has_some_params) assert spec["foo"][0] is int - assert spec["foo"][1].default_factory == DefaultFactory(42) + assert spec["foo"][1].default == 42 assert spec["bar"][0] is str - assert spec["bar"][1].default_factory == DefaultFactory("bar") + assert spec["bar"][1].default == "bar" + + +def test_function_spec_with_optional_params(empty_context: BlueskyContext): + spec = empty_context._type_spec_for_function(has_optional_parameter) + types = get_type_hints(has_optional_parameter) + arg_type = types.get("foo", Parameter.empty) + + _type = SkipJsonSchema[empty_context._convert_type(arg_type, False)] + inner_type, *annotations = get_args(_type) + assert spec["foo"][0] == inner_type + assert spec["foo"][1].default is None def test_basic_type_conversion(empty_context: BlueskyContext): @@ -439,7 +457,7 @@ def default_movable(mov: Movable = inject("demo")) -> MsgGenerator: spec = empty_context._type_spec_for_function(default_movable) movable_ref = empty_context._reference(Movable) assert spec["mov"][0] == movable_ref - assert spec["mov"][1].default_factory == DefaultFactory("demo") + assert spec["mov"][1].default == "demo" def test_generic_default_device_reference(empty_context: BlueskyContext): @@ -449,7 +467,7 @@ def default_movable(mov: Movable[float] = inject("demo")) -> MsgGenerator: spec = empty_context._type_spec_for_function(default_movable) motor_ref = empty_context._reference(Movable[float]) assert spec["mov"][0] == motor_ref - assert spec["mov"][1].default_factory == DefaultFactory("demo") + assert spec["mov"][1].default == "demo" class ConcreteStoppable(Stoppable): @@ -499,7 +517,7 @@ def test_str_default(empty_context: BlueskyContext, sim_motor: Motor, alt_motor: spec = empty_context._type_spec_for_function(has_default_reference) assert spec["m"][0] is movable_ref - assert (df := spec["m"][1].default_factory) and df() == SIM_MOTOR_NAME # type: ignore + assert spec["m"][1].default == SIM_MOTOR_NAME assert has_default_reference.__name__ in empty_context.plans model = empty_context.plans[has_default_reference.__name__].model @@ -518,7 +536,7 @@ def test_nested_str_default( spec = empty_context._type_spec_for_function(has_default_nested_reference) assert spec["m"][0] == list[movable_ref] - assert (df := spec["m"][1].default_factory) and df() == [SIM_MOTOR_NAME] # type: ignore + assert spec["m"][1].default == [SIM_MOTOR_NAME] assert has_default_nested_reference.__name__ in empty_context.plans model = empty_context.plans[has_default_nested_reference.__name__].model @@ -622,7 +640,7 @@ def demo_plan(foo: int | None = None) -> MsgGenerator: empty_context.register_plan(demo_plan) schema = empty_context.plans["demo_plan"].model.model_json_schema() assert schema["properties"] == { - "foo": {"title": "Foo", "type": "integer"}, + "foo": {"title": "Foo", "type": "integer", "default": None}, } assert "foo" not in schema.get("required", []) @@ -650,7 +668,11 @@ def demo_plan(foo: int | str | None = None) -> MsgGenerator: empty_context.register_plan(demo_plan) schema = empty_context.plans["demo_plan"].model.model_json_schema() assert schema["properties"] == { - "foo": {"title": "Foo", "anyOf": [{"type": "integer"}, {"type": "string"}]} + "foo": { + "title": "Foo", + "anyOf": [{"type": "integer"}, {"type": "string"}], + "default": None, + } } assert "foo" not in schema.get("required", []) @@ -664,7 +686,10 @@ def demo_plan(foo: int | None) -> MsgGenerator: empty_context.register_plan(demo_plan) schema = empty_context.plans["demo_plan"].model.model_json_schema() assert schema["properties"] == { - "foo": {"title": "Foo", "anyOf": [{"type": "integer"}, {"type": "null"}]} + "foo": { + "title": "Foo", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } } assert "foo" in schema.get("required", [])