From 93cb2538351184aba5f93245083cbce5a308fb1a Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Tue, 31 Mar 2026 12:00:23 -0700 Subject: [PATCH 01/15] Updated Cerner and FHIR connection --- playground/scenario_post_visit.py | 29 ++--- src/agents/fhir_cerner_mcp.py | 8 +- src/agents/fhir_epic_mcp.py | 8 +- src/agents/mcp_entrypoint.py | 52 ++++----- src/bindings/factory.py | 18 +-- src/bindings/mcp_server/server.py | 6 +- src/bindings/rest_api/app.py | 10 +- src/connectors/fhir_cerner/logic.py | 158 +++++++++------------------ src/connectors/fhir_cerner/schema.py | 114 ++++++++++--------- src/connectors/fhir_epic/logic.py | 114 ++++--------------- src/connectors/fhir_epic/schema.py | 123 ++++++++++----------- src/connectors/manifest.py | 47 ++++++-- tests/test_fhir_cerner.py | 137 ++++++++++++++--------- tests/test_fhir_epic.py | 102 +++++++++-------- tests/test_toolhive_agent.py | 10 +- 15 files changed, 440 insertions(+), 496 deletions(-) diff --git a/playground/scenario_post_visit.py b/playground/scenario_post_visit.py index 9581c66..f47b681 100644 --- a/playground/scenario_post_visit.py +++ b/playground/scenario_post_visit.py @@ -61,10 +61,9 @@ async def run_scenario(): logger.info(f"Searching for patient: {patient_search_params}") try: - patient_action = connector.get_action("read_patient") - patient_result = await patient_action.internal_execute( - FhirPatientReadInput(search_params=patient_search_params), - trace_id=trace_id + patient_result = await connector.internal_execute( + FhirPatientReadInput(action="read_patient", search_params=patient_search_params), + trace_id=trace_id, ) patient_id = patient_result.resource.get("id") logger.info(f"Found Patient ID: {patient_id}") @@ -82,17 +81,19 @@ async def run_scenario(): logger.info(f"Finding encounter for patient {patient_id} on {today}") try: - encounter_action = connector.get_action("search_encounter") - enc_result = await encounter_action.internal_execute( - FhirEncounterSearchInput(search_params=encounter_params), - trace_id=trace_id + enc_result = await connector.internal_execute( + FhirEncounterSearchInput(action="search_encounter", search_params=encounter_params), + trace_id=trace_id, ) - + if not enc_result.resources: logger.warning("No encounters found for this patient today. Falling back to most recent.") - enc_result = await encounter_action.internal_execute( - FhirEncounterSearchInput(search_params={"patient": patient_id, "status": "finished"}), - trace_id=trace_id + enc_result = await connector.internal_execute( + FhirEncounterSearchInput( + action="search_encounter", + search_params={"patient": patient_id, "status": "finished"}, + ), + trace_id=trace_id, ) if not enc_result.resources: @@ -110,6 +111,7 @@ async def run_scenario(): encoded_note = base64.b64encode(note_content.encode('utf-8')).decode('utf-8') doc_input = FhirDocumentReferenceCreateInput( + action="create_document_reference", identifier=[{"system": "urn:oid:1.2.3", "value": f"NOTE-{datetime.now().timestamp()}"}], status="current", type={"coding": [{"system": "http://loinc.org", "code": "11506-3", "display": "Progress Note"}]}, @@ -125,8 +127,7 @@ async def run_scenario(): logger.info(f"Uploading clinical note for Encounter {encounter_id}") try: - doc_action = connector.get_action("create_document_reference") - doc_result = await doc_action.internal_execute(doc_input, trace_id=trace_id) + doc_result = await connector.internal_execute(doc_input, trace_id=trace_id) logger.info(f"SUCCESS! Created DocumentReference: {doc_result.resource_id}") print(f"\nWorkflow Complete. Resource Created: {doc_result.resource_id}") except Exception as e: diff --git a/src/agents/fhir_cerner_mcp.py b/src/agents/fhir_cerner_mcp.py index 5628bd6..d329170 100644 --- a/src/agents/fhir_cerner_mcp.py +++ b/src/agents/fhir_cerner_mcp.py @@ -55,10 +55,8 @@ async def fhir_cerner_read_patient( if not cerner: raise RuntimeError("fhir_cerner connector not configured") - action = cerner.get_action("read_patient") - if patient_id: - params = FhirCernerPatientReadInput(resource_id=patient_id) + params = FhirCernerPatientReadInput(action="read_patient", resource_id=patient_id) elif family_name or given_name: search = { k: v @@ -69,11 +67,11 @@ async def fhir_cerner_read_patient( }.items() if v } - params = FhirCernerPatientReadInput(search_params=search) + params = FhirCernerPatientReadInput(action="read_patient", search_params=search) else: raise ValueError("Provide patient_id OR at least family_name/given_name") - result = await action.internal_execute(params, trace_id=trace_id) + result = await cerner.internal_execute(params, trace_id=trace_id) resource = result.resource name_parts = resource.get("name", [{}])[0] diff --git a/src/agents/fhir_epic_mcp.py b/src/agents/fhir_epic_mcp.py index d7f6335..b196b7a 100644 --- a/src/agents/fhir_epic_mcp.py +++ b/src/agents/fhir_epic_mcp.py @@ -56,10 +56,8 @@ async def fhir_epic_read_patient( if not epic: raise RuntimeError("fhir_epic connector not configured") - action = epic.get_action("read_patient") - if patient_id: - params = FhirEpicPatientReadInput(resource_id=patient_id) + params = FhirEpicPatientReadInput(action="read_patient", resource_id=patient_id) elif family_name or given_name: search = { k: v @@ -70,11 +68,11 @@ async def fhir_epic_read_patient( }.items() if v } - params = FhirEpicPatientReadInput(search_params=search) + params = FhirEpicPatientReadInput(action="read_patient", search_params=search) else: raise ValueError("Provide patient_id OR at least family_name/given_name") - result = await action.internal_execute(params, trace_id=trace_id) + result = await epic.internal_execute(params, trace_id=trace_id) resource = result.resource name_parts = resource.get("name", [{}])[0] diff --git a/src/agents/mcp_entrypoint.py b/src/agents/mcp_entrypoint.py index dee264e..9d974eb 100644 --- a/src/agents/mcp_entrypoint.py +++ b/src/agents/mcp_entrypoint.py @@ -4,12 +4,14 @@ This module is the main entrypoint for the Node Wire MCP server. When run, it exposes healthcare workflow tools via the MCP stdio transport: - • fhir_cerner_read_patient — fetch a single patient from Cerner FHIR R4 - • fhir_cerner_search_patients — fetch multiple patients from Cerner (multi-ID or name) - • fhir_epic_read_patient — fetch a single patient from Epic FHIR R4 - • fhir_epic_search_patients — fetch multiple patients from Epic (multi-ID or name) - • google_drive_upload_file — write a file to Google Drive - • smtp_send_email — send an email via SMTP + • fhir_cerner_read_patient — fetch a patient from Cerner FHIR R4 + • fhir_cerner_search_patients — search multiple patients in Cerner + • fhir_cerner_search_encounters — search encounters in Cerner + • fhir_epic_read_patient — fetch a patient from Epic FHIR R4 + • fhir_epic_search_patients — search multiple patients in Epic + • fhir_epic_search_encounters — search encounters in Epic + • google_drive_upload_file — write a file to Google Drive + • smtp_send_email — send an email via SMTP ToolHive manages the container lifecycle, injects secrets as environment variables, and proxies the stdio MCP stream to HTTP/SSE for clients. @@ -107,12 +109,11 @@ async def fhir_cerner_read_patient( if not cerner: raise RuntimeError("fhir_cerner connector not configured") - action = cerner.get_action("read_patient") - if patient_id: - params = FhirCernerPatientReadInput(resource_id=patient_id) + params = FhirCernerPatientReadInput(action="read_patient", resource_id=patient_id) elif family_name or given_name or name: params = FhirCernerPatientReadInput( + action="read_patient", given_name=given_name or None, family_name=family_name or None, name=name or None, @@ -121,7 +122,7 @@ async def fhir_cerner_read_patient( else: raise ValueError("Provide patient_id OR at least family_name / given_name / name") - result = await action.internal_execute(params, trace_id=trace_id) + result = await cerner.internal_execute(params, trace_id=trace_id) resource = result.resource # Extract a clean summary for the LLM @@ -184,12 +185,11 @@ async def fhir_epic_read_patient( if not epic: raise RuntimeError("fhir_epic connector not configured") - action = epic.get_action("read_patient") - if patient_id: - params = FhirEpicPatientReadInput(resource_id=patient_id) + params = FhirEpicPatientReadInput(action="read_patient", resource_id=patient_id) elif family_name or given_name or name: params = FhirEpicPatientReadInput( + action="read_patient", given_name=given_name or None, family_name=family_name or None, name=name or None, @@ -198,7 +198,7 @@ async def fhir_epic_read_patient( else: raise ValueError("Provide patient_id OR at least family_name / given_name / name") - result = await action.internal_execute(params, trace_id=trace_id) + result = await epic.internal_execute(params, trace_id=trace_id) resource = result.resource # Clean extract for LLM @@ -258,13 +258,12 @@ async def fhir_cerner_search_patients( if not cerner: raise RuntimeError("fhir_cerner connector not configured") - action = cerner.get_action("search_patients") - if patient_ids.strip(): ids = [i.strip() for i in patient_ids.split(",") if i.strip()] - params = FhirCernerPatientSearchInput(resource_ids=ids) + params = FhirCernerPatientSearchInput(action="search_patients", resource_ids=ids) elif family_name or given_name or name or birthdate: params = FhirCernerPatientSearchInput( + action="search_patients", given_name=given_name or None, family_name=family_name or None, name=name or None, @@ -276,7 +275,7 @@ async def fhir_cerner_search_patients( "family_name / given_name / name / birthdate" ) - result = await action.internal_execute(params, trace_id=trace_id) + result = await cerner.internal_execute(params, trace_id=trace_id) summaries = [] for resource in result.resources: @@ -337,13 +336,12 @@ async def fhir_epic_search_patients( if not epic: raise RuntimeError("fhir_epic connector not configured") - action = epic.get_action("search_patients") - if patient_ids.strip(): ids = [i.strip() for i in patient_ids.split(",") if i.strip()] - params = FhirEpicPatientSearchInput(resource_ids=ids) + params = FhirEpicPatientSearchInput(action="search_patients", resource_ids=ids) elif family_name or given_name or name or birthdate: params = FhirEpicPatientSearchInput( + action="search_patients", given_name=given_name or None, family_name=family_name or None, name=name or None, @@ -355,7 +353,7 @@ async def fhir_epic_search_patients( "family_name / given_name / name / birthdate" ) - result = await action.internal_execute(params, trace_id=trace_id) + result = await epic.internal_execute(params, trace_id=trace_id) summaries = [] for resource in result.resources: @@ -408,15 +406,14 @@ async def fhir_cerner_search_encounters( if not cerner: raise RuntimeError("fhir_cerner connector not configured") - action = cerner.get_action("search_encounter") - params = FhirCernerEncounterSearchInput( + action="search_encounter", patient_id=patient_id or None, status=status or None, date=date or None, ) - result = await action.internal_execute(params, trace_id=trace_id) + result = await cerner.internal_execute(params, trace_id=trace_id) summaries = [] for resource in result.resources: @@ -466,15 +463,14 @@ async def fhir_epic_search_encounters( if not epic: raise RuntimeError("fhir_epic connector not configured") - action = epic.get_action("search_encounter") - params = FhirEpicEncounterSearchInput( + action="search_encounter", patient_id=patient_id or None, status=status or None, date=date or None, ) - result = await action.internal_execute(params, trace_id=trace_id) + result = await epic.internal_execute(params, trace_id=trace_id) summaries = [] for resource in result.resources: diff --git a/src/bindings/factory.py b/src/bindings/factory.py index 8a28256..76e4df8 100644 --- a/src/bindings/factory.py +++ b/src/bindings/factory.py @@ -133,7 +133,9 @@ def _instantiate(self, connector_id: str) -> Any: raise ValueError(f"Unknown connector id {connector_id!r}") - def get_for_protocol(self, connector_id: str, protocol: str, action: Optional[str] = None) -> Optional[BaseConnector[Any, Any]]: + def get_for_protocol( + self, connector_id: str, protocol: str, action: Optional[str] = None + ) -> Optional[BaseConnector[Any, Any]]: cfg = self._configs.get(connector_id) if cfg is None: logger.warning( @@ -160,9 +162,11 @@ def get_for_protocol(self, connector_id: str, protocol: str, action: Optional[st if connector is None: return None - # Multi-action connectors (e.g. fhir_epic) expose a get_action() helper. - if action and hasattr(connector, "get_action"): - return connector.get_action(action) + if action: + logger.debug( + "get_for_protocol resolved connector (action from URL is merged into payload by REST)", + extra={"connector_id": connector_id, "protocol": protocol, "action": action}, + ) return connector # type: ignore[return-value] @@ -170,9 +174,5 @@ def list_for_protocol(self, protocol: str) -> List[BaseConnector[Any, Any]]: result: List[BaseConnector[Any, Any]] = [] for connector_id, connector in self._connectors.items(): if protocol in self._configs[connector_id].exposed_via: - # Multi-action connectors expose all their actions via list_actions(). - if hasattr(connector, "list_actions"): - result.extend(connector.list_actions()) - else: - result.append(connector) # type: ignore[arg-type] + result.append(connector) # type: ignore[arg-type] return result diff --git a/src/bindings/mcp_server/server.py b/src/bindings/mcp_server/server.py index ce98707..6b9f0a2 100644 --- a/src/bindings/mcp_server/server.py +++ b/src/bindings/mcp_server/server.py @@ -50,7 +50,11 @@ async def invoke_tool(self, name: str, arguments: Dict[str, Any]) -> Dict[str, A if connector is None: raise ValueError(f"Connector {connector_id!r} is not available via MCP.") - response = await connector.run(arguments) + run_args = dict(arguments) + if connector_id in ("fhir_cerner", "fhir_epic"): + run_args.setdefault("action", action) + + response = await connector.run(run_args) return response.model_dump() diff --git a/src/bindings/rest_api/app.py b/src/bindings/rest_api/app.py index 7d27dcc..f442734 100644 --- a/src/bindings/rest_api/app.py +++ b/src/bindings/rest_api/app.py @@ -73,7 +73,10 @@ def _http_status_for_category(category: ErrorCategory | None) -> int: return 503 return 500 -def _make_endpoint(cid: str, act: str) -> Any: +_FHIR_REST_IDS = frozenset({"fhir_cerner", "fhir_epic"}) + + +def _make_endpoint(cid: str, act: str) -> Any: async def endpoint( payload: Dict[str, Any], factory_dep: ConnectorFactory = Depends(get_factory), @@ -89,9 +92,12 @@ async def endpoint( connector = factory_dep.get_for_protocol(cid, "rest", action=act) if connector is None: raise HTTPException(status_code=404, detail="Connector not available for REST") + run_payload = dict(payload) + if cid in _FHIR_REST_IDS: + run_payload.setdefault("action", act) # Let the runtime (Layer A) perform full schema validation. # Any validation errors will be mapped into ConnectorResponse. - response: ConnectorResponse = await connector.run(payload) + response: ConnectorResponse = await connector.run(run_payload) status = _http_status_for_category(response.error_category) if not response.success: diff --git a/src/connectors/fhir_cerner/logic.py b/src/connectors/fhir_cerner/logic.py index 03cc6b0..c05281c 100644 --- a/src/connectors/fhir_cerner/logic.py +++ b/src/connectors/fhir_cerner/logic.py @@ -6,7 +6,7 @@ import logging import uuid from datetime import datetime, timedelta, timezone -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional import httpx import jwt @@ -21,6 +21,8 @@ FhirCernerDocumentReferenceSearchOutput, FhirCernerEncounterSearchInput, FhirCernerEncounterSearchOutput, + FhirCernerOperationInput, + FhirCernerOperationOutput, FhirCernerPatientReadInput, FhirCernerPatientReadOutput, FhirCernerPatientSearchInput, @@ -30,60 +32,14 @@ logger = logging.getLogger("connectors.fhir_cerner") -class _FhirCernerAction(BaseConnector[Any, Any]): - """ - Lightweight BaseConnector that delegates execution to a FhirCernerConnector - instance method. One of these is created per action so that the manifest - and REST router can discover each action's schema and route automatically. - """ - - connector_id = "fhir_cerner" - - def __init__( - self, - action: str, - input_model: type, - output_model: type, - handler: Callable, - *, - secret_provider: Optional[SecretProvider] = None, - ) -> None: - super().__init__(input_model, output_model, secret_provider=secret_provider) - self.action = action - self._handler = handler - - async def internal_execute(self, params: Any, *, trace_id: str) -> Any: - return await self._handler(params, trace_id=trace_id) - - -class FhirCernerConnector: +class FhirCernerConnector(BaseConnector[FhirCernerOperationInput, FhirCernerOperationOutput]): """ Single FHIR/Cerner connector. - ``connector_id = "fhir_cerner"``. All authentication helpers and action - implementations live here. The factory registers ONE instance of this - class; ``list_actions()`` and ``get_action()`` are used by the factory to - expose each action to the manifest and REST router. - Authentication uses Cerner's SMART Backend Services (private_key_jwt) flow, identical to Epic's implementation — RS384-signed JWT exchanged for an OAuth2 access token at the configured token endpoint. - Supported actions: - • read_patient — fetch a single Patient by ID or name search - • search_patients — fetch multiple Patients by list of IDs or name search - • search_encounter - • create_document_reference - • search_document_reference - - Name-based search parameters (``given_name``, ``family_name``, ``name``, - ``birthdate``) are prioritised over the raw ``search_params`` dict. - - .. note:: - Cerner's sandbox name search is case-sensitive. Supply names exactly - as stored in the system. Special characters in search values should be - URL-encoded (httpx handles this automatically). - Required secrets (configured via SecretProvider): - cerner_fhir_base_url : Cerner FHIR R4 base URL - cerner_private_key : RSA private key PEM (newlines may be escaped) @@ -94,44 +50,26 @@ class FhirCernerConnector: """ connector_id = "fhir_cerner" + action = "execute" def __init__(self, *, secret_provider: SecretProvider) -> None: + super().__init__(FhirCernerOperationInput, FhirCernerOperationOutput, secret_provider=secret_provider) self._secret_provider = secret_provider - self._actions: Dict[str, _FhirCernerAction] = { - "read_patient": _FhirCernerAction( - "read_patient", FhirCernerPatientReadInput, FhirCernerPatientReadOutput, - self._read_patient, secret_provider=secret_provider, - ), - "search_patients": _FhirCernerAction( - "search_patients", FhirCernerPatientSearchInput, FhirCernerPatientSearchOutput, - self._search_patients, secret_provider=secret_provider, - ), - "search_encounter": _FhirCernerAction( - "search_encounter", FhirCernerEncounterSearchInput, FhirCernerEncounterSearchOutput, - self._search_encounter, secret_provider=secret_provider, - ), - "create_document_reference": _FhirCernerAction( - "create_document_reference", FhirCernerDocumentReferenceCreateInput, FhirCernerDocumentReferenceCreateOutput, - self._create_document_reference, secret_provider=secret_provider, - ), - "search_document_reference": _FhirCernerAction( - "search_document_reference", FhirCernerDocumentReferenceSearchInput, FhirCernerDocumentReferenceSearchOutput, - self._search_document_reference, secret_provider=secret_provider, - ), - } - - # ------------------------------------------------------------------ - # Action discovery — consumed by ConnectorFactory - # ------------------------------------------------------------------ - - def list_actions(self) -> List[_FhirCernerAction]: - """Return all registered action connectors (used by list_for_protocol).""" - return list(self._actions.values()) - - def get_action(self, name: str) -> Optional[_FhirCernerAction]: - """Return the action connector for the given action name.""" - return self._actions.get(name) + async def internal_execute(self, params: Any, *, trace_id: str) -> Any: + # Back-compat: allow calling with either the RootModel union or a concrete action input model. + op = params.root if hasattr(params, "root") else params + if op.action == "read_patient": + return await self._read_patient(op, trace_id=trace_id) + if op.action == "search_patients": + return await self._search_patients(op, trace_id=trace_id) + if op.action == "search_encounter": + return await self._search_encounter(op, trace_id=trace_id) + if op.action == "create_document_reference": + return await self._create_document_reference(op, trace_id=trace_id) + if op.action == "search_document_reference": + return await self._search_document_reference(op, trace_id=trace_id) + raise ValueError(f"Unsupported action: {op.action!r}") # ------------------------------------------------------------------ # Shared authentication helpers @@ -251,15 +189,7 @@ def _build_name_search_params( birthdate: Optional[str], extra: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: - """Build a FHIR search params dict from explicit name/date fields. - - Priority: given_name/family_name > name > (nothing). - The ``extra`` dict (raw search_params) is merged at lowest priority. - - .. note:: - Cerner's sandbox name matching is case-sensitive — supply names - with the same capitalisation as stored in the system. - """ + """Build a FHIR search params dict from explicit name/date fields.""" params: Dict[str, str] = dict(extra or {}) if given_name and given_name.strip(): @@ -443,7 +373,7 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ raise data = response.json() - resources = [] + resources: List[Dict[str, Any]] = [] total = data.get("total") if data.get("resourceType") == "Bundle" and data.get("entry"): resources = [e["resource"] for e in data["entry"] if "resource" in e] @@ -508,19 +438,13 @@ async def _create_document_reference( base_url = self._get_base_url() auth_header = await self._get_auth_header() - # Validate context early so callers get the most actionable error. - if params.context: - ctx = dict(params.context) - if ctx.get("encounter") and not ctx.get("period"): - raise ValueError("Cerner requires 'context.period' when 'context.encounter' is provided.") - # Cerner sandbox strictly requires a charset (lowercase, no space) for text types. # Failing to provide it results in: "a character set must be specified" (422). content_type = (params.content_type or "text/plain").strip().lower() if content_type.startswith("text/"): + content_type = content_type.replace(" ", "") if "charset=" not in content_type: - # Match the formatting expected by tests and common HTTP conventions. - content_type = f"{content_type}; charset=UTF-8" + content_type = f"{content_type};charset=utf-8" attachment: Dict[str, Any] = {"contentType": content_type} if params.data: @@ -530,8 +454,12 @@ async def _create_document_reference( else: raise ValueError("Either 'text' or 'data' must be provided") - # Some Cerner tenants require title/creation; default safely when omitted. - attachment["title"] = params.attachment_title or "Document" + # Cerner requires title and creation on the attachment + if not params.attachment_title: + raise ValueError( + "Cerner requires 'attachment_title' on DocumentReference create." + ) + attachment["title"] = params.attachment_title attachment["creation"] = params.attachment_creation or datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z") doc_ref: Dict[str, Any] = { @@ -593,7 +521,22 @@ async def _create_document_reference( # Note: 'description' is intentionally omitted by default # as Cerner can reject it depending on tenant configuration. if params.context: - doc_ref["context"] = dict(params.context) + context = dict(params.context) + # Cerner REQUIRES context.period whenever context.encounter is set. + # Auto-inject a period using the document date if the caller didn't supply one. + if context.get("encounter") and not context.get("period"): + # Force .000Z precision and provide a 1-hour clinical window + start_dt = datetime.now(tz=timezone.utc) + end_dt = start_dt + timedelta(hours=1) + context["period"] = { + "start": start_dt.strftime("%Y-%m-%dT%H:%M:%S.000Z"), + "end": end_dt.strftime("%Y-%m-%dT%H:%M:%S.000Z"), + } + logger.debug( + "Auto-injected context.period (required by Cerner when encounter is set)", + extra={"trace_id": trace_id}, + ) + doc_ref["context"] = context if params.additional_fields: doc_ref.update(params.additional_fields) @@ -602,9 +545,12 @@ async def _create_document_reference( for field in ["text", "data", "content_type", "attachment_title", "attachment_creation", "doc_status"]: doc_ref.pop(field, None) - # Note: Some Cerner tenants require author/authenticator. The connector does not - # enforce those fields universally; tenants that require them will return 4xx - # with OperationOutcome diagnostics. + # Cerner requires at least one author for clinical note document types. + if not params.author: + raise ValueError( + "Cerner requires 'author' for clinical note document types. " + "Provide at least one author reference, e.g. [{'reference': 'Practitioner/{id}'}]" + ) logger.info("FHIR DocumentReference create", extra={"trace_id": trace_id}) diff --git a/src/connectors/fhir_cerner/schema.py b/src/connectors/fhir_cerner/schema.py index eba29c1..e24d915 100644 --- a/src/connectors/fhir_cerner/schema.py +++ b/src/connectors/fhir_cerner/schema.py @@ -1,16 +1,19 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Annotated, Any, Dict, List, Literal, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field, RootModel # --------------------------------------------------------------------------- -# Patient – Read (single patient by ID or name search) +# Patient – Read # --------------------------------------------------------------------------- class FhirCernerPatientReadInput(BaseModel): - """Input for reading a single FHIR Patient resource from Cerner.""" + """Input for reading a FHIR Patient resource from Cerner.""" + + action: Literal["read_patient"] = "read_patient" + """Action discriminator (one endpoint, multiple actions pattern).""" resource_id: Optional[str] = None """Direct Patient ID lookup (e.g. '12345678').""" @@ -23,21 +26,13 @@ class FhirCernerPatientReadInput(BaseModel): """Patient family / last name (used in name-based search).""" name: Optional[str] = None - """Full or partial name string — mapped to FHIR 'name' search parameter. - - Use this when you only have a single combined name string. When both - ``name`` and ``given_name``/``family_name`` are set, the explicit given/ - family fields take precedence. - """ + """Full or partial name string — mapped to FHIR 'name' search parameter.""" birthdate: Optional[str] = None """Date of birth in YYYY-MM-DD format — used alongside name search.""" search_params: Optional[Dict[str, str]] = None - """Raw FHIR search parameters (e.g. {\"family\": \"Smith\", \"given\": \"John\"}). - - Lowest priority — used only when no ID or explicit name fields are set. - """ + """Raw FHIR search parameters (e.g. {"family": "Smith", "given": "John"}).""" class FhirCernerPatientReadOutput(BaseModel): @@ -52,65 +47,27 @@ class FhirCernerPatientReadOutput(BaseModel): # --------------------------------------------------------------------------- class FhirCernerPatientSearchInput(BaseModel): - """Input for searching / fetching multiple FHIR Patient resources from Cerner. - - Two modes are supported: - - 1. **Multi-ID lookup** — pass ``resource_ids`` (list of Patient IDs). - Each ID is fetched concurrently; partial failures are captured in - ``FhirCernerPatientSearchOutput.errors`` rather than raising globally. - - 2. **Name-based search** — pass ``given_name``, ``family_name``, ``name``, - and/or ``birthdate``. A single FHIR search request is issued and all - matching Bundle entries are returned. - - Only one mode should be used per request. If ``resource_ids`` is set it - takes priority over the name/search fields. + """Input for searching / fetching multiple FHIR Patient resources from Cerner.""" - .. note:: - Cerner's sandbox name search is case-sensitive. Use the exact - capitalisation stored in the system (e.g. ``family_name="Smith"`` not - ``"smith"``). The ``name`` parameter maps to the standard FHIR - ``name`` token which Cerner supports as a partial-match. - """ + action: Literal["search_patients"] = "search_patients" + """Action discriminator (one endpoint, multiple actions pattern).""" resource_ids: Optional[List[str]] = None - """List of Cerner Patient IDs to fetch concurrently (e.g. ['12345678', '87654321']).""" + """List of Cerner Patient IDs to fetch concurrently.""" given_name: Optional[str] = None - """Patient given / first name.""" - family_name: Optional[str] = None - """Patient family / last name.""" - name: Optional[str] = None - """Full or partial name string — mapped to FHIR 'name' search parameter.""" - birthdate: Optional[str] = None - """Date of birth in YYYY-MM-DD format.""" - search_params: Optional[Dict[str, str]] = None - """Additional raw FHIR search parameters merged with the name fields.""" class FhirCernerPatientSearchOutput(BaseModel): """Output for searching multiple FHIR Patient resources from Cerner.""" resources: List[Dict[str, Any]] - """List of successfully retrieved FHIR Patient JSON objects.""" - total: Optional[int] = None - """Total number of matches reported by the server Bundle (name-search mode).""" - - errors: List[Dict[str, Any]] = [] - """Per-ID errors encountered during multi-ID fan-out. - - Each entry has the shape:: - - {"resource_id": "", "error": ""} - - An empty list means all lookups succeeded. - """ + errors: List[Dict[str, Any]] = Field(default_factory=list) # --------------------------------------------------------------------------- @@ -120,6 +77,9 @@ class FhirCernerPatientSearchOutput(BaseModel): class FhirCernerEncounterSearchInput(BaseModel): """Input for searching FHIR Encounter resources in Cerner.""" + action: Literal["search_encounter"] = "search_encounter" + """Action discriminator (one endpoint, multiple actions pattern).""" + patient_id: Optional[str] = None """Cerner Patient ID to find encounters for (maps to 'patient' FHIR param).""" @@ -150,6 +110,9 @@ class FhirCernerEncounterSearchOutput(BaseModel): class FhirCernerDocumentReferenceCreateInput(BaseModel): """Input for creating a FHIR DocumentReference resource in Cerner.""" + action: Literal["create_document_reference"] = "create_document_reference" + """Action discriminator (one endpoint, multiple actions pattern).""" + identifier: Optional[list[Dict[str, Any]]] = None """Document identifier. @@ -295,8 +258,11 @@ class FhirCernerDocumentReferenceCreateOutput(BaseModel): class FhirCernerDocumentReferenceSearchInput(BaseModel): """Input for searching FHIR DocumentReference resources in Cerner.""" + action: Literal["search_document_reference"] = "search_document_reference" + """Action discriminator (one endpoint, multiple actions pattern).""" + search_params: Dict[str, str] - """Search parameters (e.g. {\"patient\": \"12345678\"}).""" + """Search parameters (e.g. {"patient": "12345678"}).""" class FhirCernerDocumentReferenceSearchOutput(BaseModel): @@ -307,3 +273,35 @@ class FhirCernerDocumentReferenceSearchOutput(BaseModel): total: Optional[int] = None """Total number of results reported by the Bundle.""" + + +# --------------------------------------------------------------------------- +# Unified operation input/output (one endpoint, multiple actions) +# --------------------------------------------------------------------------- + +_FhirCernerOperationUnion = Annotated[ + Union[ + FhirCernerPatientReadInput, + FhirCernerPatientSearchInput, + FhirCernerEncounterSearchInput, + FhirCernerDocumentReferenceCreateInput, + FhirCernerDocumentReferenceSearchInput, + ], + Field(discriminator="action"), +] + +FhirCernerOperationInput = RootModel[_FhirCernerOperationUnion] + + +class FhirCernerOperationOutput(BaseModel): + """ + Combined output shape for schema documentation/manifest generation. + + Individual handlers still return their specific output models. + """ + + resource: Optional[Dict[str, Any]] = None + resources: Optional[list[Dict[str, Any]]] = None + total: Optional[int] = None + resource_id: Optional[str] = None + errors: Optional[list[Dict[str, Any]]] = None diff --git a/src/connectors/fhir_epic/logic.py b/src/connectors/fhir_epic/logic.py index e9cc615..5cbe8c3 100644 --- a/src/connectors/fhir_epic/logic.py +++ b/src/connectors/fhir_epic/logic.py @@ -6,7 +6,7 @@ import logging import uuid from datetime import datetime, timezone -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional import httpx import jwt @@ -20,6 +20,8 @@ FhirDocumentReferenceSearchOutput, FhirEncounterSearchInput, FhirEncounterSearchOutput, + FhirEpicOperationInput, + FhirEpicOperationOutput, FhirPatientReadInput, FhirPatientReadOutput, FhirPatientSearchInput, @@ -29,92 +31,35 @@ logger = logging.getLogger("connectors.fhir_epic") -class _FhirAction(BaseConnector[Any, Any]): - """ - Lightweight BaseConnector that delegates execution to a FhirEpicConnector - instance method. One of these is created per action so that the manifest - and REST router can discover each action's schema and route automatically. - """ - - connector_id = "fhir_epic" - - def __init__( - self, - action: str, - input_model: type, - output_model: type, - handler: Callable, - *, - secret_provider: Optional[SecretProvider] = None, - ) -> None: - super().__init__(input_model, output_model, secret_provider=secret_provider) - self.action = action # instance attribute, overrides absent class-level action - self._handler = handler - - async def internal_execute(self, params: Any, *, trace_id: str) -> Any: - return await self._handler(params, trace_id=trace_id) - - -class FhirEpicConnector: +class FhirEpicConnector(BaseConnector[FhirEpicOperationInput, FhirEpicOperationOutput]): """ Single FHIR/Epic connector. - ``connector_id = "fhir_epic"``. All authentication helpers and action - implementations live here. The factory registers ONE instance of this - class; ``list_actions()`` and ``get_action()`` are used by the factory to - expose each action to the manifest and REST router. - - Supported actions: - • read_patient — fetch a single Patient by ID or name search - • search_patients — fetch multiple Patients by list of IDs or name search - • search_encounter - • create_document_reference - • search_document_reference - - Name-based search parameters (``given_name``, ``family_name``, ``name``, - ``birthdate``) are prioritised over the raw ``search_params`` dict and are - normalised (stripped, lowercased for ``name`` token search). + Exposes one endpoint (`execute`) and dispatches actions via the + `action` discriminator on the request payload. """ connector_id = "fhir_epic" + action = "execute" def __init__(self, *, secret_provider: SecretProvider) -> None: + super().__init__(FhirEpicOperationInput, FhirEpicOperationOutput, secret_provider=secret_provider) self._secret_provider = secret_provider - self._actions: Dict[str, _FhirAction] = { - "read_patient": _FhirAction( - "read_patient", FhirPatientReadInput, FhirPatientReadOutput, - self._read_patient, secret_provider=secret_provider, - ), - "search_patients": _FhirAction( - "search_patients", FhirPatientSearchInput, FhirPatientSearchOutput, - self._search_patients, secret_provider=secret_provider, - ), - "search_encounter": _FhirAction( - "search_encounter", FhirEncounterSearchInput, FhirEncounterSearchOutput, - self._search_encounter, secret_provider=secret_provider, - ), - "create_document_reference": _FhirAction( - "create_document_reference", FhirDocumentReferenceCreateInput, FhirDocumentReferenceCreateOutput, - self._create_document_reference, secret_provider=secret_provider, - ), - "search_document_reference": _FhirAction( - "search_document_reference", FhirDocumentReferenceSearchInput, FhirDocumentReferenceSearchOutput, - self._search_document_reference, secret_provider=secret_provider, - ), - } - - # ------------------------------------------------------------------ - # Action discovery — consumed by ConnectorFactory - # ------------------------------------------------------------------ - - def list_actions(self) -> List[_FhirAction]: - """Return all registered action connectors (used by list_for_protocol).""" - return list(self._actions.values()) - - def get_action(self, name: str) -> Optional[_FhirAction]: - """Return the action connector for the given action name.""" - return self._actions.get(name) + async def internal_execute(self, params: Any, *, trace_id: str) -> Any: + # Back-compat: allow calling with either the RootModel union or a concrete action input model. + op = params.root if hasattr(params, "root") else params + if op.action == "read_patient": + return await self._read_patient(op, trace_id=trace_id) + if op.action == "search_patients": + return await self._search_patients(op, trace_id=trace_id) + if op.action == "search_encounter": + return await self._search_encounter(op, trace_id=trace_id) + if op.action == "create_document_reference": + return await self._create_document_reference(op, trace_id=trace_id) + if op.action == "search_document_reference": + return await self._search_document_reference(op, trace_id=trace_id) + raise ValueError(f"Unsupported action: {op.action!r}") # ------------------------------------------------------------------ # Shared authentication helpers @@ -194,22 +139,14 @@ def _build_name_search_params( birthdate: Optional[str], extra: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: - """Build a FHIR search params dict from explicit name/date fields. - - Priority: given_name/family_name > name > (nothing). - The ``extra`` dict (raw search_params) is merged at lowest priority so - callers can pass additional filters without overriding name fields. - """ + """Build a FHIR search params dict from explicit name/date fields.""" params: Dict[str, str] = dict(extra or {}) - # Normalize: strip whitespace; FHIR name search is typically case-insensitive - # on compliant servers but we preserve original case per FHIR spec. if given_name and given_name.strip(): params["given"] = given_name.strip() if family_name and family_name.strip(): params["family"] = family_name.strip() if name and name.strip() and "given" not in params and "family" not in params: - # Only fall back to the combined 'name' token when no split fields given params["name"] = name.strip() if birthdate and birthdate.strip(): params["birthdate"] = birthdate.strip() @@ -296,7 +233,6 @@ async def _search_patients( base_url = self._get_base_url() auth_header = await self._get_auth_header() - # ---- Mode 1: Multi-ID fan-out ---- if params.resource_ids: ids = [rid.strip() for rid in params.resource_ids if rid.strip()] if not ids: @@ -309,7 +245,6 @@ async def _search_patients( ) async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[str]]: - """Return (rid, resource_or_None, error_or_None).""" try: async with httpx.AsyncClient() as client: resp = await client.get( @@ -344,7 +279,6 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ ) return FhirPatientSearchOutput(resources=resources, total=len(resources), errors=errors) - # ---- Mode 2: Name-based search (returns Bundle) ---- name_params = self._build_name_search_params( params.given_name, params.family_name, params.name, params.birthdate, params.search_params, @@ -386,7 +320,7 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ raise data = response.json() - resources = [] + resources: List[Dict[str, Any]] = [] total = data.get("total") if data.get("resourceType") == "Bundle" and data.get("entry"): resources = [e["resource"] for e in data["entry"] if "resource" in e] diff --git a/src/connectors/fhir_epic/schema.py b/src/connectors/fhir_epic/schema.py index 99aa9b5..eaeef26 100644 --- a/src/connectors/fhir_epic/schema.py +++ b/src/connectors/fhir_epic/schema.py @@ -1,43 +1,30 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Annotated, Any, Dict, List, Literal, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field, RootModel # --------------------------------------------------------------------------- -# Patient – Read (single patient by ID or name search) +# Patient – Read # --------------------------------------------------------------------------- class FhirPatientReadInput(BaseModel): - """Input for reading a single FHIR Patient resource from Epic.""" + """Input for reading a FHIR Patient resource.""" + + action: Literal["read_patient"] = "read_patient" + """Action discriminator (one endpoint, multiple actions pattern).""" resource_id: Optional[str] = None """Direct Patient ID lookup (e.g. 'eXYZ123').""" - # Convenience name fields — take priority over raw search_params when set. given_name: Optional[str] = None - """Patient given / first name (used in name-based search).""" - family_name: Optional[str] = None - """Patient family / last name (used in name-based search).""" - name: Optional[str] = None - """Full or partial name string — mapped to FHIR 'name' search parameter. - - Use this when you only have a single combined name string. When both - ``name`` and ``given_name``/``family_name`` are set, the explicit given/ - family fields take precedence. - """ - birthdate: Optional[str] = None - """Date of birth in YYYY-MM-DD format — used alongside name search.""" search_params: Optional[Dict[str, str]] = None - """Raw FHIR search parameters (e.g. {\"family\": \"Smith\", \"given\": \"John\"}). - - Lowest priority — used only when no ID or explicit name fields are set. - """ + """Search parameters (e.g. {"family": "Smith", "given": "John"}).""" class FhirPatientReadOutput(BaseModel): @@ -52,59 +39,25 @@ class FhirPatientReadOutput(BaseModel): # --------------------------------------------------------------------------- class FhirPatientSearchInput(BaseModel): - """Input for searching / fetching multiple FHIR Patient resources from Epic. - - Two modes are supported: + """Input for searching / fetching multiple FHIR Patient resources from Epic.""" - 1. **Multi-ID lookup** — pass ``resource_ids`` (list of Patient IDs). - Each ID is fetched concurrently; partial failures are captured in - ``FhirPatientSearchOutput.errors`` rather than raising globally. - - 2. **Name-based search** — pass ``given_name``, ``family_name``, ``name``, - and/or ``birthdate``. A single FHIR search request is issued and all - matching Bundle entries are returned. - - Only one mode should be used per request. If ``resource_ids`` is set it - takes priority over the name/search fields. - """ + action: Literal["search_patients"] = "search_patients" + """Action discriminator (one endpoint, multiple actions pattern).""" resource_ids: Optional[List[str]] = None - """List of Epic Patient IDs to fetch concurrently (e.g. ['eABC', 'eDEF']).""" - given_name: Optional[str] = None - """Patient given / first name.""" - family_name: Optional[str] = None - """Patient family / last name.""" - name: Optional[str] = None - """Full or partial name string — mapped to FHIR 'name' search parameter.""" - birthdate: Optional[str] = None - """Date of birth in YYYY-MM-DD format.""" - search_params: Optional[Dict[str, str]] = None - """Additional raw FHIR search parameters merged with the name fields.""" class FhirPatientSearchOutput(BaseModel): """Output for searching multiple FHIR Patient resources.""" resources: List[Dict[str, Any]] - """List of successfully retrieved FHIR Patient JSON objects.""" - total: Optional[int] = None - """Total number of matches reported by the server Bundle (name-search mode).""" - - errors: List[Dict[str, Any]] = [] - """Per-ID errors encountered during multi-ID fan-out. - - Each entry has the shape:: - - {"resource_id": "", "error": ""} - - An empty list means all lookups succeeded. - """ + errors: List[Dict[str, Any]] = Field(default_factory=list) # --------------------------------------------------------------------------- @@ -114,17 +67,13 @@ class FhirPatientSearchOutput(BaseModel): class FhirEncounterSearchInput(BaseModel): """Input for searching FHIR Encounter resources.""" - patient_id: Optional[str] = None - """FHIR Patient ID to find encounters for (maps to 'patient' FHIR param).""" + action: Literal["search_encounter"] = "search_encounter" + """Action discriminator (one endpoint, multiple actions pattern).""" + patient_id: Optional[str] = None status: Optional[str] = None - """Status of the encounters to find (e.g. 'finished', 'arrived').""" - date: Optional[str] = None - """Date or date range for the encounters (e.g. '2024', 'gt2023-01-01').""" - search_params: Optional[Dict[str, str]] = None - """Raw FHIR search parameters. Used if explicit fields above are not provided.""" class FhirEncounterSearchOutput(BaseModel): @@ -144,6 +93,9 @@ class FhirEncounterSearchOutput(BaseModel): class FhirDocumentReferenceCreateInput(BaseModel): """Input for creating a FHIR DocumentReference resource.""" + action: Literal["create_document_reference"] = "create_document_reference" + """Action discriminator (one endpoint, multiple actions pattern).""" + identifier: list[Dict[str, Any]] """Document identifier.""" @@ -213,8 +165,11 @@ class FhirDocumentReferenceCreateOutput(BaseModel): class FhirDocumentReferenceSearchInput(BaseModel): """Input for searching FHIR DocumentReference resources.""" + action: Literal["search_document_reference"] = "search_document_reference" + """Action discriminator (one endpoint, multiple actions pattern).""" + search_params: Dict[str, str] - """Search parameters (e.g. {\"patient\": \"eXYZ123\"}).""" + """Search parameters (e.g. {"patient": "eXYZ123"}).""" class FhirDocumentReferenceSearchOutput(BaseModel): @@ -224,4 +179,36 @@ class FhirDocumentReferenceSearchOutput(BaseModel): """The list of raw FHIR DocumentReference JSON objects found.""" total: Optional[int] = None - """Total number of results reported by the Bundle.""" \ No newline at end of file + """Total number of results reported by the Bundle.""" + + +# --------------------------------------------------------------------------- +# Unified operation input/output (one endpoint, multiple actions) +# --------------------------------------------------------------------------- + +_FhirEpicOperationUnion = Annotated[ + Union[ + FhirPatientReadInput, + FhirPatientSearchInput, + FhirEncounterSearchInput, + FhirDocumentReferenceCreateInput, + FhirDocumentReferenceSearchInput, + ], + Field(discriminator="action"), +] + +FhirEpicOperationInput = RootModel[_FhirEpicOperationUnion] + + +class FhirEpicOperationOutput(BaseModel): + """ + Combined output shape for schema documentation/manifest generation. + + Individual handlers still return their specific output models. + """ + + resource: Optional[Dict[str, Any]] = None + resources: Optional[list[Dict[str, Any]]] = None + total: Optional[int] = None + resource_id: Optional[str] = None + errors: Optional[list[Dict[str, Any]]] = None diff --git a/src/connectors/manifest.py b/src/connectors/manifest.py index a13f9f5..2033234 100644 --- a/src/connectors/manifest.py +++ b/src/connectors/manifest.py @@ -6,6 +6,25 @@ from runtime import BaseConnector +# FHIR connectors expose a single `execute` entrypoint with a discriminated `action` +# field; expand these for REST/MCP discovery so routes remain per-operation. +_FHIR_DISCRIMINATED_ACTIONS: Dict[str, List[str]] = { + "fhir_cerner": [ + "read_patient", + "search_patients", + "search_encounter", + "create_document_reference", + "search_document_reference", + ], + "fhir_epic": [ + "read_patient", + "search_patients", + "search_encounter", + "create_document_reference", + "search_document_reference", + ], +} + def _schema_for(model: Type[BaseModel]) -> Dict[str, Any]: return model.model_json_schema() @@ -23,13 +42,25 @@ def build_manifest(connectors: List[BaseConnector[Any, Any]]) -> List[Dict[str, for connector in connectors: input_model = connector._input_model_cls # type: ignore[attr-defined] output_model = connector._output_model_cls # type: ignore[attr-defined] - manifest.append( - { - "connector_id": connector.connector_id, - "action": connector.action, - "input_schema": _schema_for(input_model), - "output_schema": _schema_for(output_model), - } - ) + cid = connector.connector_id + if cid in _FHIR_DISCRIMINATED_ACTIONS and getattr(connector, "action", None) == "execute": + for sub_action in _FHIR_DISCRIMINATED_ACTIONS[cid]: + manifest.append( + { + "connector_id": cid, + "action": sub_action, + "input_schema": _schema_for(input_model), + "output_schema": _schema_for(output_model), + } + ) + else: + manifest.append( + { + "connector_id": cid, + "action": connector.action, + "input_schema": _schema_for(input_model), + "output_schema": _schema_for(output_model), + } + ) return manifest diff --git a/tests/test_fhir_cerner.py b/tests/test_fhir_cerner.py index a48eb72..903a927 100644 --- a/tests/test_fhir_cerner.py +++ b/tests/test_fhir_cerner.py @@ -36,18 +36,13 @@ def _connector() -> FhirCernerConnector: # --------------------------------------------------------------------------- -# Sanity: connector exposes all 5 actions +# Sanity: unified connector (single execute entrypoint) # --------------------------------------------------------------------------- -def test_fhir_cerner_connector_exposes_five_actions(): +def test_fhir_cerner_connector_is_unified_execute(): c = _connector() - actions = {a.action for a in c.list_actions()} - assert actions == { - "read_patient", "search_patients", - "search_encounter", "create_document_reference", "search_document_reference", - } - for name in actions: - assert c.get_action(name) is not None + assert c.connector_id == "fhir_cerner" + assert c.action == "execute" # --------------------------------------------------------------------------- @@ -56,9 +51,9 @@ def test_fhir_cerner_connector_exposes_five_actions(): @pytest.mark.asyncio async def test_fhir_cerner_read_patient_by_id(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - params = FhirCernerPatientReadInput(resource_id="12345678") + params = FhirCernerPatientReadInput(action="read_patient", resource_id="12345678") patient_response = MagicMock() patient_response.status_code = 200 @@ -67,7 +62,7 @@ async def test_fhir_cerner_read_patient_by_id(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "12345678" assert result.resource["name"][0]["family"] == "Smith" @@ -79,9 +74,12 @@ async def test_fhir_cerner_read_patient_by_id(): @pytest.mark.asyncio async def test_fhir_cerner_read_patient_by_search(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - params = FhirCernerPatientReadInput(search_params={"family": "Smith", "given": "John"}) + params = FhirCernerPatientReadInput( + action="read_patient", + search_params={"family": "Smith", "given": "John"}, + ) patient_response = MagicMock() patient_response.status_code = 200 @@ -93,7 +91,7 @@ async def test_fhir_cerner_read_patient_by_search(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "99887766" @@ -104,9 +102,14 @@ async def test_fhir_cerner_read_patient_by_search(): @pytest.mark.asyncio async def test_fhir_cerner_read_patient_by_explicit_name_fields(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - params = FhirCernerPatientReadInput(given_name=" Jane ", family_name="Doe", birthdate="1990-06-15") + params = FhirCernerPatientReadInput( + action="read_patient", + given_name=" Jane ", + family_name="Doe", + birthdate="1990-06-15", + ) patient_response = MagicMock() patient_response.status_code = 200 @@ -118,7 +121,7 @@ async def test_fhir_cerner_read_patient_by_explicit_name_fields(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "55551234" call_kwargs = mock_get.call_args @@ -134,9 +137,9 @@ async def test_fhir_cerner_read_patient_by_explicit_name_fields(): @pytest.mark.asyncio async def test_fhir_cerner_read_patient_by_name_field(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - params = FhirCernerPatientReadInput(name="Johnson") + params = FhirCernerPatientReadInput(action="read_patient", name="Johnson") patient_response = MagicMock() patient_response.status_code = 200 @@ -148,7 +151,7 @@ async def test_fhir_cerner_read_patient_by_name_field(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "99990001" call_kwargs = mock_get.call_args @@ -162,14 +165,14 @@ async def test_fhir_cerner_read_patient_by_name_field(): @pytest.mark.asyncio async def test_fhir_cerner_read_patient_no_params_raises(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - params = FhirCernerPatientReadInput() - + params = FhirCernerPatientReadInput(action="read_patient") + with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): with pytest.raises(ValueError, match="Provide resource_id"): - await action.internal_execute(params, trace_id="test-trace") + await c.internal_execute(params, trace_id="test-trace") # --------------------------------------------------------------------------- @@ -178,9 +181,12 @@ async def test_fhir_cerner_read_patient_no_params_raises(): @pytest.mark.asyncio async def test_fhir_cerner_search_patients_multi_id(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientSearchInput - params = FhirCernerPatientSearchInput(resource_ids=["11111111", "22222222"]) + params = FhirCernerPatientSearchInput( + action="search_patients", + resource_ids=["11111111", "22222222"], + ) def _patient_resp(pid: str) -> MagicMock: m = MagicMock() @@ -193,7 +199,7 @@ def _patient_resp(pid: str) -> MagicMock: with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=responses): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") ids = {r["id"] for r in result.resources} assert ids == {"11111111", "22222222"} @@ -207,9 +213,12 @@ def _patient_resp(pid: str) -> MagicMock: @pytest.mark.asyncio async def test_fhir_cerner_search_patients_partial_failure(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientSearchInput - params = FhirCernerPatientSearchInput(resource_ids=["99999999", "00000000"]) + params = FhirCernerPatientSearchInput( + action="search_patients", + resource_ids=["99999999", "00000000"], + ) good_resp = MagicMock() good_resp.status_code = 200 @@ -222,7 +231,7 @@ async def test_fhir_cerner_search_patients_partial_failure(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=[good_resp, bad_resp]): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert len(result.resources) == 1 assert result.resources[0]["id"] == "99999999" @@ -236,9 +245,9 @@ async def test_fhir_cerner_search_patients_partial_failure(): @pytest.mark.asyncio async def test_fhir_cerner_search_patients_by_name(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientSearchInput - params = FhirCernerPatientSearchInput(family_name="Smith") + params = FhirCernerPatientSearchInput(action="search_patients", family_name="Smith") bundle_resp = MagicMock() bundle_resp.status_code = 200 @@ -254,7 +263,7 @@ async def test_fhir_cerner_search_patients_by_name(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=bundle_resp) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 2 assert len(result.resources) == 2 @@ -270,14 +279,14 @@ async def test_fhir_cerner_search_patients_by_name(): @pytest.mark.asyncio async def test_fhir_cerner_search_patients_no_params_raises(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerPatientSearchInput - params = FhirCernerPatientSearchInput() - + params = FhirCernerPatientSearchInput(action="search_patients") + with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): with pytest.raises(ValueError): - await action.internal_execute(params, trace_id="test-trace") + await c.internal_execute(params, trace_id="test-trace") # --------------------------------------------------------------------------- @@ -286,9 +295,12 @@ async def test_fhir_cerner_search_patients_no_params_raises(): @pytest.mark.asyncio async def test_fhir_cerner_search_encounter(): - action = _connector().get_action("search_encounter") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerEncounterSearchInput - params = FhirCernerEncounterSearchInput(search_params={"patient": "12345678", "status": "finished"}) + params = FhirCernerEncounterSearchInput( + action="search_encounter", + search_params={"patient": "12345678", "status": "finished"}, + ) enc_response = MagicMock() enc_response.status_code = 200 @@ -303,7 +315,7 @@ async def test_fhir_cerner_search_encounter(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=enc_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 2 assert result.resources[0]["id"] == "enc-1" @@ -311,9 +323,9 @@ async def test_fhir_cerner_search_encounter(): @pytest.mark.asyncio async def test_fhir_cerner_search_encounter_by_patient(): - action = _connector().get_action("search_encounter") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerEncounterSearchInput - params = FhirCernerEncounterSearchInput(patient_id="12345678") + params = FhirCernerEncounterSearchInput(action="search_encounter", patient_id="12345678") enc_response = MagicMock() enc_response.status_code = 200 @@ -325,7 +337,7 @@ async def test_fhir_cerner_search_encounter_by_patient(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=enc_response) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 1 assert result.resources[0]["id"] == "enc-1" @@ -340,14 +352,26 @@ async def test_fhir_cerner_search_encounter_by_patient(): @pytest.mark.asyncio async def test_fhir_cerner_create_document_reference(): - action = _connector().get_action("create_document_reference") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerDocumentReferenceCreateInput params = FhirCernerDocumentReferenceCreateInput( + action="create_document_reference", identifier=[{"system": "urn:oid:1.2.3", "value": "ID.123"}], status="current", - type={"coding": [{"system": "urn:oid:4.5.6", "code": "18100", "display": "Employer Group Scan"}]}, + doc_status="final", + type={ + "coding": [{ + "system": "urn:oid:4.5.6", + "code": "18100", + "display": "Employer Group Scan", + "userSelected": True, + }], + "text": "Employer Group Scan", + }, subject="Patient/12724066", data="dGVzdA==", + attachment_title="Document", + author=[{"reference": "Practitioner/p1"}], context={ "encounter": [{"reference": "Encounter/enc-1"}], "period": {"start": "2024-01-01T00:00:00Z", "end": "2024-01-01T01:00:00Z"}, @@ -363,14 +387,14 @@ async def test_fhir_cerner_create_document_reference(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post: mock_post.side_effect = [_token_mock(), create_response] - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource_id == "doc-456" _, kwargs = mock_post.call_args_list[1] assert kwargs["json"]["resourceType"] == "DocumentReference" assert kwargs["json"]["subject"] == {"reference": "Patient/12724066"} # Verify that charset was added to contentType - assert kwargs["json"]["content"][0]["attachment"]["contentType"] == "text/plain; charset=UTF-8" + assert kwargs["json"]["content"][0]["attachment"]["contentType"] == "text/plain;charset=utf-8" # --------------------------------------------------------------------------- @@ -379,9 +403,12 @@ async def test_fhir_cerner_create_document_reference(): @pytest.mark.asyncio async def test_fhir_cerner_search_document_reference(): - action = _connector().get_action("search_document_reference") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerDocumentReferenceSearchInput - params = FhirCernerDocumentReferenceSearchInput(search_params={"patient": "12345678"}) + params = FhirCernerDocumentReferenceSearchInput( + action="search_document_reference", + search_params={"patient": "12345678"}, + ) search_response = MagicMock() search_response.status_code = 200 @@ -394,7 +421,7 @@ async def test_fhir_cerner_search_document_reference(): with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=search_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 1 assert result.resources[0]["id"] == "doc-789" @@ -407,18 +434,22 @@ async def test_fhir_cerner_search_document_reference(): @pytest.mark.asyncio async def test_fhir_cerner_create_document_reference_validation(): """Verify that ValueError is raised when period is missing but encounter is present.""" - action = _connector().get_action("create_document_reference") + c = _connector() from connectors.fhir_cerner.schema import FhirCernerDocumentReferenceCreateInput params = FhirCernerDocumentReferenceCreateInput( + action="create_document_reference", identifier=[{"system": "urn:oid:1.2.3", "value": "ID.123"}], status="current", + doc_status="final", type={"coding": [{"system": "http://loinc.org", "code": "11488-4"}]}, subject="Patient/12724066", data="dGVzdA==", + attachment_title="Doc", + author=[{"reference": "Practitioner/p1"}], context={"encounter": [{"reference": "Encounter/enc-1"}]}, ) with patch("connectors.fhir_cerner.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): with pytest.raises(ValueError, match="Cerner requires the proprietary CodeSet 72 system"): - await action.internal_execute(params, trace_id="test-trace") + await c.internal_execute(params, trace_id="test-trace") diff --git a/tests/test_fhir_epic.py b/tests/test_fhir_epic.py index b076da5..38eee40 100644 --- a/tests/test_fhir_epic.py +++ b/tests/test_fhir_epic.py @@ -36,18 +36,13 @@ def _connector() -> FhirEpicConnector: # --------------------------------------------------------------------------- -# Sanity: connector exposes all 5 actions +# Sanity: unified connector (single execute entrypoint) # --------------------------------------------------------------------------- -def test_fhir_epic_connector_exposes_five_actions(): +def test_fhir_epic_connector_is_unified_execute(): c = _connector() - actions = {a.action for a in c.list_actions()} - assert actions == { - "read_patient", "search_patients", - "search_encounter", "create_document_reference", "search_document_reference", - } - for name in actions: - assert c.get_action(name) is not None + assert c.connector_id == "fhir_epic" + assert c.action == "execute" # --------------------------------------------------------------------------- @@ -56,9 +51,9 @@ def test_fhir_epic_connector_exposes_five_actions(): @pytest.mark.asyncio async def test_fhir_epic_read_patient_by_id(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_epic.schema import FhirPatientReadInput - params = FhirPatientReadInput(resource_id="eXYZ123") + params = FhirPatientReadInput(action="read_patient", resource_id="eXYZ123") patient_response = MagicMock() patient_response.status_code = 200 @@ -67,7 +62,7 @@ async def test_fhir_epic_read_patient_by_id(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "eXYZ123" assert result.resource["name"][0]["family"] == "Smith" @@ -79,9 +74,12 @@ async def test_fhir_epic_read_patient_by_id(): @pytest.mark.asyncio async def test_fhir_epic_read_patient_by_search(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_epic.schema import FhirPatientReadInput - params = FhirPatientReadInput(search_params={"family": "Smith", "given": "John"}) + params = FhirPatientReadInput( + action="read_patient", + search_params={"family": "Smith", "given": "John"}, + ) patient_response = MagicMock() patient_response.status_code = 200 @@ -93,7 +91,7 @@ async def test_fhir_epic_read_patient_by_search(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "eABC" @@ -104,9 +102,14 @@ async def test_fhir_epic_read_patient_by_search(): @pytest.mark.asyncio async def test_fhir_epic_read_patient_by_explicit_name_fields(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_epic.schema import FhirPatientReadInput - params = FhirPatientReadInput(given_name=" John ", family_name="Smith", birthdate="1980-01-01") + params = FhirPatientReadInput( + action="read_patient", + given_name=" John ", + family_name="Smith", + birthdate="1980-01-01", + ) patient_response = MagicMock() patient_response.status_code = 200 @@ -118,7 +121,7 @@ async def test_fhir_epic_read_patient_by_explicit_name_fields(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "eDEF" # Verify the correct FHIR params were built (stripped whitespace) @@ -135,9 +138,9 @@ async def test_fhir_epic_read_patient_by_explicit_name_fields(): @pytest.mark.asyncio async def test_fhir_epic_read_patient_by_name_field(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_epic.schema import FhirPatientReadInput - params = FhirPatientReadInput(name="Johnson") + params = FhirPatientReadInput(action="read_patient", name="Johnson") patient_response = MagicMock() patient_response.status_code = 200 @@ -149,7 +152,7 @@ async def test_fhir_epic_read_patient_by_name_field(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=patient_response) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource["id"] == "eGHI" call_kwargs = mock_get.call_args @@ -163,14 +166,14 @@ async def test_fhir_epic_read_patient_by_name_field(): @pytest.mark.asyncio async def test_fhir_epic_read_patient_no_params_raises(): - action = _connector().get_action("read_patient") + c = _connector() from connectors.fhir_epic.schema import FhirPatientReadInput - params = FhirPatientReadInput() # nothing provided - + params = FhirPatientReadInput(action="read_patient") + with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): with pytest.raises(ValueError, match="Provide resource_id"): - await action.internal_execute(params, trace_id="test-trace") + await c.internal_execute(params, trace_id="test-trace") # --------------------------------------------------------------------------- @@ -179,9 +182,9 @@ async def test_fhir_epic_read_patient_no_params_raises(): @pytest.mark.asyncio async def test_fhir_epic_search_patients_multi_id(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_epic.schema import FhirPatientSearchInput - params = FhirPatientSearchInput(resource_ids=["eABC", "eDEF"]) + params = FhirPatientSearchInput(action="search_patients", resource_ids=["eABC", "eDEF"]) def _patient_resp(pid: str) -> MagicMock: m = MagicMock() @@ -194,7 +197,7 @@ def _patient_resp(pid: str) -> MagicMock: with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=responses): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") ids = {r["id"] for r in result.resources} assert ids == {"eABC", "eDEF"} @@ -208,9 +211,9 @@ def _patient_resp(pid: str) -> MagicMock: @pytest.mark.asyncio async def test_fhir_epic_search_patients_partial_failure(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_epic.schema import FhirPatientSearchInput - params = FhirPatientSearchInput(resource_ids=["eGOOD", "eBAD"]) + params = FhirPatientSearchInput(action="search_patients", resource_ids=["eGOOD", "eBAD"]) good_resp = MagicMock() good_resp.status_code = 200 @@ -223,7 +226,7 @@ async def test_fhir_epic_search_patients_partial_failure(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, side_effect=[good_resp, bad_resp]): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert len(result.resources) == 1 assert result.resources[0]["id"] == "eGOOD" @@ -237,9 +240,9 @@ async def test_fhir_epic_search_patients_partial_failure(): @pytest.mark.asyncio async def test_fhir_epic_search_patients_by_name(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_epic.schema import FhirPatientSearchInput - params = FhirPatientSearchInput(family_name="Smith") + params = FhirPatientSearchInput(action="search_patients", family_name="Smith") bundle_resp = MagicMock() bundle_resp.status_code = 200 @@ -255,7 +258,7 @@ async def test_fhir_epic_search_patients_by_name(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=bundle_resp) as mock_get: - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 2 assert len(result.resources) == 2 @@ -272,14 +275,14 @@ async def test_fhir_epic_search_patients_by_name(): @pytest.mark.asyncio async def test_fhir_epic_search_patients_no_params_raises(): - action = _connector().get_action("search_patients") + c = _connector() from connectors.fhir_epic.schema import FhirPatientSearchInput - params = FhirPatientSearchInput() - + params = FhirPatientSearchInput(action="search_patients") + with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()): with pytest.raises(ValueError): - await action.internal_execute(params, trace_id="test-trace") + await c.internal_execute(params, trace_id="test-trace") # --------------------------------------------------------------------------- @@ -288,9 +291,12 @@ async def test_fhir_epic_search_patients_no_params_raises(): @pytest.mark.asyncio async def test_fhir_epic_search_encounter(): - action = _connector().get_action("search_encounter") + c = _connector() from connectors.fhir_epic.schema import FhirEncounterSearchInput - params = FhirEncounterSearchInput(search_params={"patient": "eXYZ123", "status": "finished"}) + params = FhirEncounterSearchInput( + action="search_encounter", + search_params={"patient": "eXYZ123", "status": "finished"}, + ) enc_response = MagicMock() enc_response.status_code = 200 @@ -305,7 +311,7 @@ async def test_fhir_epic_search_encounter(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=enc_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 2 assert result.resources[0]["id"] == "enc-1" @@ -317,9 +323,10 @@ async def test_fhir_epic_search_encounter(): @pytest.mark.asyncio async def test_fhir_epic_create_document_reference(): - action = _connector().get_action("create_document_reference") + c = _connector() from connectors.fhir_epic.schema import FhirDocumentReferenceCreateInput params = FhirDocumentReferenceCreateInput( + action="create_document_reference", identifier=[{"system": "urn:oid:1.2.3", "value": "ID.123"}], status="current", type={"coding": [{"system": "urn:oid:4.5.6", "code": "18100", "display": "Employer Group Scan"}]}, @@ -337,7 +344,7 @@ async def test_fhir_epic_create_document_reference(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post: mock_post.side_effect = [_token_mock(), create_response] - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.resource_id == "doc-456" _, kwargs = mock_post.call_args_list[1] @@ -351,9 +358,12 @@ async def test_fhir_epic_create_document_reference(): @pytest.mark.asyncio async def test_fhir_epic_search_document_reference(): - action = _connector().get_action("search_document_reference") + c = _connector() from connectors.fhir_epic.schema import FhirDocumentReferenceSearchInput - params = FhirDocumentReferenceSearchInput(search_params={"patient": "eXYZ123"}) + params = FhirDocumentReferenceSearchInput( + action="search_document_reference", + search_params={"patient": "eXYZ123"}, + ) search_response = MagicMock() search_response.status_code = 200 @@ -366,7 +376,7 @@ async def test_fhir_epic_search_document_reference(): with patch("connectors.fhir_epic.logic.jwt.encode", return_value="dummy-jwt"), \ patch("httpx.AsyncClient.post", new_callable=AsyncMock, return_value=_token_mock()), \ patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=search_response): - result = await action.internal_execute(params, trace_id="test-trace") + result = await c.internal_execute(params, trace_id="test-trace") assert result.total == 1 assert result.resources[0]["id"] == "doc-789" diff --git a/tests/test_toolhive_agent.py b/tests/test_toolhive_agent.py index 5f92366..50aa61a 100644 --- a/tests/test_toolhive_agent.py +++ b/tests/test_toolhive_agent.py @@ -248,8 +248,8 @@ async def test_agent_fails_when_mcp_unreachable() -> None: # MCP entrypoint smoke test # --------------------------------------------------------------------------- -def test_mcp_entrypoint_registers_four_tools() -> None: - """The FastMCP server should expose exactly 4 tools.""" +def test_mcp_entrypoint_registers_eight_tools() -> None: + """The FastMCP server should expose the full FHIR + integration tool surface.""" # We patch all external deps before importing the module to avoid side effects with ( patch("bindings.factory.ConnectorFactory") as mock_factory_cls, @@ -280,9 +280,13 @@ def fake_tool(*args: Any, **kwargs: Any): from agents.mcp_entrypoint import _make_server _make_server() - assert len(registered_tools) == 4 + assert len(registered_tools) == 8 assert "fhir_cerner_read_patient" in registered_tools + assert "fhir_cerner_search_patients" in registered_tools + assert "fhir_cerner_search_encounters" in registered_tools assert "fhir_epic_read_patient" in registered_tools + assert "fhir_epic_search_patients" in registered_tools + assert "fhir_epic_search_encounters" in registered_tools assert "google_drive_upload_file" in registered_tools assert "smtp_send_email" in registered_tools From 018ba6296ec4c5e4330aa489581949a18a5f169c Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Tue, 31 Mar 2026 12:19:59 -0700 Subject: [PATCH 02/15] cleanup and update mcp --- src/agents/fhir_cerner_mcp.py | 22 ++++++------ src/agents/fhir_epic_mcp.py | 22 ++++++------ src/agents/mcp_entrypoint.py | 8 ++++- src/connectors/manifest.py | 66 +++++++++++++++++++++++------------ tests/test_fhir_cerner.py | 2 +- 5 files changed, 71 insertions(+), 49 deletions(-) diff --git a/src/agents/fhir_cerner_mcp.py b/src/agents/fhir_cerner_mcp.py index d329170..fd2067c 100644 --- a/src/agents/fhir_cerner_mcp.py +++ b/src/agents/fhir_cerner_mcp.py @@ -48,6 +48,7 @@ async def fhir_cerner_read_patient( patient_id: str = "", family_name: str = "", given_name: str = "", + name: str = "", birthdate: str = "", ) -> dict: trace_id = str(uuid.uuid4()) @@ -57,19 +58,16 @@ async def fhir_cerner_read_patient( if patient_id: params = FhirCernerPatientReadInput(action="read_patient", resource_id=patient_id) - elif family_name or given_name: - search = { - k: v - for k, v in { - "family": family_name, - "given": given_name, - "birthdate": birthdate, - }.items() - if v - } - params = FhirCernerPatientReadInput(action="read_patient", search_params=search) + elif family_name or given_name or name: + params = FhirCernerPatientReadInput( + action="read_patient", + given_name=given_name or None, + family_name=family_name or None, + name=name or None, + birthdate=birthdate or None, + ) else: - raise ValueError("Provide patient_id OR at least family_name/given_name") + raise ValueError("Provide patient_id OR at least family_name / given_name / name") result = await cerner.internal_execute(params, trace_id=trace_id) resource = result.resource diff --git a/src/agents/fhir_epic_mcp.py b/src/agents/fhir_epic_mcp.py index b196b7a..5e6798e 100644 --- a/src/agents/fhir_epic_mcp.py +++ b/src/agents/fhir_epic_mcp.py @@ -49,6 +49,7 @@ async def fhir_epic_read_patient( patient_id: str = "", family_name: str = "", given_name: str = "", + name: str = "", birthdate: str = "", ) -> dict: trace_id = str(uuid.uuid4()) @@ -58,19 +59,16 @@ async def fhir_epic_read_patient( if patient_id: params = FhirEpicPatientReadInput(action="read_patient", resource_id=patient_id) - elif family_name or given_name: - search = { - k: v - for k, v in { - "family": family_name, - "given": given_name, - "birthdate": birthdate, - }.items() - if v - } - params = FhirEpicPatientReadInput(action="read_patient", search_params=search) + elif family_name or given_name or name: + params = FhirEpicPatientReadInput( + action="read_patient", + given_name=given_name or None, + family_name=family_name or None, + name=name or None, + birthdate=birthdate or None, + ) else: - raise ValueError("Provide patient_id OR at least family_name/given_name") + raise ValueError("Provide patient_id OR at least family_name / given_name / name") result = await epic.internal_execute(params, trace_id=trace_id) resource = result.resource diff --git a/src/agents/mcp_entrypoint.py b/src/agents/mcp_entrypoint.py index 9d974eb..ba9ac46 100644 --- a/src/agents/mcp_entrypoint.py +++ b/src/agents/mcp_entrypoint.py @@ -406,6 +406,9 @@ async def fhir_cerner_search_encounters( if not cerner: raise RuntimeError("fhir_cerner connector not configured") + if not (patient_id or status or date): + raise ValueError("Provide at least one of patient_id / status / date") + params = FhirCernerEncounterSearchInput( action="search_encounter", patient_id=patient_id or None, @@ -463,6 +466,9 @@ async def fhir_epic_search_encounters( if not epic: raise RuntimeError("fhir_epic connector not configured") + if not (patient_id or status or date): + raise ValueError("Provide at least one of patient_id / status / date") + params = FhirEpicEncounterSearchInput( action="search_encounter", patient_id=patient_id or None, @@ -544,7 +550,7 @@ async def google_drive_upload_file( } # ------------------------------------------------------------------ - # Tool 4: Send email via SMTP + # Tool 8: Send email via SMTP # ------------------------------------------------------------------ @mcp.tool( diff --git a/src/connectors/manifest.py b/src/connectors/manifest.py index 2033234..ffb5cfa 100644 --- a/src/connectors/manifest.py +++ b/src/connectors/manifest.py @@ -6,30 +6,46 @@ from runtime import BaseConnector -# FHIR connectors expose a single `execute` entrypoint with a discriminated `action` -# field; expand these for REST/MCP discovery so routes remain per-operation. -_FHIR_DISCRIMINATED_ACTIONS: Dict[str, List[str]] = { - "fhir_cerner": [ - "read_patient", - "search_patients", - "search_encounter", - "create_document_reference", - "search_document_reference", - ], - "fhir_epic": [ - "read_patient", - "search_patients", - "search_encounter", - "create_document_reference", - "search_document_reference", - ], -} - def _schema_for(model: Type[BaseModel]) -> Dict[str, Any]: return model.model_json_schema() +def _fhir_action_schemas() -> Dict[str, Dict[str, Type[BaseModel]]]: + """Return per-action input model classes for FHIR connectors (lazy import).""" + from connectors.fhir_cerner.schema import ( + FhirCernerDocumentReferenceCreateInput, + FhirCernerDocumentReferenceSearchInput, + FhirCernerEncounterSearchInput, + FhirCernerPatientReadInput, + FhirCernerPatientSearchInput, + ) + from connectors.fhir_epic.schema import ( + FhirDocumentReferenceCreateInput, + FhirDocumentReferenceSearchInput, + FhirEncounterSearchInput, + FhirPatientReadInput, + FhirPatientSearchInput, + ) + + return { + "fhir_cerner": { + "read_patient": FhirCernerPatientReadInput, + "search_patients": FhirCernerPatientSearchInput, + "search_encounter": FhirCernerEncounterSearchInput, + "create_document_reference": FhirCernerDocumentReferenceCreateInput, + "search_document_reference": FhirCernerDocumentReferenceSearchInput, + }, + "fhir_epic": { + "read_patient": FhirPatientReadInput, + "search_patients": FhirPatientSearchInput, + "search_encounter": FhirEncounterSearchInput, + "create_document_reference": FhirDocumentReferenceCreateInput, + "search_document_reference": FhirDocumentReferenceSearchInput, + }, + } + + def build_manifest(connectors: List[BaseConnector[Any, Any]]) -> List[Dict[str, Any]]: """ Build a simple manifest for discovery. @@ -39,21 +55,25 @@ def build_manifest(connectors: List[BaseConnector[Any, Any]]) -> List[Dict[str, REST route generation and MCP tool manifests. """ manifest: List[Dict[str, Any]] = [] + fhir_schemas: Dict[str, Dict[str, Type[BaseModel]]] | None = None + for connector in connectors: - input_model = connector._input_model_cls # type: ignore[attr-defined] output_model = connector._output_model_cls # type: ignore[attr-defined] cid = connector.connector_id - if cid in _FHIR_DISCRIMINATED_ACTIONS and getattr(connector, "action", None) == "execute": - for sub_action in _FHIR_DISCRIMINATED_ACTIONS[cid]: + if getattr(connector, "action", None) == "execute" and cid in ("fhir_cerner", "fhir_epic"): + if fhir_schemas is None: + fhir_schemas = _fhir_action_schemas() + for sub_action, input_cls in fhir_schemas[cid].items(): manifest.append( { "connector_id": cid, "action": sub_action, - "input_schema": _schema_for(input_model), + "input_schema": _schema_for(input_cls), "output_schema": _schema_for(output_model), } ) else: + input_model = connector._input_model_cls # type: ignore[attr-defined] manifest.append( { "connector_id": cid, diff --git a/tests/test_fhir_cerner.py b/tests/test_fhir_cerner.py index 903a927..9aa7fe7 100644 --- a/tests/test_fhir_cerner.py +++ b/tests/test_fhir_cerner.py @@ -16,7 +16,7 @@ class MockSecretProvider(SecretProvider): def get_secret(self, key: str) -> str: return { "cerner_fhir_base_url": "https://fhir-myrecord.cerner.com/r4/tenant-id", - "cerner_private_key": "-----BEGIN RSA PRIVATE KEY-----\\\\nMEowIQ...dummy\\\\n-----END RSA PRIVATE KEY-----", + "cerner_private_key": "-----BEGIN RSA PRIVATE KEY-----\\nMEowIQ...dummy\\n-----END RSA PRIVATE KEY-----", "cerner_kid": "dummy-kid", "cerner_client_id": "dummy-client-id", "cerner_token_url": "https://authorization.cerner.com/tenants/tenant-id/protocols/oauth2/profiles/smart-v1/token", From cc803d261bc104f3d67d55597ca2087f994d0245 Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Tue, 31 Mar 2026 19:08:43 -0700 Subject: [PATCH 03/15] SDK connector added --- src/bindings/factory.py | 51 +-- src/bindings/mcp_server/server.py | 3 +- src/bindings/rest_api/app.py | 11 +- src/connectors/__init__.py | 27 +- src/connectors/fhir_cerner/__init__.py | 1 + src/connectors/fhir_cerner/logic.py | 79 +++-- src/connectors/fhir_cerner/schema.py | 26 +- src/connectors/fhir_epic/__init__.py | 1 + src/connectors/fhir_epic/logic.py | 292 ++++++++++------- src/connectors/fhir_epic/schema.py | 26 +- src/connectors/google_drive/logic.py | 413 ++++++++++++++----------- src/connectors/google_drive/schema.py | 18 +- src/connectors/http_generic/logic.py | 2 - src/connectors/manifest.py | 61 +--- src/connectors/stripe/logic.py | 42 ++- src/connectors/stripe/schema.py | 4 +- src/runtime/__init__.py | 4 + src/runtime/base.py | 3 +- src/runtime/sdk_connector.py | 217 +++++++++++++ tests/test_connectors_basic.py | 6 +- tests/test_google_drive.py | 9 +- tests/test_sdk_connector_manifest.py | 58 ++++ 22 files changed, 836 insertions(+), 518 deletions(-) create mode 100644 src/connectors/fhir_cerner/__init__.py create mode 100644 src/connectors/fhir_epic/__init__.py create mode 100644 src/runtime/sdk_connector.py create mode 100644 tests/test_sdk_connector_manifest.py diff --git a/src/bindings/factory.py b/src/bindings/factory.py index 76e4df8..2f87234 100644 --- a/src/bindings/factory.py +++ b/src/bindings/factory.py @@ -7,24 +7,15 @@ import yaml -from connectors.fhir_epic.logic import FhirEpicConnector -from connectors.fhir_cerner.logic import FhirCernerConnector from connectors.http_generic.logic import HttpGenericConnector from connectors.http_generic.schema import HttpRequestInput, HttpResponseOutput -from connectors.google_drive.logic import GoogleDriveConnector -from connectors.google_drive.schema import ( - GoogleDriveOperationInput, - GoogleDriveOperationOutput, -) from connectors.smtp.logic import SmtpConnector from connectors.smtp.schema import SmtpSendInput, SmtpSendOutput -from connectors.stripe.logic import StripeChargeConnector -from connectors.stripe.schema import ChargeInput, ChargeOutput from runtime import BaseConnector, SecretProvider +from runtime.sdk_connector import _CONNECTOR_REGISTRY logger = logging.getLogger("bindings.factory") -# Resolve default config relative to platform root so it works from any cwd. _PLATFORM_ROOT = Path(__file__).resolve().parent.parent.parent _DEFAULT_CONFIG_PATH = _PLATFORM_ROOT / "config" / "connectors.yaml" @@ -38,11 +29,7 @@ class ConnectorConfig: class EnvSecretProvider(SecretProvider): - """ - Simple SecretProvider implementation backed by environment variables. - - Keys are looked up directly from os.environ for the POC. - """ + """SecretProvider backed by environment variables.""" def __init__(self) -> None: import os @@ -56,16 +43,13 @@ def get_secret(self, key: str) -> str: val = self._env.get(key.upper()) if val is not None: return val.strip(" '\"") - # Return empty string instead of raising RuntimeError for zero-config/local testing. return "" class ConnectorFactory: """ - Factory responsible for: - - Loading connector configuration from config/connectors.yaml - - Instantiating connector adapters - - Enforcing exposed_via rules per protocol + Loads config/connectors.yaml, instantiates connectors from the SDK registry + or legacy explicit constructors. """ def __init__(self, config_path: str | Path | None = None) -> None: @@ -74,7 +58,6 @@ def __init__(self, config_path: str | Path | None = None) -> None: elif _DEFAULT_CONFIG_PATH.is_file(): self._config_path = str(_DEFAULT_CONFIG_PATH) else: - # Fallback when run from platform dir (e.g. package installed from wheel) cwd_config = Path.cwd() / "config" / "connectors.yaml" self._config_path = str(cwd_config) self._secret_provider: SecretProvider = EnvSecretProvider() @@ -114,22 +97,22 @@ def load(self) -> None: self._connectors[connector_id] = self._instantiate(connector_id) def _instantiate(self, connector_id: str) -> Any: + sdk_cls = _CONNECTOR_REGISTRY.get(connector_id) + if sdk_cls is not None: + return sdk_cls(secret_provider=self._secret_provider) + if connector_id == "http_generic": - return HttpGenericConnector(HttpRequestInput, HttpResponseOutput, secret_provider=self._secret_provider) + return HttpGenericConnector( + HttpRequestInput, + HttpResponseOutput, + secret_provider=self._secret_provider, + ) if connector_id == "smtp": - return SmtpConnector(SmtpSendInput, SmtpSendOutput, secret_provider=self._secret_provider) - if connector_id == "stripe": - return StripeChargeConnector(ChargeInput, ChargeOutput, secret_provider=self._secret_provider) - if connector_id == "google_drive": - return GoogleDriveConnector( - GoogleDriveOperationInput, - GoogleDriveOperationOutput, + return SmtpConnector( + SmtpSendInput, + SmtpSendOutput, secret_provider=self._secret_provider, ) - if connector_id == "fhir_epic": - return FhirEpicConnector(secret_provider=self._secret_provider) - if connector_id == "fhir_cerner": - return FhirCernerConnector(secret_provider=self._secret_provider) raise ValueError(f"Unknown connector id {connector_id!r}") @@ -164,7 +147,7 @@ def get_for_protocol( if action: logger.debug( - "get_for_protocol resolved connector (action from URL is merged into payload by REST)", + "get_for_protocol resolved connector", extra={"connector_id": connector_id, "protocol": protocol, "action": action}, ) diff --git a/src/bindings/mcp_server/server.py b/src/bindings/mcp_server/server.py index 6b9f0a2..5bbed57 100644 --- a/src/bindings/mcp_server/server.py +++ b/src/bindings/mcp_server/server.py @@ -7,6 +7,7 @@ from bindings.factory import ConnectorFactory from connectors import auto_register from connectors.manifest import build_manifest +from runtime import SDKConnector logger = logging.getLogger("bindings.mcp_server") @@ -51,7 +52,7 @@ async def invoke_tool(self, name: str, arguments: Dict[str, Any]) -> Dict[str, A raise ValueError(f"Connector {connector_id!r} is not available via MCP.") run_args = dict(arguments) - if connector_id in ("fhir_cerner", "fhir_epic"): + if isinstance(connector, SDKConnector): run_args.setdefault("action", action) response = await connector.run(run_args) diff --git a/src/bindings/rest_api/app.py b/src/bindings/rest_api/app.py index f442734..4fd283a 100644 --- a/src/bindings/rest_api/app.py +++ b/src/bindings/rest_api/app.py @@ -6,7 +6,6 @@ from fastapi import Depends, FastAPI, HTTPException from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel, create_model from dotenv import load_dotenv load_dotenv() # Load environmental variables from .env @@ -14,7 +13,7 @@ from bindings.factory import ConnectorFactory from connectors import auto_register from connectors.manifest import build_manifest -from runtime import ConnectorResponse, ErrorCategory +from runtime import ConnectorResponse, ErrorCategory, SDKConnector from opentelemetry import trace from opentelemetry.trace import Status, StatusCode from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor @@ -36,9 +35,6 @@ app = FastAPI(title="Node Wire - REST API") FastAPIInstrumentor.instrument_app(app) -import os -from pathlib import Path - # Include the professional scenarios orchestrator app.include_router(scenarios_router) @@ -73,9 +69,6 @@ def _http_status_for_category(category: ErrorCategory | None) -> int: return 503 return 500 -_FHIR_REST_IDS = frozenset({"fhir_cerner", "fhir_epic"}) - - def _make_endpoint(cid: str, act: str) -> Any: async def endpoint( payload: Dict[str, Any], @@ -93,7 +86,7 @@ async def endpoint( if connector is None: raise HTTPException(status_code=404, detail="Connector not available for REST") run_payload = dict(payload) - if cid in _FHIR_REST_IDS: + if isinstance(connector, SDKConnector): run_payload.setdefault("action", act) # Let the runtime (Layer A) perform full schema validation. # Any validation errors will be mapped into ConnectorResponse. diff --git a/src/connectors/__init__.py b/src/connectors/__init__.py index f9c7b0f..50d93b4 100644 --- a/src/connectors/__init__.py +++ b/src/connectors/__init__.py @@ -3,48 +3,51 @@ """ Node Wire - Layer B: System Adapters. -Each connector lives in its own subpackage and follows the three-file pattern: +Each connector lives in its own subpackage: connector_name/ schema.py logic.py - registration.py + registration.py (optional — legacy connectors) -Registration modules are auto-discovered so they can register system-specific -exceptions with the global ErrorMapper in Layer A. +SDKConnector-based connectors self-register when their `logic` module is +imported. Legacy connectors may still use `registration.py` for ErrorMapper. """ from importlib import import_module from pkgutil import iter_modules -from typing import Iterable, List +from typing import List def auto_register() -> List[str]: """ - Import all `registration` modules in connector subpackages. + Import connector subpackages so SDK connectors register and legacy mappings apply. - Returns the list of successfully imported module names. This should be - called once at process startup (e.g. by Layer C bindings) to ensure all - connector-specific error mappings are registered. + Imports `logic` first (triggers SDKConnector.__init_subclass__), then + `registration` when present. """ imported: List[str] = [] package_name = __name__ for module_info in iter_modules(__path__, prefix=f"{package_name}."): - # We only care about subpackages; each is expected to expose registration.py if not module_info.ispkg: continue + logic_module = f"{module_info.name}.logic" + try: + import_module(logic_module) + imported.append(logic_module) + except ModuleNotFoundError: + pass + registration_module = f"{module_info.name}.registration" try: import_module(registration_module) imported.append(registration_module) except ModuleNotFoundError: - # Connector without a registration module; skip silently. continue return imported __all__ = ["auto_register"] - diff --git a/src/connectors/fhir_cerner/__init__.py b/src/connectors/fhir_cerner/__init__.py new file mode 100644 index 0000000..9fa8ea5 --- /dev/null +++ b/src/connectors/fhir_cerner/__init__.py @@ -0,0 +1 @@ +"""FHIR Cerner connector package.""" diff --git a/src/connectors/fhir_cerner/logic.py b/src/connectors/fhir_cerner/logic.py index c05281c..94b453d 100644 --- a/src/connectors/fhir_cerner/logic.py +++ b/src/connectors/fhir_cerner/logic.py @@ -11,7 +11,7 @@ import httpx import jwt -from runtime import BaseConnector, SecretProvider +from runtime import SDKConnector, sdk_action from . import registration from .schema import ( @@ -21,7 +21,6 @@ FhirCernerDocumentReferenceSearchOutput, FhirCernerEncounterSearchInput, FhirCernerEncounterSearchOutput, - FhirCernerOperationInput, FhirCernerOperationOutput, FhirCernerPatientReadInput, FhirCernerPatientReadOutput, @@ -32,44 +31,56 @@ logger = logging.getLogger("connectors.fhir_cerner") -class FhirCernerConnector(BaseConnector[FhirCernerOperationInput, FhirCernerOperationOutput]): +class FhirCernerConnector(SDKConnector): """ - Single FHIR/Cerner connector. - - Authentication uses Cerner's SMART Backend Services (private_key_jwt) flow, - identical to Epic's implementation — RS384-signed JWT exchanged for an - OAuth2 access token at the configured token endpoint. - - Required secrets (configured via SecretProvider): - - cerner_fhir_base_url : Cerner FHIR R4 base URL - - cerner_private_key : RSA private key PEM (newlines may be escaped) - - cerner_kid : Key ID registered in the Cerner code console - - cerner_client_id : Client ID from Cerner app registration - - cerner_token_url : OAuth2 token endpoint URL (from .well-known/smart-configuration - or the Cerner code console) + FHIR/Cerner connector: SMART Backend Services (private_key_jwt), RS384. + + Required secrets: cerner_fhir_base_url, cerner_private_key, cerner_kid, + cerner_client_id, cerner_token_url (optional cerner_scopes). """ connector_id = "fhir_cerner" action = "execute" + output_model = FhirCernerOperationOutput + + @sdk_action("read_patient") + async def read_patient( + self, params: FhirCernerPatientReadInput, *, trace_id: str + ) -> FhirCernerOperationOutput: + out = await self._read_patient(params, trace_id=trace_id) + return FhirCernerOperationOutput(resource=out.resource) + + @sdk_action("search_patients") + async def search_patients( + self, params: FhirCernerPatientSearchInput, *, trace_id: str + ) -> FhirCernerOperationOutput: + out = await self._search_patients(params, trace_id=trace_id) + return FhirCernerOperationOutput( + resources=out.resources, + total=out.total, + errors=out.errors, + ) - def __init__(self, *, secret_provider: SecretProvider) -> None: - super().__init__(FhirCernerOperationInput, FhirCernerOperationOutput, secret_provider=secret_provider) - self._secret_provider = secret_provider - - async def internal_execute(self, params: Any, *, trace_id: str) -> Any: - # Back-compat: allow calling with either the RootModel union or a concrete action input model. - op = params.root if hasattr(params, "root") else params - if op.action == "read_patient": - return await self._read_patient(op, trace_id=trace_id) - if op.action == "search_patients": - return await self._search_patients(op, trace_id=trace_id) - if op.action == "search_encounter": - return await self._search_encounter(op, trace_id=trace_id) - if op.action == "create_document_reference": - return await self._create_document_reference(op, trace_id=trace_id) - if op.action == "search_document_reference": - return await self._search_document_reference(op, trace_id=trace_id) - raise ValueError(f"Unsupported action: {op.action!r}") + @sdk_action("search_encounter") + async def search_encounter( + self, params: FhirCernerEncounterSearchInput, *, trace_id: str + ) -> FhirCernerOperationOutput: + out = await self._search_encounter(params, trace_id=trace_id) + return FhirCernerOperationOutput(resources=out.resources, total=out.total) + + @sdk_action("create_document_reference") + async def create_document_reference( + self, params: FhirCernerDocumentReferenceCreateInput, *, trace_id: str + ) -> FhirCernerOperationOutput: + out = await self._create_document_reference(params, trace_id=trace_id) + return FhirCernerOperationOutput(resource_id=out.resource_id, resource=out.resource) + + @sdk_action("search_document_reference") + async def search_document_reference( + self, params: FhirCernerDocumentReferenceSearchInput, *, trace_id: str + ) -> FhirCernerOperationOutput: + out = await self._search_document_reference(params, trace_id=trace_id) + return FhirCernerOperationOutput(resources=out.resources, total=out.total) # ------------------------------------------------------------------ # Shared authentication helpers diff --git a/src/connectors/fhir_cerner/schema.py b/src/connectors/fhir_cerner/schema.py index e24d915..123c81d 100644 --- a/src/connectors/fhir_cerner/schema.py +++ b/src/connectors/fhir_cerner/schema.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Annotated, Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional -from pydantic import BaseModel, Field, RootModel +from pydantic import BaseModel, Field # --------------------------------------------------------------------------- @@ -275,29 +275,11 @@ class FhirCernerDocumentReferenceSearchOutput(BaseModel): """Total number of results reported by the Bundle.""" -# --------------------------------------------------------------------------- -# Unified operation input/output (one endpoint, multiple actions) -# --------------------------------------------------------------------------- - -_FhirCernerOperationUnion = Annotated[ - Union[ - FhirCernerPatientReadInput, - FhirCernerPatientSearchInput, - FhirCernerEncounterSearchInput, - FhirCernerDocumentReferenceCreateInput, - FhirCernerDocumentReferenceSearchInput, - ], - Field(discriminator="action"), -] - -FhirCernerOperationInput = RootModel[_FhirCernerOperationUnion] - - class FhirCernerOperationOutput(BaseModel): """ - Combined output shape for schema documentation/manifest generation. + Unified output for all Cerner FHIR actions (SDKConnector single output_model). - Individual handlers still return their specific output models. + Fields are populated depending on the action; unused fields are None. """ resource: Optional[Dict[str, Any]] = None diff --git a/src/connectors/fhir_epic/__init__.py b/src/connectors/fhir_epic/__init__.py new file mode 100644 index 0000000..aa47436 --- /dev/null +++ b/src/connectors/fhir_epic/__init__.py @@ -0,0 +1 @@ +"""FHIR Epic connector package.""" diff --git a/src/connectors/fhir_epic/logic.py b/src/connectors/fhir_epic/logic.py index 5cbe8c3..9e72e58 100644 --- a/src/connectors/fhir_epic/logic.py +++ b/src/connectors/fhir_epic/logic.py @@ -11,7 +11,7 @@ import httpx import jwt -from runtime import BaseConnector, SecretProvider +from runtime import SDKConnector, sdk_action from .schema import ( FhirDocumentReferenceCreateInput, @@ -20,7 +20,6 @@ FhirDocumentReferenceSearchOutput, FhirEncounterSearchInput, FhirEncounterSearchOutput, - FhirEpicOperationInput, FhirEpicOperationOutput, FhirPatientReadInput, FhirPatientReadOutput, @@ -31,69 +30,82 @@ logger = logging.getLogger("connectors.fhir_epic") -class FhirEpicConnector(BaseConnector[FhirEpicOperationInput, FhirEpicOperationOutput]): - """ - Single FHIR/Epic connector. - - Exposes one endpoint (`execute`) and dispatches actions via the - `action` discriminator on the request payload. - """ +class FhirEpicConnector(SDKConnector): + """FHIR/Epic connector: one @sdk_action per operation.""" connector_id = "fhir_epic" action = "execute" + output_model = FhirEpicOperationOutput + + @sdk_action("read_patient") + async def read_patient( + self, params: FhirPatientReadInput, *, trace_id: str + ) -> FhirEpicOperationOutput: + out = await self._read_patient(params, trace_id=trace_id) + return FhirEpicOperationOutput(resource=out.resource) + + @sdk_action("search_patients") + async def search_patients( + self, params: FhirPatientSearchInput, *, trace_id: str + ) -> FhirEpicOperationOutput: + out = await self._search_patients(params, trace_id=trace_id) + return FhirEpicOperationOutput( + resources=out.resources, + total=out.total, + errors=out.errors, + ) - def __init__(self, *, secret_provider: SecretProvider) -> None: - super().__init__(FhirEpicOperationInput, FhirEpicOperationOutput, secret_provider=secret_provider) - self._secret_provider = secret_provider - - async def internal_execute(self, params: Any, *, trace_id: str) -> Any: - # Back-compat: allow calling with either the RootModel union or a concrete action input model. - op = params.root if hasattr(params, "root") else params - if op.action == "read_patient": - return await self._read_patient(op, trace_id=trace_id) - if op.action == "search_patients": - return await self._search_patients(op, trace_id=trace_id) - if op.action == "search_encounter": - return await self._search_encounter(op, trace_id=trace_id) - if op.action == "create_document_reference": - return await self._create_document_reference(op, trace_id=trace_id) - if op.action == "search_document_reference": - return await self._search_document_reference(op, trace_id=trace_id) - raise ValueError(f"Unsupported action: {op.action!r}") + @sdk_action("search_encounter") + async def search_encounter( + self, params: FhirEncounterSearchInput, *, trace_id: str + ) -> FhirEpicOperationOutput: + out = await self._search_encounter(params, trace_id=trace_id) + return FhirEpicOperationOutput(resources=out.resources, total=out.total) + + @sdk_action("create_document_reference") + async def create_document_reference( + self, params: FhirDocumentReferenceCreateInput, *, trace_id: str + ) -> FhirEpicOperationOutput: + out = await self._create_document_reference(params, trace_id=trace_id) + return FhirEpicOperationOutput(resource_id=out.resource_id, resource=out.resource) + + @sdk_action("search_document_reference") + async def search_document_reference( + self, params: FhirDocumentReferenceSearchInput, *, trace_id: str + ) -> FhirEpicOperationOutput: + out = await self._search_document_reference(params, trace_id=trace_id) + return FhirEpicOperationOutput(resources=out.resources, total=out.total) # ------------------------------------------------------------------ # Shared authentication helpers # ------------------------------------------------------------------ def _get_base_url(self) -> str: - return self._secret_provider.get_secret("epic_fhir_base_url").rstrip("/") + return self.secret_provider.get_secret("epic_fhir_base_url").rstrip("/") async def _get_auth_header(self) -> Dict[str, str]: - """ - Obtain an access token via Epic's SMART Backend Services (private_key_jwt) - and return ready-to-use request headers. - - Algorithm: RS384. Token lifetime: 5 minutes (Epic maximum). - Reference: https://fhir.epic.com/Documentation?docId=oauth2tutorial§ion=cloud-based-app - """ headers = { "Content-Type": "application/fhir+json", "Accept": "application/fhir+json", } - private_key_str = self._secret_provider.get_secret("epic_private_key") - kid = self._secret_provider.get_secret("epic_kid") - client_id = self._secret_provider.get_secret("epic_client_id") - token_url = self._secret_provider.get_secret("epic_token_url") + private_key_str = self.secret_provider.get_secret("epic_private_key") + kid = self.secret_provider.get_secret("epic_kid") + client_id = self.secret_provider.get_secret("epic_client_id") + token_url = self.secret_provider.get_secret("epic_token_url") - # Environment variables sometimes store newlines as escape sequences. private_key_pem = codecs.decode(private_key_str, "unicode_escape") now = int(datetime.now(tz=timezone.utc).timestamp()) jwt_token = jwt.encode( { - "iss": client_id, "sub": client_id, "aud": token_url, - "jti": str(uuid.uuid4()), "iat": now, "nbf": now, "exp": now + 300, + "iss": client_id, + "sub": client_id, + "aud": token_url, + "jti": str(uuid.uuid4()), + "iat": now, + "nbf": now, + "exp": now + 300, }, private_key_pem, algorithm="RS384", @@ -115,7 +127,8 @@ async def _get_auth_header(self) -> Dict[str, str]: if token_response.status_code != 200: logger.error( "OAuth token exchange failed | status=%s | body=%s", - token_response.status_code, token_response.text, + token_response.status_code, + token_response.text, ) token_response.raise_for_status() token_data = token_response.json() @@ -127,10 +140,6 @@ async def _get_auth_header(self) -> Dict[str, str]: headers["Authorization"] = f"Bearer {access_token}" return headers - # ------------------------------------------------------------------ - # Internal name-field helpers - # ------------------------------------------------------------------ - @staticmethod def _build_name_search_params( given_name: Optional[str], @@ -139,7 +148,6 @@ def _build_name_search_params( birthdate: Optional[str], extra: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: - """Build a FHIR search params dict from explicit name/date fields.""" params: Dict[str, str] = dict(extra or {}) if given_name and given_name.strip(): @@ -160,7 +168,6 @@ def _build_encounter_search_params( date: Optional[str], extra: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: - """Build a FHIR search params dict for Encounter from explicit fields.""" params: Dict[str, str] = dict(extra or {}) if patient_id and patient_id.strip(): @@ -172,10 +179,6 @@ def _build_encounter_search_params( return params - # ------------------------------------------------------------------ - # Action: read_patient - # ------------------------------------------------------------------ - async def _read_patient( self, params: FhirPatientReadInput, *, trace_id: str ) -> FhirPatientReadOutput: @@ -185,18 +188,30 @@ async def _read_patient( if params.resource_id: url = f"{base_url}/Patient/{params.resource_id}" query_params: Optional[Dict[str, str]] = None - logger.info("FHIR Patient read by ID", extra={"trace_id": trace_id, "resource_id": params.resource_id}) + logger.info( + "FHIR Patient read by ID", + extra={"trace_id": trace_id, "resource_id": params.resource_id}, + ) elif params.given_name or params.family_name or params.name: url = f"{base_url}/Patient" query_params = self._build_name_search_params( - params.given_name, params.family_name, params.name, - params.birthdate, params.search_params, + params.given_name, + params.family_name, + params.name, + params.birthdate, + params.search_params, + ) + logger.info( + "FHIR Patient read by name fields", + extra={"trace_id": trace_id, "query_params": query_params}, ) - logger.info("FHIR Patient read by name fields", extra={"trace_id": trace_id, "query_params": query_params}) elif params.search_params: url = f"{base_url}/Patient" query_params = params.search_params - logger.info("FHIR Patient read by search", extra={"trace_id": trace_id, "search_params": params.search_params}) + logger.info( + "FHIR Patient read by search", + extra={"trace_id": trace_id, "search_params": params.search_params}, + ) else: raise ValueError( "Provide resource_id, or name fields (given_name/family_name/name), " @@ -205,10 +220,17 @@ async def _read_patient( try: async with httpx.AsyncClient() as client: - response = await client.get(url, headers=auth_header, params=query_params, timeout=30.0) + response = await client.get( + url, headers=auth_header, params=query_params, timeout=30.0 + ) response.raise_for_status() except Exception as exc: - logger.error("FHIR Patient read failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) + logger.error( + "FHIR Patient read failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) raise data = response.json() @@ -220,13 +242,12 @@ async def _read_patient( else: resource = data - logger.info("FHIR Patient read completed", extra={"trace_id": trace_id, "status_code": response.status_code}) + logger.info( + "FHIR Patient read completed", + extra={"trace_id": trace_id, "status_code": response.status_code}, + ) return FhirPatientReadOutput(resource=resource) - # ------------------------------------------------------------------ - # Action: search_patients (multi-ID fan-out OR name search) - # ------------------------------------------------------------------ - async def _search_patients( self, params: FhirPatientSearchInput, *, trace_id: str ) -> FhirPatientSearchOutput: @@ -257,7 +278,8 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ except Exception as exc: logger.warning( "FHIR Patient fetch failed | resource_id=%s | error=%s", - rid, str(exc), + rid, + str(exc), extra={"trace_id": trace_id}, ) return rid, None, str(exc) @@ -274,14 +296,20 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ logger.info( "FHIR Patient multi-ID lookup completed | found=%s | errors=%s", - len(resources), len(errors), + len(resources), + len(errors), extra={"trace_id": trace_id}, ) - return FhirPatientSearchOutput(resources=resources, total=len(resources), errors=errors) + return FhirPatientSearchOutput( + resources=resources, total=len(resources), errors=errors + ) name_params = self._build_name_search_params( - params.given_name, params.family_name, params.name, - params.birthdate, params.search_params, + params.given_name, + params.family_name, + params.name, + params.birthdate, + params.search_params, ) if not name_params: raise ValueError( @@ -307,14 +335,16 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ except httpx.HTTPStatusError as exc: logger.error( "FHIR Patient name search failed | status=%s | body=%s", - exc.response.status_code, exc.response.text, + exc.response.status_code, + exc.response.text, extra={"trace_id": trace_id}, ) raise except Exception as exc: logger.error( "FHIR Patient name search failed | error=%s: %s", - type(exc).__name__, str(exc), + type(exc).__name__, + str(exc), extra={"trace_id": trace_id}, ) raise @@ -327,15 +357,12 @@ async def _fetch_one(rid: str) -> tuple[str, Optional[Dict[str, Any]], Optional[ logger.info( "FHIR Patient name search completed | found=%s | total=%s", - len(resources), total, + len(resources), + total, extra={"trace_id": trace_id}, ) return FhirPatientSearchOutput(resources=resources, total=total) - # ------------------------------------------------------------------ - # Action: search_encounter - # ------------------------------------------------------------------ - async def _search_encounter( self, params: FhirEncounterSearchInput, *, trace_id: str ) -> FhirEncounterSearchOutput: @@ -346,24 +373,43 @@ async def _search_encounter( query_params = self._build_encounter_search_params( params.patient_id, params.status, params.date, params.search_params ) - logger.info("FHIR Encounter search by explicit fields", extra={"trace_id": trace_id, "query_params": query_params}) + logger.info( + "FHIR Encounter search by explicit fields", + extra={"trace_id": trace_id, "query_params": query_params}, + ) elif params.search_params: query_params = params.search_params - logger.info("FHIR Encounter search by raw params", extra={"trace_id": trace_id, "search_params": params.search_params}) + logger.info( + "FHIR Encounter search by raw params", + extra={"trace_id": trace_id, "search_params": params.search_params}, + ) else: raise ValueError("Provide at least patient_id, status, date OR search_params") try: async with httpx.AsyncClient() as client: response = await client.get( - f"{base_url}/Encounter", headers=auth_header, params=query_params, timeout=30.0, + f"{base_url}/Encounter", + headers=auth_header, + params=query_params, + timeout=30.0, ) response.raise_for_status() except httpx.HTTPStatusError as exc: - logger.error("FHIR Encounter search failed | status=%s | body=%s", exc.response.status_code, exc.response.text, extra={"trace_id": trace_id}) + logger.error( + "FHIR Encounter search failed | status=%s | body=%s", + exc.response.status_code, + exc.response.text, + extra={"trace_id": trace_id}, + ) raise except Exception as exc: - logger.error("FHIR Encounter search failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) + logger.error( + "FHIR Encounter search failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) raise data = response.json() @@ -372,13 +418,13 @@ async def _search_encounter( if data.get("resourceType") == "Bundle" and data.get("entry"): resources = [e["resource"] for e in data["entry"] if "resource" in e] - logger.info("FHIR Encounter search completed | found=%s", len(resources), extra={"trace_id": trace_id}) + logger.info( + "FHIR Encounter search completed | found=%s", + len(resources), + extra={"trace_id": trace_id}, + ) return FhirEncounterSearchOutput(resources=resources, total=total) - # ------------------------------------------------------------------ - # Action: create_document_reference - # ------------------------------------------------------------------ - async def _create_document_reference( self, params: FhirDocumentReferenceCreateInput, *, trace_id: str ) -> FhirDocumentReferenceCreateOutput: @@ -392,7 +438,14 @@ async def _create_document_reference( "type": params.type, "subject": {"reference": params.subject}, "date": datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), - "content": [{"attachment": {"contentType": params.content_type or "text/plain", "data": params.data}}], + "content": [ + { + "attachment": { + "contentType": params.content_type or "text/plain", + "data": params.data, + } + } + ], } if params.category: doc_ref["category"] = params.category @@ -410,7 +463,10 @@ async def _create_document_reference( try: async with httpx.AsyncClient() as client: response = await client.post( - f"{base_url}/DocumentReference", json=doc_ref, headers=auth_header, timeout=30.0, + f"{base_url}/DocumentReference", + json=doc_ref, + headers=auth_header, + timeout=30.0, ) response.raise_for_status() except httpx.HTTPStatusError as exc: @@ -427,13 +483,19 @@ async def _create_document_reference( logger.error( "FHIR DocumentReference create failed | status=%s | epic_error=%s | sent_payload=%s", - exc.response.status_code, error_detail, json.dumps(doc_ref), + exc.response.status_code, + error_detail, + json.dumps(doc_ref), extra={"trace_id": trace_id}, ) - # Raise a more descriptive error for the API to catch raise ValueError(f"Epic Error: {error_detail}") from exc except Exception as exc: - logger.error("FHIR DocumentReference create failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) + logger.error( + "FHIR DocumentReference create failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) raise resource_id: Optional[str] = None @@ -442,7 +504,11 @@ async def _create_document_reference( location = response.headers.get("Location", "") if location: history_marker = location.find("/_history/") - resource_id = location[:history_marker].split("/")[-1] if history_marker != -1 else location.split("/")[-1] + resource_id = ( + location[:history_marker].split("/")[-1] + if history_marker != -1 + else location.split("/")[-1] + ) if not resource_id: content_length = response.headers.get("content-length", "0") @@ -459,12 +525,14 @@ async def _create_document_reference( f"Status: {response.status_code}, Location: {location!r}, Body: {response.text[:200]!r}" ) - logger.info("FHIR DocumentReference create completed | resource_id=%s", resource_id, extra={"trace_id": trace_id}) - return FhirDocumentReferenceCreateOutput(resource_id=resource_id, resource=body if body else None) - - # ------------------------------------------------------------------ - # Action: search_document_reference - # ------------------------------------------------------------------ + logger.info( + "FHIR DocumentReference create completed | resource_id=%s", + resource_id, + extra={"trace_id": trace_id}, + ) + return FhirDocumentReferenceCreateOutput( + resource_id=resource_id, resource=body if body else None + ) async def _search_document_reference( self, params: FhirDocumentReferenceSearchInput, *, trace_id: str @@ -472,19 +540,35 @@ async def _search_document_reference( base_url = self._get_base_url() auth_header = await self._get_auth_header() - logger.info("FHIR DocumentReference search", extra={"trace_id": trace_id, "search_params": params.search_params}) + logger.info( + "FHIR DocumentReference search", + extra={"trace_id": trace_id, "search_params": params.search_params}, + ) try: async with httpx.AsyncClient() as client: response = await client.get( - f"{base_url}/DocumentReference", headers=auth_header, params=params.search_params, timeout=30.0, + f"{base_url}/DocumentReference", + headers=auth_header, + params=params.search_params, + timeout=30.0, ) response.raise_for_status() except httpx.HTTPStatusError as exc: - logger.error("FHIR DocumentReference search failed | status=%s | body=%s", exc.response.status_code, exc.response.text, extra={"trace_id": trace_id}) + logger.error( + "FHIR DocumentReference search failed | status=%s | body=%s", + exc.response.status_code, + exc.response.text, + extra={"trace_id": trace_id}, + ) raise except Exception as exc: - logger.error("FHIR DocumentReference search failed | error=%s: %s", type(exc).__name__, str(exc), extra={"trace_id": trace_id}) + logger.error( + "FHIR DocumentReference search failed | error=%s: %s", + type(exc).__name__, + str(exc), + extra={"trace_id": trace_id}, + ) raise data = response.json() @@ -498,4 +582,4 @@ async def _search_document_reference( len(resources), extra={"trace_id": trace_id}, ) - return FhirDocumentReferenceSearchOutput(resources=resources, total=total) \ No newline at end of file + return FhirDocumentReferenceSearchOutput(resources=resources, total=total) diff --git a/src/connectors/fhir_epic/schema.py b/src/connectors/fhir_epic/schema.py index eaeef26..55d8103 100644 --- a/src/connectors/fhir_epic/schema.py +++ b/src/connectors/fhir_epic/schema.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Annotated, Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional -from pydantic import BaseModel, Field, RootModel +from pydantic import BaseModel, Field # --------------------------------------------------------------------------- @@ -182,29 +182,11 @@ class FhirDocumentReferenceSearchOutput(BaseModel): """Total number of results reported by the Bundle.""" -# --------------------------------------------------------------------------- -# Unified operation input/output (one endpoint, multiple actions) -# --------------------------------------------------------------------------- - -_FhirEpicOperationUnion = Annotated[ - Union[ - FhirPatientReadInput, - FhirPatientSearchInput, - FhirEncounterSearchInput, - FhirDocumentReferenceCreateInput, - FhirDocumentReferenceSearchInput, - ], - Field(discriminator="action"), -] - -FhirEpicOperationInput = RootModel[_FhirEpicOperationUnion] - - class FhirEpicOperationOutput(BaseModel): """ - Combined output shape for schema documentation/manifest generation. + Unified output for all Epic FHIR actions (SDKConnector single output_model). - Individual handlers still return their specific output models. + Fields are populated depending on the action; unused fields are None. """ resource: Optional[Dict[str, Any]] = None diff --git a/src/connectors/google_drive/logic.py b/src/connectors/google_drive/logic.py index a4b2b3d..36e9107 100644 --- a/src/connectors/google_drive/logic.py +++ b/src/connectors/google_drive/logic.py @@ -1,17 +1,18 @@ from __future__ import annotations import asyncio -import json import base64 +import json import logging -from typing import Any, Union +from typing import Any from google.oauth2 import service_account from googleapiclient.discovery import build from googleapiclient.errors import HttpError from googleapiclient.http import MediaInMemoryUpload -from runtime import BaseConnector +from runtime import SDKConnector, sdk_action +from runtime.models import ErrorCategory from .exceptions import ( GoogleDriveAuthError, @@ -26,192 +27,47 @@ FilesListOperation, FilesUpdateOperation, FilesUploadOperation, - GoogleDriveOperationInput, GoogleDriveOperationOutput, PermissionsCreateOperation, ) logger = logging.getLogger("connectors.google_drive") -# Performant default for files.list so the API returns only needed metadata. DEFAULT_LIST_FIELDS = "nextPageToken, files(id, name, mimeType, webViewLink)" -_OperationUnion = Union[ - FilesCreateOperation, - FilesListOperation, - PermissionsCreateOperation, - FilesGetOperation, - FilesUpdateOperation, - FilesUploadOperation, - FilesDeleteOperation, -] - -class GoogleDriveConnector( - BaseConnector[GoogleDriveOperationInput, GoogleDriveOperationOutput], -): +class GoogleDriveConnector(SDKConnector): """ - Google Drive connector for files and permissions operations. + Google Drive connector: each Drive operation is an @sdk_action method. """ connector_id = "google_drive" action = "execute" + output_model = GoogleDriveOperationOutput - async def internal_execute( - self, params: GoogleDriveOperationInput, *, trace_id: str - ) -> GoogleDriveOperationOutput: - logger.info( - "Executing Google Drive operation", - extra={ - "trace_id": trace_id, - "connector_id": self.connector_id, - "action": self.action, - "action_type": params.root.action, - }, - ) - - drive = self._build_client() + error_map = { + GoogleDriveAuthError: (ErrorCategory.AUTH, "GDRIVE_AUTH"), + GoogleDriveRateLimitError: (ErrorCategory.RETRYABLE, "GDRIVE_RATE_LIMIT"), + GoogleDriveBusinessError: (ErrorCategory.BUSINESS, "GDRIVE_BUSINESS_RULE"), + GoogleDriveFatalError: (ErrorCategory.FATAL, "GDRIVE_FATAL"), + } + def build_client(self) -> Any: + raw_sa = self.secret_provider.get_secret("GOOGLE_DRIVE_SA_JSON") try: - response = await asyncio.to_thread( - self._dispatch_to_sdk, drive, params.root - ) - return GoogleDriveOperationOutput( - raw=response, - description=f"Successfully executed {params.root.action}", - ) - except HttpError as exc: - self._translate_and_raise_http_error(exc) - except Exception as exc: # noqa: BLE001 - logger.error( - "Unexpected SDK failure", - extra={ - "trace_id": trace_id, - "connector_id": self.connector_id, - "action": self.action, - "error_type": type(exc).__name__, - "error_message": str(exc), - }, - ) - raise GoogleDriveFatalError(str(exc)) from exc - - def _dispatch_to_sdk( - self, drive: Any, params: _OperationUnion - ) -> dict[str, Any]: - """Routes the strictly validated model to the correct SDK method.""" - if params.action == "files.create": - body = { - "name": params.name, - "mimeType": params.mime_type, - "parents": params.parents, - } - body = {k: v for k, v in body.items() if v is not None} - return drive.files().create(body=body, fields='id, name, webViewLink', - supportsAllDrives=True, - ).execute() - - if params.action == "files.list": - fields = params.fields or DEFAULT_LIST_FIELDS - result = drive.files().list( - pageSize=params.page_size, - q=params.query, - fields=fields, - supportsAllDrives=True, - includeItemsFromAllDrives=True, - ).execute() - return result - - if params.action == "permissions.create": - body = { - "role": params.role, - "type": params.type, - "emailAddress": params.email_address, - } - return drive.permissions().create( - fileId=params.file_id, - body=body, - supportsAllDrives=True, - ).execute() - - if params.action == "files.get": - fields = params.fields or "id,name,mimeType,parents" - return ( - drive.files() - .get( - fileId=params.file_id, - fields=fields, - supportsAllDrives=True, - ) - .execute() - ) - - if params.action == "files.update": - body: dict[str, Any] = {} - if params.name is not None: - body["name"] = params.name - if params.mime_type is not None: - body["mimeType"] = params.mime_type - - kwargs: dict[str, Any] = {} - if params.add_parents: - kwargs["addParents"] = ",".join(params.add_parents) - if params.remove_parents: - kwargs["removeParents"] = ",".join(params.remove_parents) - - return ( - drive.files() - .update( - fileId=params.file_id, - body=body, - **kwargs, - supportsAllDrives=True, - ) - .execute() - ) - - if params.action == "files.upload": - body = { - "name": params.name, - "mimeType": params.mime_type, - "parents": params.parents, - } - body = {k: v for k, v in body.items() if v is not None} - - if params.content_base64 is not None: - media_bytes = base64.b64decode(params.content_base64) - elif params.content is not None: - media_bytes = params.content.encode("utf-8") - else: - raise ValueError("Either content or content_base64 must be provided for files.upload") - - media = MediaInMemoryUpload( - media_bytes, - mimetype=params.mime_type, - resumable=False, + info = json.loads(raw_sa) + creds = service_account.Credentials.from_service_account_info( + info, + scopes=["https://www.googleapis.com/auth/drive"], ) - - return ( - drive.files() - .create( - body=body, - media_body=media, - fields='id, name, webViewLink', - supportsAllDrives=True, - ) - .execute() + except json.JSONDecodeError: + creds = service_account.Credentials.from_service_account_file( + raw_sa.strip(), + scopes=["https://www.googleapis.com/auth/drive"], ) - - if params.action == "files.delete": - drive.files().update(fileId=params.file_id, - body={'trashed': True}, - supportsAllDrives=True, - ).execute() - return {"file_id": params.file_id, "status": "deleted"} - - raise ValueError(f"Unmapped action router: {params.action}") + return build("drive", "v3", credentials=creds) def _translate_and_raise_http_error(self, exc: HttpError) -> None: - """Translates Google's dynamic HTTP errors into static taxonomy classes.""" status = exc.resp.status content_str = str(getattr(exc, "content", "") or "") @@ -220,9 +76,7 @@ def _translate_and_raise_http_error(self, exc: HttpError) -> None: raise GoogleDriveRateLimitError( "Google Drive quota/rate limit exceeded" ) from exc - raise GoogleDriveAuthError( - "Authentication or permissions failure" - ) from exc + raise GoogleDriveAuthError("Authentication or permissions failure") from exc if status == 429 or status >= 500: raise GoogleDriveRateLimitError( @@ -231,26 +85,207 @@ def _translate_and_raise_http_error(self, exc: HttpError) -> None: if status in (400, 404, 409): reason = getattr(exc, "reason", str(exc)) - raise GoogleDriveBusinessError( - f"Business logic failure: {reason}" - ) from exc + raise GoogleDriveBusinessError(f"Business logic failure: {reason}") from exc - raise GoogleDriveFatalError( - f"Unhandled HttpError status {status}" - ) from exc + raise GoogleDriveFatalError(f"Unhandled HttpError status {status}") from exc - def _build_client(self) -> Any: - raw_sa = self.secret_provider.get_secret("GOOGLE_DRIVE_SA_JSON") + @sdk_action("files.create") + async def files_create( + self, params: FilesCreateOperation, *, trace_id: str + ) -> GoogleDriveOperationOutput: + logger.info("Google Drive files.create", extra={"trace_id": trace_id}) + drive = self.get_client() + body = {k: v for k, v in { + "name": params.name, + "mimeType": params.mime_type, + "parents": params.parents, + }.items() if v is not None} try: - info = json.loads(raw_sa) - creds = service_account.Credentials.from_service_account_info( - info, - scopes=["https://www.googleapis.com/auth/drive"], + result = await asyncio.to_thread( + lambda: drive.files().create( + body=body, + fields="id, name, webViewLink", + supportsAllDrives=True, + ).execute() ) - except json.JSONDecodeError: - # Fallback: treat the secret as a file path - creds = service_account.Credentials.from_service_account_file( - raw_sa.strip(), - scopes=["https://www.googleapis.com/auth/drive"], + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw=result, description="Successfully executed files.create" + ) + + @sdk_action("files.list") + async def files_list( + self, params: FilesListOperation, *, trace_id: str + ) -> GoogleDriveOperationOutput: + logger.info("Google Drive files.list", extra={"trace_id": trace_id}) + drive = self.get_client() + fields = params.fields or DEFAULT_LIST_FIELDS + try: + result = await asyncio.to_thread( + lambda: drive.files().list( + pageSize=params.page_size, + q=params.query, + fields=fields, + pageToken=params.page_token, + supportsAllDrives=True, + includeItemsFromAllDrives=True, + ).execute() ) - return build("drive", "v3", credentials=creds) + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw=result, description="Successfully executed files.list" + ) + + @sdk_action("files.get") + async def files_get( + self, params: FilesGetOperation, *, trace_id: str + ) -> GoogleDriveOperationOutput: + logger.info( + "Google Drive files.get", + extra={"trace_id": trace_id, "file_id": params.file_id}, + ) + drive = self.get_client() + fields = params.fields or "id,name,mimeType,parents" + try: + result = await asyncio.to_thread( + lambda: drive.files().get( + fileId=params.file_id, + fields=fields, + supportsAllDrives=True, + ).execute() + ) + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw=result, description="Successfully executed files.get" + ) + + @sdk_action("files.update") + async def files_update( + self, params: FilesUpdateOperation, *, trace_id: str + ) -> GoogleDriveOperationOutput: + logger.info( + "Google Drive files.update", + extra={"trace_id": trace_id, "file_id": params.file_id}, + ) + drive = self.get_client() + body: dict[str, Any] = {} + if params.name is not None: + body["name"] = params.name + if params.mime_type is not None: + body["mimeType"] = params.mime_type + kwargs: dict[str, Any] = {} + if params.add_parents: + kwargs["addParents"] = ",".join(params.add_parents) + if params.remove_parents: + kwargs["removeParents"] = ",".join(params.remove_parents) + try: + result = await asyncio.to_thread( + lambda: drive.files().update( + fileId=params.file_id, + body=body, + supportsAllDrives=True, + **kwargs, + ).execute() + ) + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw=result, description="Successfully executed files.update" + ) + + @sdk_action("files.upload") + async def files_upload( + self, params: FilesUploadOperation, *, trace_id: str + ) -> GoogleDriveOperationOutput: + logger.info("Google Drive files.upload", extra={"trace_id": trace_id}) + drive = self.get_client() + body = {k: v for k, v in { + "name": params.name, + "mimeType": params.mime_type, + "parents": params.parents, + }.items() if v is not None} + if params.content_base64 is not None: + media_bytes = base64.b64decode(params.content_base64) + elif params.content is not None: + media_bytes = params.content.encode("utf-8") + else: + raise ValueError( + "Either content or content_base64 must be provided for files.upload" + ) + media = MediaInMemoryUpload( + media_bytes, + mimetype=params.mime_type, + resumable=False, + ) + try: + result = await asyncio.to_thread( + lambda: drive.files().create( + body=body, + media_body=media, + fields="id, name, webViewLink", + supportsAllDrives=True, + ).execute() + ) + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw=result, description="Successfully executed files.upload" + ) + + @sdk_action("files.delete") + async def files_delete( + self, params: FilesDeleteOperation, *, trace_id: str + ) -> GoogleDriveOperationOutput: + logger.info( + "Google Drive files.delete", + extra={"trace_id": trace_id, "file_id": params.file_id}, + ) + drive = self.get_client() + try: + await asyncio.to_thread( + lambda: drive.files().update( + fileId=params.file_id, + body={"trashed": True}, + supportsAllDrives=True, + ).execute() + ) + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw={"file_id": params.file_id, "status": "deleted"}, + description="Successfully executed files.delete", + ) + + @sdk_action("permissions.create") + async def permissions_create( + self, params: PermissionsCreateOperation, *, trace_id: str + ) -> GoogleDriveOperationOutput: + logger.info( + "Google Drive permissions.create", + extra={"trace_id": trace_id, "file_id": params.file_id}, + ) + drive = self.get_client() + body: dict[str, Any] = { + "role": params.role, + "type": params.type, + } + if params.email_address: + body["emailAddress"] = params.email_address + if params.domain: + body["domain"] = params.domain + try: + result = await asyncio.to_thread( + lambda: drive.permissions().create( + fileId=params.file_id, + body=body, + supportsAllDrives=True, + ).execute() + ) + except HttpError as exc: + self._translate_and_raise_http_error(exc) + return GoogleDriveOperationOutput( + raw=result, description="Successfully executed permissions.create" + ) diff --git a/src/connectors/google_drive/schema.py b/src/connectors/google_drive/schema.py index a2f22e8..9d516e9 100644 --- a/src/connectors/google_drive/schema.py +++ b/src/connectors/google_drive/schema.py @@ -11,9 +11,6 @@ class BaseDriveOperation(BaseModel): model_config = ConfigDict(extra="forbid") -# --- Specific Operation Schemas --- - - class FilesCreateOperation(BaseDriveOperation): action: Literal["files.create"] name: str = Field(..., description="The name of the file.") @@ -32,6 +29,10 @@ class FilesListOperation(BaseDriveOperation): "uses a performant default: nextPageToken, files(id, name, mimeType, webViewLink)." ), ) + page_token: Optional[str] = Field( + None, + description="Token for the next page of results from a previous files.list response.", + ) class PermissionsCreateOperation(BaseDriveOperation): @@ -91,11 +92,7 @@ class FilesDeleteOperation(BaseDriveOperation): file_id: str -# --- The Envelope --- -# The runtime validates against this single type. Pydantic automatically -# routes the validation to the correct sub-model based on the "action" field. -# RootModel accepts **raw_input in __init__ so BaseConnector's _input_model_cls(**raw_input) works. -_OperationUnion = Annotated[ +_GoogleDriveOperationUnion = Annotated[ Union[ FilesCreateOperation, FilesListOperation, @@ -108,9 +105,10 @@ class FilesDeleteOperation(BaseDriveOperation): Field(discriminator="action"), ] -GoogleDriveOperationInput = RootModel[_OperationUnion] +# Discriminated union for tests/agents; must stay aligned with GoogleDriveConnector @sdk_action set. +GoogleDriveOperationInput = RootModel[_GoogleDriveOperationUnion] class GoogleDriveOperationOutput(BaseModel): raw: Dict[str, Any] - description: str \ No newline at end of file + description: str diff --git a/src/connectors/http_generic/logic.py b/src/connectors/http_generic/logic.py index 88afc67..2536cf6 100644 --- a/src/connectors/http_generic/logic.py +++ b/src/connectors/http_generic/logic.py @@ -38,8 +38,6 @@ async def internal_execute(self, params: HttpRequestInput, *, trace_id: str) -> }, ) - print(f"trace_id: {trace_id} from node-wire-connector") - try: async with httpx.AsyncClient() as client: response = await client.request( diff --git a/src/connectors/manifest.py b/src/connectors/manifest.py index ffb5cfa..56984d9 100644 --- a/src/connectors/manifest.py +++ b/src/connectors/manifest.py @@ -4,76 +4,34 @@ from pydantic import BaseModel -from runtime import BaseConnector +from runtime import BaseConnector, SDKConnector def _schema_for(model: Type[BaseModel]) -> Dict[str, Any]: return model.model_json_schema() -def _fhir_action_schemas() -> Dict[str, Dict[str, Type[BaseModel]]]: - """Return per-action input model classes for FHIR connectors (lazy import).""" - from connectors.fhir_cerner.schema import ( - FhirCernerDocumentReferenceCreateInput, - FhirCernerDocumentReferenceSearchInput, - FhirCernerEncounterSearchInput, - FhirCernerPatientReadInput, - FhirCernerPatientSearchInput, - ) - from connectors.fhir_epic.schema import ( - FhirDocumentReferenceCreateInput, - FhirDocumentReferenceSearchInput, - FhirEncounterSearchInput, - FhirPatientReadInput, - FhirPatientSearchInput, - ) - - return { - "fhir_cerner": { - "read_patient": FhirCernerPatientReadInput, - "search_patients": FhirCernerPatientSearchInput, - "search_encounter": FhirCernerEncounterSearchInput, - "create_document_reference": FhirCernerDocumentReferenceCreateInput, - "search_document_reference": FhirCernerDocumentReferenceSearchInput, - }, - "fhir_epic": { - "read_patient": FhirPatientReadInput, - "search_patients": FhirPatientSearchInput, - "search_encounter": FhirEncounterSearchInput, - "create_document_reference": FhirDocumentReferenceCreateInput, - "search_document_reference": FhirDocumentReferenceSearchInput, - }, - } - - def build_manifest(connectors: List[BaseConnector[Any, Any]]) -> List[Dict[str, Any]]: """ - Build a simple manifest for discovery. - - Each entry describes a connector/action pair and includes JSON Schemas - for the input and output models. This is consumed by Layer C for - REST route generation and MCP tool manifests. + One manifest entry per SDK @sdk_action (specific input/output schemas), + or one entry per legacy BaseConnector. """ manifest: List[Dict[str, Any]] = [] - fhir_schemas: Dict[str, Dict[str, Type[BaseModel]]] | None = None - for connector in connectors: - output_model = connector._output_model_cls # type: ignore[attr-defined] cid = connector.connector_id - if getattr(connector, "action", None) == "execute" and cid in ("fhir_cerner", "fhir_epic"): - if fhir_schemas is None: - fhir_schemas = _fhir_action_schemas() - for sub_action, input_cls in fhir_schemas[cid].items(): + if isinstance(connector, SDKConnector): + for action_name, meta in type(connector).sdk_action_metas().items(): manifest.append( { "connector_id": cid, - "action": sub_action, - "input_schema": _schema_for(input_cls), - "output_schema": _schema_for(output_model), + "action": action_name, + "input_schema": _schema_for(meta.input_model), + "output_schema": _schema_for(meta.output_model), } ) else: input_model = connector._input_model_cls # type: ignore[attr-defined] + output_model = connector._output_model_cls # type: ignore[attr-defined] manifest.append( { "connector_id": cid, @@ -83,4 +41,3 @@ def build_manifest(connectors: List[BaseConnector[Any, Any]]) -> List[Dict[str, } ) return manifest - diff --git a/src/connectors/stripe/logic.py b/src/connectors/stripe/logic.py index 14e973f..aefa296 100644 --- a/src/connectors/stripe/logic.py +++ b/src/connectors/stripe/logic.py @@ -1,54 +1,67 @@ from __future__ import annotations +import asyncio import logging import stripe -from runtime import BaseConnector +from runtime import SDKConnector, sdk_action +from runtime.models import ErrorCategory from .schema import ChargeInput, ChargeOutput logger = logging.getLogger("connectors.stripe") -class StripeChargeConnector(BaseConnector[ChargeInput, ChargeOutput]): - """ - Stripe connector for creating charges using the official Stripe SDK. - """ +class StripeConnector(SDKConnector): + """Stripe connector: charges and future SDK operations as @sdk_action methods.""" connector_id = "stripe" action = "charge" + output_model = ChargeOutput - async def internal_execute(self, params: ChargeInput, *, trace_id: str) -> ChargeOutput: - # API key is expected to be provided by SecretProvider. + error_map = { + stripe.error.RateLimitError: (ErrorCategory.RETRYABLE, "STRIPE_RATE_LIMIT"), + stripe.error.APIConnectionError: (ErrorCategory.RETRYABLE, "STRIPE_API_CONNECTION"), + stripe.error.CardError: (ErrorCategory.BUSINESS, "STRIPE_CARD_ERROR"), + stripe.error.InvalidRequestError: (ErrorCategory.BUSINESS, "STRIPE_INVALID_REQUEST"), + stripe.error.AuthenticationError: (ErrorCategory.AUTH, "STRIPE_AUTH_ERROR"), + stripe.error.StripeError: (ErrorCategory.FATAL, "STRIPE_ERROR"), + } + + @sdk_action("charge") + async def charge(self, params: ChargeInput, *, trace_id: str) -> ChargeOutput: api_key = self.secret_provider.get_secret("stripe_api_key") - stripe.api_key = api_key logger.info( "Creating Stripe charge", extra={ "trace_id": trace_id, "connector_id": self.connector_id, - "action": self.action, + "action": "charge", "amount": params.amount, "currency": params.currency, }, ) - try: - charge = await stripe.Charge.create( # type: ignore[attr-defined] + def _create() -> stripe.Charge: + stripe.api_key = api_key + return stripe.Charge.create( amount=params.amount, currency=params.currency, source=params.source, description=params.description, ) - except Exception as exc: # noqa: BLE001 + + try: + charge = await asyncio.to_thread(_create) + except Exception as exc: logger.error( "Stripe charge creation failed", extra={ "trace_id": trace_id, "connector_id": self.connector_id, - "action": self.action, + "action": "charge", "amount": params.amount, "currency": params.currency, "error_type": type(exc).__name__, @@ -62,7 +75,7 @@ async def internal_execute(self, params: ChargeInput, *, trace_id: str) -> Charg extra={ "trace_id": trace_id, "connector_id": self.connector_id, - "action": self.action, + "action": "charge", "charge_id": charge.get("id"), }, ) @@ -71,4 +84,3 @@ async def internal_execute(self, params: ChargeInput, *, trace_id: str) -> Charg charge_id=charge.get("id"), receipt_url=charge.get("receipt_url"), ) - diff --git a/src/connectors/stripe/schema.py b/src/connectors/stripe/schema.py index bf7e6f6..e912829 100644 --- a/src/connectors/stripe/schema.py +++ b/src/connectors/stripe/schema.py @@ -1,9 +1,12 @@ from __future__ import annotations +from typing import Literal + from pydantic import BaseModel class ChargeInput(BaseModel): + action: Literal["charge"] = "charge" amount: int currency: str source: str @@ -13,4 +16,3 @@ class ChargeInput(BaseModel): class ChargeOutput(BaseModel): charge_id: str receipt_url: str | None = None - diff --git a/src/runtime/__init__.py b/src/runtime/__init__.py index 76d63e9..1e5c11f 100644 --- a/src/runtime/__init__.py +++ b/src/runtime/__init__.py @@ -3,6 +3,7 @@ from .base import BaseConnector from .secrets import SecretProvider from .policy import PolicyHook, PolicyDenied +from .sdk_connector import SDKConnector, sdk_action, _CONNECTOR_REGISTRY __all__ = [ "ConnectorResponse", @@ -12,4 +13,7 @@ "SecretProvider", "PolicyHook", "PolicyDenied", + "SDKConnector", + "sdk_action", + "_CONNECTOR_REGISTRY", ] diff --git a/src/runtime/base.py b/src/runtime/base.py index 25596d9..e470350 100644 --- a/src/runtime/base.py +++ b/src/runtime/base.py @@ -74,7 +74,6 @@ async def run( - Maps exceptions into the standard error taxonomy """ trace_id = str(uuid.uuid4()) - print(f"trace_id: {trace_id} from runtime.base") with tracer.start_as_current_span( "connector.run", @@ -97,7 +96,7 @@ async def run( try: try: - input_model = self._input_model_cls(**raw_input) + input_model = self._input_model_cls.model_validate(raw_input) except ValidationError as exc: logger.error( "Input validation failed", diff --git a/src/runtime/sdk_connector.py b/src/runtime/sdk_connector.py new file mode 100644 index 0000000..faa9688 --- /dev/null +++ b/src/runtime/sdk_connector.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import inspect +import logging +import uuid +from dataclasses import dataclass +from typing import ( + Annotated, + Any, + ClassVar, + Dict, + Optional, + Tuple, + Type, + Union, + get_type_hints, +) + +from pydantic import BaseModel, Field, RootModel + +from .base import BaseConnector +from .errors import ErrorMapper +from .models import ErrorCategory +from .secrets import SecretProvider + +logger = logging.getLogger("runtime.sdk_connector") + +# Populated by SDKConnector.__init_subclass__ +_CONNECTOR_REGISTRY: Dict[str, Type["SDKConnector"]] = {} + + +def sdk_action(name: str): + """ + Mark a connector method as a named, auto-discoverable SDK action. + + The decorated method must be async and have full type annotations for its + params (first arg after self) and return type. + """ + + def decorator(fn: Any) -> Any: + fn._sdk_action_name = name + return fn + + return decorator + + +@dataclass +class SdkActionMeta: + """Metadata for one @sdk_action method.""" + + name: str + fn_name: str + input_model: Type[BaseModel] + output_model: Type[BaseModel] + + +class SDKConnector(BaseConnector): + """ + Base class for SDK-backed connectors. + + Subclasses define: + - connector_id: str + - output_model: Type[BaseModel] (common output envelope for all actions) + - error_map: optional mapping of exception -> (ErrorCategory, code) + - build_client() / get_client() for vendor SDK lifecycle + + Actions are declared with @sdk_action("resource.operation") on async methods. + """ + + connector_id: str + action: str = "execute" + + error_map: ClassVar[Dict[Type[BaseException], Tuple[ErrorCategory, str]]] = {} + output_model: ClassVar[Type[BaseModel]] + + _action_registry: ClassVar[Dict[str, SdkActionMeta]] + _union_input_model: ClassVar[Type[RootModel[Any]]] + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + + registry: Dict[str, SdkActionMeta] = {} + for attr_name in dir(cls): + method = getattr(cls, attr_name, None) + if not callable(method) or not hasattr(method, "_sdk_action_name"): + continue + + try: + hints = get_type_hints(method) + except Exception: + hints = {} + + try: + sig_params = [ + p + for p in inspect.signature(method).parameters.values() + if p.name not in ("self", "trace_id") + ] + input_param_name = sig_params[0].name if sig_params else None + except (ValueError, TypeError): + input_param_name = None + + if not input_param_name: + raise TypeError( + f"{cls.__name__}.{attr_name}: @sdk_action method must have a params argument " + "after self" + ) + + input_model = hints.get(input_param_name) + output_model = hints.get("return") + if input_model is None or not isinstance(input_model, type) or not issubclass( + input_model, BaseModel + ): + raise TypeError( + f"{cls.__name__}.{attr_name}: missing or invalid type hint for " + f"parameter {input_param_name!r}" + ) + if output_model is None or not isinstance(output_model, type) or not issubclass( + output_model, BaseModel + ): + raise TypeError( + f"{cls.__name__}.{attr_name}: missing or invalid return type hint" + ) + + action_name = method._sdk_action_name + registry[action_name] = SdkActionMeta( + name=action_name, + fn_name=attr_name, + input_model=input_model, + output_model=output_model, + ) + + cls._action_registry = registry + + valid_models = [m.input_model for m in registry.values()] + if not valid_models: + raise TypeError(f"{cls.__name__}: SDKConnector must define at least one @sdk_action") + + if len(valid_models) == 1: + root_type = valid_models[0] + else: + root_type = Annotated[ + Union[tuple(valid_models)], # type: ignore[arg-type] + Field(discriminator="action"), + ] + + cls._union_input_model = RootModel[root_type] # type: ignore[valid-type] + cls._union_input_model.model_rebuild() + + own_error_map = cls.__dict__.get("error_map", {}) + for exc_type, (category, code) in own_error_map.items(): + ErrorMapper.register(exc_type, category, code=code) + + if "connector_id" in cls.__dict__: + _CONNECTOR_REGISTRY[cls.connector_id] = cls + logger.debug( + "Registered SDKConnector subclass", + extra={"connector_id": cls.connector_id}, + ) + + def __init__(self, *, secret_provider: Optional[SecretProvider] = None) -> None: + cls = type(self) + super().__init__( + cls._union_input_model, + cls.output_model, + secret_provider=secret_provider, + ) + self._client: Any = None + + @classmethod + def sdk_action_metas(cls) -> Dict[str, SdkActionMeta]: + """Registry of action name -> metadata (for manifest).""" + return dict(cls._action_registry) + + def build_client(self) -> Any: + """Override in subclasses to build the vendor SDK client.""" + return None + + def get_client(self) -> Any: + if self._client is None: + self._client = self.build_client() + return self._client + + async def internal_execute(self, params: Any, *, trace_id: str) -> Any: + """Dispatch to the @sdk_action method matching the validated input.""" + root = params.root if hasattr(params, "root") else params + action_key = getattr(root, "action", None) + if action_key is None: + raise ValueError(f"Input model missing action discriminator: {type(root).__name__}") + + meta = self._action_registry.get(str(action_key)) + if meta is None: + raise ValueError( + f"Connector {self.connector_id!r} has no registered action {action_key!r}. " + f"Available: {list(self._action_registry)}" + ) + fn = getattr(self, meta.fn_name) + logger.debug( + "Dispatching sdk_action", + extra={ + "connector_id": self.connector_id, + "action": action_key, + "trace_id": trace_id, + }, + ) + return await fn(root, trace_id=trace_id) + + async def call_action(self, name: str, params_dict: Dict[str, Any]) -> Any: + """Invoke another action by name (for composite operations).""" + meta = self._action_registry.get(name) + if meta is None: + raise ValueError( + f"call_action: unknown action {name!r} on connector {self.connector_id!r}" + ) + validated = meta.input_model.model_validate(params_dict) + fn = getattr(self, meta.fn_name) + return await fn(validated, trace_id=str(uuid.uuid4())) diff --git a/tests/test_connectors_basic.py b/tests/test_connectors_basic.py index a4e633e..f5db5b7 100644 --- a/tests/test_connectors_basic.py +++ b/tests/test_connectors_basic.py @@ -8,8 +8,7 @@ from connectors.http_generic.schema import HttpRequestInput, HttpResponseOutput from connectors.smtp.logic import SmtpConnector from connectors.smtp.schema import SmtpSendInput, SmtpSendOutput -from connectors.stripe.logic import StripeChargeConnector -from connectors.stripe.schema import ChargeInput, ChargeOutput +from connectors.stripe.logic import StripeConnector from runtime import ConnectorResponse, ErrorCategory, SecretProvider from connectors import auto_register @@ -25,6 +24,7 @@ def get_secret(self, key: str) -> str: def test_auto_register_runs_without_error(): imported = auto_register() assert any("http_generic.registration" in name for name in imported) + assert any("google_drive.logic" in name for name in imported) def test_http_connector_instantiation_only(): @@ -40,7 +40,7 @@ def test_smtp_connector_instantiation_only(): def test_stripe_connector_instantiation_only(): - connector = StripeChargeConnector(ChargeInput, ChargeOutput, secret_provider=DummySecretProvider()) + connector = StripeConnector(secret_provider=DummySecretProvider()) assert connector.connector_id == "stripe" assert connector.action == "charge" diff --git a/tests/test_google_drive.py b/tests/test_google_drive.py index 286d7a2..d9768ae 100644 --- a/tests/test_google_drive.py +++ b/tests/test_google_drive.py @@ -34,11 +34,7 @@ def __init__(self, status: int, *, content: str = "", reason: str = "") -> None: def _connector() -> GoogleDriveConnector: - return GoogleDriveConnector( - input_model=GoogleDriveOperationInput, - output_model=GoogleDriveOperationOutput, - secret_provider=MockSecretProvider(), - ) + return GoogleDriveConnector(secret_provider=MockSecretProvider()) def test_google_drive_internal_execute_files_list_happy_path(): @@ -50,7 +46,7 @@ def test_google_drive_internal_execute_files_list_happy_path(): list_call = files_api.list.return_value list_call.execute.return_value = {"files": [{"id": "f-1", "name": "Report"}]} - with patch.object(connector, "_build_client", return_value=drive): + with patch.object(connector, "get_client", return_value=drive): result = asyncio.run(connector.internal_execute(params, trace_id="test-trace")) assert result.raw == {"files": [{"id": "f-1", "name": "Report"}]} @@ -59,6 +55,7 @@ def test_google_drive_internal_execute_files_list_happy_path(): pageSize=5, q=None, fields=DEFAULT_LIST_FIELDS, + pageToken=None, supportsAllDrives=True, includeItemsFromAllDrives=True, ) diff --git a/tests/test_sdk_connector_manifest.py b/tests/test_sdk_connector_manifest.py new file mode 100644 index 0000000..504d5a1 --- /dev/null +++ b/tests/test_sdk_connector_manifest.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from bindings.factory import ConnectorFactory +from connectors import auto_register +from connectors.manifest import build_manifest +from connectors.stripe.schema import ChargeInput +from runtime import SDKConnector +from runtime.sdk_connector import _CONNECTOR_REGISTRY + + +def test_registry_contains_sdk_connectors(): + auto_register() + assert "google_drive" in _CONNECTOR_REGISTRY + assert "stripe" in _CONNECTOR_REGISTRY + assert "fhir_epic" in _CONNECTOR_REGISTRY + + +def test_manifest_emits_per_sdk_action(): + auto_register() + factory = ConnectorFactory() + factory.load() + rest_manifest = build_manifest(factory.list_for_protocol("rest")) + rest_actions = {(e["connector_id"], e["action"]) for e in rest_manifest} + assert ("google_drive", "files.list") in rest_actions + assert ("fhir_epic", "read_patient") in rest_actions + assert ("stripe", "charge") not in rest_actions # stripe is grpc/mcp only in config + + mcp_manifest = build_manifest(factory.list_for_protocol("mcp")) + mcp_actions = {(e["connector_id"], e["action"]) for e in mcp_manifest} + assert ("stripe", "charge") in mcp_actions + # Per-action input schema should not be the full union for SDK connectors + for entry in mcp_manifest: + if entry["connector_id"] == "stripe": + props = entry["input_schema"].get("properties", {}) + assert "amount" in props + + +def test_stripe_connector_is_sdk_and_accepts_charge_payload(): + auto_register() + factory = ConnectorFactory() + factory.load() + connector = factory.get_for_protocol("stripe", "grpc") + assert connector is not None + assert isinstance(connector, SDKConnector) + validated = ChargeInput.model_validate( + {"action": "charge", "amount": 100, "currency": "usd", "source": "tok_visa"} + ) + assert validated.action == "charge" + + +def test_mcp_tool_invoke_sets_action(): + from bindings.mcp_server.server import McpServer + + server = McpServer() + tools = server.list_tools() + names = {t["name"] for t in tools} + assert "google_drive.files.list" in names + assert "stripe.charge" in names From cab653b92c89e1da10cad5430cdcf55f46703d96 Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Tue, 31 Mar 2026 20:58:35 -0700 Subject: [PATCH 04/15] Updated architecture --- src/connectors/google_drive/README.md | 18 +- src/connectors/google_drive/action_spec.py | 188 +++++++++++++++++ src/connectors/google_drive/logic.py | 232 +++------------------ src/connectors/google_drive/schema.py | 9 +- src/runtime/__init__.py | 10 + src/runtime/sdk_action_spec.py | 123 +++++++++++ src/runtime/sdk_connector.py | 79 +++++++ tests/test_google_drive_action_spec.py | 132 ++++++++++++ 8 files changed, 580 insertions(+), 211 deletions(-) create mode 100644 src/connectors/google_drive/action_spec.py create mode 100644 src/runtime/sdk_action_spec.py create mode 100644 tests/test_google_drive_action_spec.py diff --git a/src/connectors/google_drive/README.md b/src/connectors/google_drive/README.md index 3b44409..8d644d0 100644 --- a/src/connectors/google_drive/README.md +++ b/src/connectors/google_drive/README.md @@ -2,7 +2,7 @@ > **Platform:** Node Wire > **Connector ID:** `google_drive` -> **Endpoint:** `POST /connectors/google_drive/execute` +> **REST:** One route per operation, e.g. `POST /connectors/google_drive/files.list` (the `action` field is still set on the body for `SDKConnector` dispatch). > **Discriminator:** `action` field (discriminated-union payload) > **Source:** `connectors/google_drive/` @@ -10,7 +10,21 @@ ## 1. Operations Overview -All requests go through a single `execute` endpoint. The `action` field determines which Google Drive operation runs. All responses share a common output shape and error taxonomy enforced by the runtime. +The runtime validates requests against the discriminated union in `schema.py`, then dispatches to `@sdk_action` handlers on `GoogleDriveConnector`. Each handler delegates to an **action spec** in `action_spec.py` that maps the validated model to the Google Drive API v3 client (`googleapiclient`). Shared concerns (thread offload, `HttpError` translation, logging) stay in `logic.py`. All responses share a common output shape and error taxonomy enforced by the runtime. + +### Action-spec layout + +| Piece | Role | +|-------|------| +| [`action_spec.py`](action_spec.py) | `GOOGLE_DRIVE_ACTION_SPECS`: per-action `SdkActionSpec` (resource path, method, field/body mapping, constants, optional `build_kwargs` / `post_process`). | +| [`logic.py`](logic.py) | Client build, `_translate_and_raise_http_error`, `_execute_action_spec`, thin `@sdk_action` methods. | +| [`runtime/sdk_action_spec.py`](../../runtime/sdk_action_spec.py) | Reusable primitives: `SdkActionSpec`, `default_build_kwargs`, `execute_spec_in_thread`. | + +**Adding a new operation:** Add a Pydantic variant in `schema.py` (with an `action` discriminator literal), extend the `GoogleDriveOperationInput` union, and add an entry to `GOOGLE_DRIVE_ACTION_SPECS` in `action_spec.py` (or a `build_kwargs` hook for non-generic cases such as multipart upload). `SDKConnector.__init_subclass__` auto-generates the handler — do **not** also add an `@sdk_action` method for the same action name, as that will raise a `TypeError` at class-definition time. + +### Migrating other SDK connectors + +Use the same pattern: put declarative mapping in a connector-local `*_action_spec` module; `SDKConnector.__init_subclass__` auto-generates `@sdk_action`-equivalent handlers from `action_specs`, so no manual `@sdk_action` decorators are needed for spec-driven actions. Use `SdkActionSpec.build_kwargs` when the vendor API needs custom assembly (uploads, explicit `None` args, etc.). ### Available Operations diff --git a/src/connectors/google_drive/action_spec.py b/src/connectors/google_drive/action_spec.py new file mode 100644 index 0000000..4b32fee --- /dev/null +++ b/src/connectors/google_drive/action_spec.py @@ -0,0 +1,188 @@ +""" +Google Drive action specs: mapping from validated Pydantic inputs to Drive API v3 calls. + +Used by GoogleDriveConnector to reduce per-action boilerplate while preserving +behavior (defaults, field masks, shared drives flags). +""" + +from __future__ import annotations + +import base64 +from typing import Any, Dict + +from googleapiclient.http import MediaInMemoryUpload +from pydantic import BaseModel + +from runtime.sdk_action_spec import SdkActionSpec + +from .schema import ( + FilesCreateOperation, + FilesDeleteOperation, + FilesGetOperation, + FilesListOperation, + FilesUpdateOperation, + FilesUploadOperation, + PermissionsCreateOperation, +) + +DEFAULT_LIST_FIELDS = "nextPageToken, files(id, name, mimeType, webViewLink)" + +# Action name -> SdkActionSpec (matches @sdk_action("...") strings) +GOOGLE_DRIVE_ACTION_SPECS: Dict[str, SdkActionSpec] = {} + + +def _register_files_create() -> None: + GOOGLE_DRIVE_ACTION_SPECS["files.create"] = SdkActionSpec( + resource_segments=("files",), + method_name="create", + body_from_model={ + "name": "name", + "mime_type": "mimeType", + "parents": "parents", + }, + constant_kwargs={ + "fields": "id, name, webViewLink", + "supportsAllDrives": True, + }, + input_model=FilesCreateOperation, + ) + + +def _build_files_list_kwargs(_drive: Any, model: BaseModel) -> Dict[str, Any]: + """Match legacy behavior: pass q/pageToken explicitly even when None.""" + p = model if isinstance(model, FilesListOperation) else FilesListOperation.model_validate( + model + ) + return { + "pageSize": p.page_size, + "q": p.query, + "fields": p.fields or DEFAULT_LIST_FIELDS, + "pageToken": p.page_token, + "supportsAllDrives": True, + "includeItemsFromAllDrives": True, + } + + +def _register_files_list() -> None: + GOOGLE_DRIVE_ACTION_SPECS["files.list"] = SdkActionSpec( + resource_segments=("files",), + method_name="list", + build_kwargs=_build_files_list_kwargs, + input_model=FilesListOperation, + ) + + +def _register_files_get() -> None: + GOOGLE_DRIVE_ACTION_SPECS["files.get"] = SdkActionSpec( + resource_segments=("files",), + method_name="get", + kwargs_from_model={"file_id": "fileId"}, + computed_kwargs={ + "fields": lambda p: p.fields or "id,name,mimeType,parents", + }, + constant_kwargs={"supportsAllDrives": True}, + input_model=FilesGetOperation, + ) + + +def _register_files_update() -> None: + GOOGLE_DRIVE_ACTION_SPECS["files.update"] = SdkActionSpec( + resource_segments=("files",), + method_name="update", + kwargs_from_model={"file_id": "fileId"}, + body_from_model={ + "name": "name", + "mime_type": "mimeType", + }, + computed_kwargs={ + "addParents": lambda p: ",".join(p.add_parents) if p.add_parents else None, + "removeParents": lambda p: ",".join(p.remove_parents) if p.remove_parents else None, + }, + constant_kwargs={"supportsAllDrives": True}, + include_empty_body=True, + input_model=FilesUpdateOperation, + ) + + +def _build_upload_kwargs(drive: Any, model: BaseModel) -> Dict[str, Any]: + params = model if isinstance(model, FilesUploadOperation) else FilesUploadOperation.model_validate( + model + ) + body = {k: v for k, v in { + "name": params.name, + "mimeType": params.mime_type, + "parents": params.parents, + }.items() if v is not None} + if params.content_base64 is not None: + media_bytes = base64.b64decode(params.content_base64) + elif params.content is not None: + media_bytes = params.content.encode("utf-8") + else: + raise ValueError( + "Either content or content_base64 must be provided for files.upload" + ) + media = MediaInMemoryUpload( + media_bytes, + mimetype=params.mime_type, + resumable=False, + ) + return { + "body": body, + "media_body": media, + "fields": "id, name, webViewLink", + "supportsAllDrives": True, + } + + +def _register_files_upload() -> None: + GOOGLE_DRIVE_ACTION_SPECS["files.upload"] = SdkActionSpec( + resource_segments=("files",), + method_name="create", + build_kwargs=_build_upload_kwargs, + input_model=FilesUploadOperation, + ) + + +def _register_files_delete() -> None: + def _post_delete(_result: Any, model: BaseModel) -> Dict[str, Any]: + file_id = getattr(model, "file_id", None) + return {"file_id": file_id, "status": "deleted"} + + GOOGLE_DRIVE_ACTION_SPECS["files.delete"] = SdkActionSpec( + resource_segments=("files",), + method_name="update", + kwargs_from_model={"file_id": "fileId"}, + body_constant={"trashed": True}, + constant_kwargs={"supportsAllDrives": True}, + post_process=_post_delete, + input_model=FilesDeleteOperation, + ) + + +def _register_permissions_create() -> None: + GOOGLE_DRIVE_ACTION_SPECS["permissions.create"] = SdkActionSpec( + resource_segments=("permissions",), + method_name="create", + kwargs_from_model={"file_id": "fileId"}, + body_from_model={ + "role": "role", + "type": "type", + "email_address": "emailAddress", + "domain": "domain", + }, + constant_kwargs={"supportsAllDrives": True}, + input_model=PermissionsCreateOperation, + ) + + +def _init_specs() -> None: + _register_files_create() + _register_files_list() + _register_files_get() + _register_files_update() + _register_files_upload() + _register_files_delete() + _register_permissions_create() + + +_init_specs() diff --git a/src/connectors/google_drive/logic.py b/src/connectors/google_drive/logic.py index 36e9107..f3a5408 100644 --- a/src/connectors/google_drive/logic.py +++ b/src/connectors/google_drive/logic.py @@ -1,7 +1,5 @@ from __future__ import annotations -import asyncio -import base64 import json import logging from typing import Any @@ -9,41 +7,36 @@ from google.oauth2 import service_account from googleapiclient.discovery import build from googleapiclient.errors import HttpError -from googleapiclient.http import MediaInMemoryUpload -from runtime import SDKConnector, sdk_action +from runtime import SDKConnector from runtime.models import ErrorCategory +from runtime.sdk_action_spec import execute_spec_in_thread +from .action_spec import DEFAULT_LIST_FIELDS, GOOGLE_DRIVE_ACTION_SPECS from .exceptions import ( GoogleDriveAuthError, GoogleDriveBusinessError, GoogleDriveFatalError, GoogleDriveRateLimitError, ) -from .schema import ( - FilesCreateOperation, - FilesDeleteOperation, - FilesGetOperation, - FilesListOperation, - FilesUpdateOperation, - FilesUploadOperation, - GoogleDriveOperationOutput, - PermissionsCreateOperation, -) +from .schema import GoogleDriveOperationOutput logger = logging.getLogger("connectors.google_drive") -DEFAULT_LIST_FIELDS = "nextPageToken, files(id, name, mimeType, webViewLink)" +# Re-export for tests and callers that imported from logic. +__all__ = ["DEFAULT_LIST_FIELDS", "GoogleDriveConnector"] class GoogleDriveConnector(SDKConnector): """ - Google Drive connector: each Drive operation is an @sdk_action method. + Google Drive connector: Drive API v3 operations are driven by action specs + (see action_spec.py) and thin @sdk_action handlers for logging and dispatch. """ connector_id = "google_drive" action = "execute" output_model = GoogleDriveOperationOutput + action_specs = GOOGLE_DRIVE_ACTION_SPECS error_map = { GoogleDriveAuthError: (ErrorCategory.AUTH, "GDRIVE_AUTH"), @@ -89,203 +82,26 @@ def _translate_and_raise_http_error(self, exc: HttpError) -> None: raise GoogleDriveFatalError(f"Unhandled HttpError status {status}") from exc - @sdk_action("files.create") - async def files_create( - self, params: FilesCreateOperation, *, trace_id: str - ) -> GoogleDriveOperationOutput: - logger.info("Google Drive files.create", extra={"trace_id": trace_id}) - drive = self.get_client() - body = {k: v for k, v in { - "name": params.name, - "mimeType": params.mime_type, - "parents": params.parents, - }.items() if v is not None} - try: - result = await asyncio.to_thread( - lambda: drive.files().create( - body=body, - fields="id, name, webViewLink", - supportsAllDrives=True, - ).execute() - ) - except HttpError as exc: - self._translate_and_raise_http_error(exc) - return GoogleDriveOperationOutput( - raw=result, description="Successfully executed files.create" - ) - - @sdk_action("files.list") - async def files_list( - self, params: FilesListOperation, *, trace_id: str + async def _execute_action_spec( + self, + action_name: str, + params: Any, + *, + trace_id: str, + log_extra: dict[str, Any] | None = None, ) -> GoogleDriveOperationOutput: - logger.info("Google Drive files.list", extra={"trace_id": trace_id}) + spec = GOOGLE_DRIVE_ACTION_SPECS.get(action_name) + if spec is None: + raise ValueError(f"No action spec registered for {action_name!r}") drive = self.get_client() - fields = params.fields or DEFAULT_LIST_FIELDS + extra = {"trace_id": trace_id, **(log_extra or {})} + logger.info("Google Drive %s", action_name, extra=extra) try: - result = await asyncio.to_thread( - lambda: drive.files().list( - pageSize=params.page_size, - q=params.query, - fields=fields, - pageToken=params.page_token, - supportsAllDrives=True, - includeItemsFromAllDrives=True, - ).execute() - ) - except HttpError as exc: - self._translate_and_raise_http_error(exc) - return GoogleDriveOperationOutput( - raw=result, description="Successfully executed files.list" - ) - - @sdk_action("files.get") - async def files_get( - self, params: FilesGetOperation, *, trace_id: str - ) -> GoogleDriveOperationOutput: - logger.info( - "Google Drive files.get", - extra={"trace_id": trace_id, "file_id": params.file_id}, - ) - drive = self.get_client() - fields = params.fields or "id,name,mimeType,parents" - try: - result = await asyncio.to_thread( - lambda: drive.files().get( - fileId=params.file_id, - fields=fields, - supportsAllDrives=True, - ).execute() - ) - except HttpError as exc: - self._translate_and_raise_http_error(exc) - return GoogleDriveOperationOutput( - raw=result, description="Successfully executed files.get" - ) - - @sdk_action("files.update") - async def files_update( - self, params: FilesUpdateOperation, *, trace_id: str - ) -> GoogleDriveOperationOutput: - logger.info( - "Google Drive files.update", - extra={"trace_id": trace_id, "file_id": params.file_id}, - ) - drive = self.get_client() - body: dict[str, Any] = {} - if params.name is not None: - body["name"] = params.name - if params.mime_type is not None: - body["mimeType"] = params.mime_type - kwargs: dict[str, Any] = {} - if params.add_parents: - kwargs["addParents"] = ",".join(params.add_parents) - if params.remove_parents: - kwargs["removeParents"] = ",".join(params.remove_parents) - try: - result = await asyncio.to_thread( - lambda: drive.files().update( - fileId=params.file_id, - body=body, - supportsAllDrives=True, - **kwargs, - ).execute() - ) + raw = await execute_spec_in_thread(drive, spec, params) except HttpError as exc: self._translate_and_raise_http_error(exc) return GoogleDriveOperationOutput( - raw=result, description="Successfully executed files.update" + raw=raw, + description=f"Successfully executed {action_name}", ) - @sdk_action("files.upload") - async def files_upload( - self, params: FilesUploadOperation, *, trace_id: str - ) -> GoogleDriveOperationOutput: - logger.info("Google Drive files.upload", extra={"trace_id": trace_id}) - drive = self.get_client() - body = {k: v for k, v in { - "name": params.name, - "mimeType": params.mime_type, - "parents": params.parents, - }.items() if v is not None} - if params.content_base64 is not None: - media_bytes = base64.b64decode(params.content_base64) - elif params.content is not None: - media_bytes = params.content.encode("utf-8") - else: - raise ValueError( - "Either content or content_base64 must be provided for files.upload" - ) - media = MediaInMemoryUpload( - media_bytes, - mimetype=params.mime_type, - resumable=False, - ) - try: - result = await asyncio.to_thread( - lambda: drive.files().create( - body=body, - media_body=media, - fields="id, name, webViewLink", - supportsAllDrives=True, - ).execute() - ) - except HttpError as exc: - self._translate_and_raise_http_error(exc) - return GoogleDriveOperationOutput( - raw=result, description="Successfully executed files.upload" - ) - - @sdk_action("files.delete") - async def files_delete( - self, params: FilesDeleteOperation, *, trace_id: str - ) -> GoogleDriveOperationOutput: - logger.info( - "Google Drive files.delete", - extra={"trace_id": trace_id, "file_id": params.file_id}, - ) - drive = self.get_client() - try: - await asyncio.to_thread( - lambda: drive.files().update( - fileId=params.file_id, - body={"trashed": True}, - supportsAllDrives=True, - ).execute() - ) - except HttpError as exc: - self._translate_and_raise_http_error(exc) - return GoogleDriveOperationOutput( - raw={"file_id": params.file_id, "status": "deleted"}, - description="Successfully executed files.delete", - ) - - @sdk_action("permissions.create") - async def permissions_create( - self, params: PermissionsCreateOperation, *, trace_id: str - ) -> GoogleDriveOperationOutput: - logger.info( - "Google Drive permissions.create", - extra={"trace_id": trace_id, "file_id": params.file_id}, - ) - drive = self.get_client() - body: dict[str, Any] = { - "role": params.role, - "type": params.type, - } - if params.email_address: - body["emailAddress"] = params.email_address - if params.domain: - body["domain"] = params.domain - try: - result = await asyncio.to_thread( - lambda: drive.permissions().create( - fileId=params.file_id, - body=body, - supportsAllDrives=True, - ).execute() - ) - except HttpError as exc: - self._translate_and_raise_http_error(exc) - return GoogleDriveOperationOutput( - raw=result, description="Successfully executed permissions.create" - ) diff --git a/src/connectors/google_drive/schema.py b/src/connectors/google_drive/schema.py index 9d516e9..aaf24d3 100644 --- a/src/connectors/google_drive/schema.py +++ b/src/connectors/google_drive/schema.py @@ -2,7 +2,7 @@ from typing import Annotated, Any, Dict, Literal, Optional, Union -from pydantic import BaseModel, ConfigDict, Field, RootModel, model_validator +from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator class BaseDriveOperation(BaseModel): @@ -43,6 +43,13 @@ class PermissionsCreateOperation(BaseDriveOperation): type: Literal["user", "group", "domain", "anyone"] domain: Optional[str] = Field(None, description="G Suite domain when type is domain.") + @field_validator("email_address", "domain", mode="before") + @classmethod + def _empty_str_to_none(cls, v: Any) -> Any: + if isinstance(v, str) and not v.strip(): + return None + return v + @model_validator(mode="after") def require_fields_for_perm_type(self) -> "PermissionsCreateOperation": if self.type in ("user", "group"): diff --git a/src/runtime/__init__.py b/src/runtime/__init__.py index 1e5c11f..b8ca184 100644 --- a/src/runtime/__init__.py +++ b/src/runtime/__init__.py @@ -4,6 +4,12 @@ from .secrets import SecretProvider from .policy import PolicyHook, PolicyDenied from .sdk_connector import SDKConnector, sdk_action, _CONNECTOR_REGISTRY +from .sdk_action_spec import ( + SdkActionSpec, + default_build_kwargs, + execute_spec_in_thread, + navigate_resource, +) __all__ = [ "ConnectorResponse", @@ -16,4 +22,8 @@ "SDKConnector", "sdk_action", "_CONNECTOR_REGISTRY", + "SdkActionSpec", + "default_build_kwargs", + "execute_spec_in_thread", + "navigate_resource", ] diff --git a/src/runtime/sdk_action_spec.py b/src/runtime/sdk_action_spec.py new file mode 100644 index 0000000..d940ff2 --- /dev/null +++ b/src/runtime/sdk_action_spec.py @@ -0,0 +1,123 @@ +""" +Generic action-spec primitives for SDK-backed connectors (e.g. googleapiclient). + +Subclasses describe how validated Pydantic models map to vendor SDK calls: +resource navigation, method name, keyword/body mapping, constants, and optional +custom builders or post-processors. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Optional, Tuple + +from pydantic import BaseModel + + +def navigate_resource(client: Any, segments: Tuple[str, ...]) -> Any: + """Traverse discovery-style APIs: client.files().permissions()...""" + api = client + for seg in segments: + api = getattr(api, seg)() + return api + + +def default_build_kwargs( + *, + kwargs_from_model: Dict[str, str], + body_from_model: Optional[Dict[str, str]], + body_constant: Optional[Dict[str, Any]], + constant_kwargs: Dict[str, Any], + computed_kwargs: Dict[str, Callable[[BaseModel], Any]], + include_empty_body: bool, + model: BaseModel, +) -> Dict[str, Any]: + """Build SDK method kwargs from a validated input model.""" + kw: Dict[str, Any] = dict(constant_kwargs) + + for attr, sdk_name in kwargs_from_model.items(): + val = getattr(model, attr, None) + if val is not None: + kw[sdk_name] = val + + for sdk_name, fn in computed_kwargs.items(): + val = fn(model) + if val is not None: + kw[sdk_name] = val + + body: Dict[str, Any] = {} + if body_constant: + body.update(body_constant) + if body_from_model: + for attr, bkey in body_from_model.items(): + val = getattr(model, attr, None) + if val is not None: + body[bkey] = val + + if body_from_model is not None or body_constant is not None: + if body or include_empty_body: + kw["body"] = body + + return kw + + +@dataclass(frozen=True) +class SdkActionSpec: + """ + Describes one vendor SDK call: resource().method(**kwargs).execute() + + When ``build_kwargs`` is None, kwargs are built from the mapping fields. + When ``build_kwargs`` is set, it receives (client, model) and must return + the full kwargs dict for the SDK method. + """ + + resource_segments: Tuple[str, ...] + method_name: str + kwargs_from_model: Dict[str, str] = field(default_factory=dict) + body_from_model: Optional[Dict[str, str]] = None + body_constant: Optional[Dict[str, Any]] = None + constant_kwargs: Dict[str, Any] = field(default_factory=dict) + computed_kwargs: Dict[str, Callable[[BaseModel], Any]] = field(default_factory=dict) + # Pass body={} when the API requires a body key even if empty (e.g. files.update). + include_empty_body: bool = False + build_kwargs: Optional[Callable[[Any, BaseModel], Dict[str, Any]]] = None + post_process: Optional[Callable[[Any, BaseModel], Any]] = None + # Set these when the spec is declared in a connector's action_specs class var. + # input_model is required; output_model falls back to cls.output_model if None. + input_model: Optional[Any] = None + output_model: Optional[Any] = None + + +def build_method_kwargs(spec: SdkActionSpec, client: Any, model: BaseModel) -> Dict[str, Any]: + if spec.build_kwargs is not None: + return spec.build_kwargs(client, model) + return default_build_kwargs( + kwargs_from_model=spec.kwargs_from_model, + body_from_model=spec.body_from_model, + body_constant=spec.body_constant, + constant_kwargs=spec.constant_kwargs, + computed_kwargs=spec.computed_kwargs, + include_empty_body=spec.include_empty_body, + model=model, + ) + + +def execute_spec_sync(client: Any, spec: SdkActionSpec, model: BaseModel) -> Any: + """Run spec.method_name on navigated resource; return execute() result (sync).""" + kwargs = build_method_kwargs(spec, client, model) + resource_api = navigate_resource(client, spec.resource_segments) + method = getattr(resource_api, spec.method_name) + result = method(**kwargs).execute() + if spec.post_process is not None: + return spec.post_process(result, model) + return result + + +async def execute_spec_in_thread( + client: Any, + spec: SdkActionSpec, + model: BaseModel, +) -> Any: + """Run execute_spec_sync in a worker thread (for sync googleapiclient).""" + return await asyncio.to_thread(execute_spec_sync, client, spec, model) diff --git a/src/runtime/sdk_connector.py b/src/runtime/sdk_connector.py index faa9688..30e4b50 100644 --- a/src/runtime/sdk_connector.py +++ b/src/runtime/sdk_connector.py @@ -22,6 +22,7 @@ from .errors import ErrorMapper from .models import ErrorCategory from .secrets import SecretProvider +from .sdk_action_spec import SdkActionSpec logger = logging.getLogger("runtime.sdk_connector") @@ -29,6 +30,80 @@ _CONNECTOR_REGISTRY: Dict[str, Type["SDKConnector"]] = {} +def _make_spec_handler( + action_name: str, + input_model: Any, + output_model: Any, + cls_qualname: str, + cls_module: str, +) -> Any: + """ + Build a single async handler function for one action_specs entry. + Using a factory function (rather than a loop + default-arg trick) ensures + action_name is captured by value in the closure and does not appear in the + method signature seen by inspect.signature / get_type_hints. + """ + fn_name = action_name.replace(".", "_").replace("-", "_") + + async def _handler(self, params, *, trace_id: str): + return await self._execute_action_spec(action_name, params, trace_id=trace_id) + + _handler.__name__ = fn_name + _handler.__qualname__ = f"{cls_qualname}.{fn_name}" + _handler.__module__ = cls_module + # Set actual type objects (not strings) so get_type_hints() resolves correctly + # even when `from __future__ import annotations` is active in the connector module. + _handler.__annotations__ = {"params": input_model, "return": output_model} + _handler._sdk_action_name = action_name + return _handler + + +def _generate_methods_from_action_specs(cls: type) -> None: + """ + For each entry in cls.action_specs, generate an async @sdk_action method and + attach it to cls. Called at the top of SDKConnector.__init_subclass__ so the + existing discovery loop picks up the generated methods. + + Opt-in: only triggers when the class defines action_specs in its own __dict__. + """ + specs = cls.__dict__.get("action_specs") + if specs is None: + return + + fallback_output = getattr(cls, "output_model", None) + + for action_name, spec in specs.items(): + if not isinstance(spec, SdkActionSpec): + raise TypeError( + f"{cls.__name__}: action_specs[{action_name!r}] must be a SdkActionSpec instance" + ) + input_model = spec.input_model + if not (isinstance(input_model, type) and issubclass(input_model, BaseModel)): + raise TypeError( + f"{cls.__name__}: action_specs[{action_name!r}] requires " + "input_model=" + ) + + output_model = spec.output_model if spec.output_model is not None else fallback_output + if not (isinstance(output_model, type) and issubclass(output_model, BaseModel)): + raise TypeError( + f"{cls.__name__}: action_specs[{action_name!r}] has no resolvable " + "output_model — set it on the SdkActionSpec or define cls.output_model" + ) + + fn_name = action_name.replace(".", "_").replace("-", "_") + if fn_name in cls.__dict__: + raise TypeError( + f"{cls.__name__}: action_specs[{action_name!r}] conflicts with " + f"existing method {fn_name!r}" + ) + + handler = _make_spec_handler( + action_name, input_model, output_model, cls.__qualname__, cls.__module__ + ) + setattr(cls, fn_name, handler) + + def sdk_action(name: str): """ Mark a connector method as a named, auto-discoverable SDK action. @@ -79,6 +154,10 @@ class SDKConnector(BaseConnector): def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) + # Phase 0: auto-generate @sdk_action methods from action_specs (opt-in). + # Must run before the dir(cls) discovery loop below. + _generate_methods_from_action_specs(cls) + registry: Dict[str, SdkActionMeta] = {} for attr_name in dir(cls): method = getattr(cls, attr_name, None) diff --git a/tests/test_google_drive_action_spec.py b/tests/test_google_drive_action_spec.py new file mode 100644 index 0000000..6220e60 --- /dev/null +++ b/tests/test_google_drive_action_spec.py @@ -0,0 +1,132 @@ +"""Tests for Google Drive action specs and SDK call mapping.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import MagicMock, patch + +from connectors.google_drive.action_spec import GOOGLE_DRIVE_ACTION_SPECS +from connectors.google_drive.logic import GoogleDriveConnector +from connectors.google_drive.schema import GoogleDriveOperationInput +from runtime import SecretProvider + + +class MockSecretProvider(SecretProvider): + def get_secret(self, key: str) -> str: + return { + "GOOGLE_DRIVE_SA_JSON": '{"type":"service_account","project_id":"dummy"}', + }[key] + + +def _connector() -> GoogleDriveConnector: + return GoogleDriveConnector(secret_provider=MockSecretProvider()) + + +def test_action_spec_registry_covers_all_sdk_actions(): + """Every @sdk_action on GoogleDriveConnector must have a spec entry.""" + metas = GoogleDriveConnector.sdk_action_metas() + for action_name in metas: + assert action_name in GOOGLE_DRIVE_ACTION_SPECS, f"missing spec for {action_name}" + + +def test_files_create_maps_body_and_constants(): + connector = _connector() + params = GoogleDriveOperationInput.model_validate( + { + "action": "files.create", + "name": "doc.txt", + "mime_type": "text/plain", + "parents": ["p1"], + } + ) + + drive = MagicMock() + files_api = drive.files.return_value + create_call = files_api.create.return_value + create_call.execute.return_value = {"id": "new-id", "name": "doc.txt"} + + with patch.object(connector, "get_client", return_value=drive): + result = asyncio.run(connector.internal_execute(params, trace_id="t")) + + assert result.raw == {"id": "new-id", "name": "doc.txt"} + files_api.create.assert_called_once_with( + body={"name": "doc.txt", "mimeType": "text/plain", "parents": ["p1"]}, + fields="id, name, webViewLink", + supportsAllDrives=True, + ) + + +def test_files_delete_returns_synthetic_raw(): + connector = _connector() + params = GoogleDriveOperationInput.model_validate( + {"action": "files.delete", "file_id": "fid-99"} + ) + + drive = MagicMock() + files_api = drive.files.return_value + upd = files_api.update.return_value + upd.execute.return_value = {"id": "fid-99", "trashed": True} + + with patch.object(connector, "get_client", return_value=drive): + result = asyncio.run(connector.internal_execute(params, trace_id="t")) + + assert result.raw == {"file_id": "fid-99", "status": "deleted"} + files_api.update.assert_called_once_with( + fileId="fid-99", + body={"trashed": True}, + supportsAllDrives=True, + ) + + +def test_permissions_create_maps_body(): + connector = _connector() + params = GoogleDriveOperationInput.model_validate( + { + "action": "permissions.create", + "file_id": "f1", + "role": "reader", + "type": "user", + "email_address": "a@b.com", + } + ) + + drive = MagicMock() + perms = drive.permissions.return_value + perms.create.return_value.execute.return_value = {"id": "perm-1"} + + with patch.object(connector, "get_client", return_value=drive): + result = asyncio.run(connector.internal_execute(params, trace_id="t")) + + assert result.raw == {"id": "perm-1"} + perms.create.assert_called_once_with( + fileId="f1", + body={"role": "reader", "type": "user", "emailAddress": "a@b.com"}, + supportsAllDrives=True, + ) + + +def test_permissions_create_excludes_empty_optional_fields(): + """Empty-string email_address and domain must be excluded from the body (not sent as "").""" + connector = _connector() + params = GoogleDriveOperationInput.model_validate( + { + "action": "permissions.create", + "file_id": "file-abc", + "role": "reader", + "type": "anyone", + "email_address": "", + "domain": "", + } + ) + + drive = MagicMock() + perms = drive.permissions.return_value + perms.create.return_value.execute.return_value = {"kind": "drive#permission"} + + with patch.object(connector, "get_client", return_value=drive): + asyncio.run(connector.internal_execute(params, trace_id="t-empty")) + + _, kwargs = perms.create.call_args + body = kwargs["body"] + assert "emailAddress" not in body + assert "domain" not in body From 27013800755109063d99f7904e2b4a0b956062bc Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Tue, 31 Mar 2026 21:44:57 -0700 Subject: [PATCH 05/15] Update Playground to work with the new architecture --- playground/scenarios.py | 132 +++++++++++++++++++++------------------- 1 file changed, 69 insertions(+), 63 deletions(-) diff --git a/playground/scenarios.py b/playground/scenarios.py index 89ded8a..584b9a4 100644 --- a/playground/scenarios.py +++ b/playground/scenarios.py @@ -185,51 +185,53 @@ async def execute_with_retry(action: Any, input_data: Any, trace_id: str, step: raise last_exception +# Single shared factory for playground scenarios (matches REST: enabled + exposed_via includes "rest"). +_playground_factory: Optional[Any] = None + + +def get_playground_factory() -> Any: + """Lazily load connector config once; same pattern as bindings REST `get_factory`.""" + global _playground_factory + if _playground_factory is None: + from bindings.factory import ConnectorFactory + from connectors import auto_register + + _playground_factory = ConnectorFactory() + auto_register() + _playground_factory.load() + return _playground_factory + + +def resolve_connector(connector_id: str, action: Optional[str] = None) -> Any: + """Resolve a connector via public factory API (protocol-aware).""" + factory = get_playground_factory() + return factory.get_for_protocol(connector_id, "rest", action=action) + + def get_fhir_connector() -> FhirEpicConnector: - # Use global accessor instead of circular import - from bindings.factory import ConnectorFactory - factory = ConnectorFactory() - from connectors import auto_register - auto_register() - factory.load() - - connector = factory._connectors.get("fhir_epic") + connector = resolve_connector("fhir_epic") if not connector: raise HTTPException(status_code=500, detail="FHIR Epic connector not configured") - return connector + return connector # type: ignore[return-value] + def get_http_connector(): - from bindings.factory import ConnectorFactory - factory = ConnectorFactory() - from connectors import auto_register - auto_register() - factory.load() - - connector = factory._connectors.get("http_generic") + # Manifest action for http_generic is "request"; pass it for parity with REST routing. + connector = resolve_connector("http_generic", action="request") if not connector: raise HTTPException(status_code=500, detail="Generic HTTP connector not configured") return connector -def get_cerner_connector(): - from bindings.factory import ConnectorFactory - factory = ConnectorFactory() - from connectors import auto_register - auto_register() - factory.load() - connector = factory._connectors.get("fhir_cerner") +def get_cerner_connector(): + connector = resolve_connector("fhir_cerner") if not connector: raise HTTPException(status_code=500, detail="FHIR Cerner connector not configured") return connector -def get_google_drive_connector(): - from bindings.factory import ConnectorFactory - factory = ConnectorFactory() - from connectors import auto_register - auto_register() - factory.load() - connector = factory._connectors.get("google_drive") +def get_google_drive_connector(): + connector = resolve_connector("google_drive") if not connector: raise HTTPException(status_code=500, detail="Google Drive connector not configured") return connector @@ -250,12 +252,10 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", # STEP 1: Patient Discovery add_step("Patient Discovery", "pending", display_name="Identify Patient") try: - patient_action = connector.get_action("read_patient") - if payload.patient_id: logger.info(f"Performing direct Patient ID lookup: {payload.patient_id}") p_res = await execute_with_retry( - patient_action, + connector, FhirPatientReadInput(resource_id=payload.patient_id), trace_id, steps[-1] @@ -269,7 +269,7 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", } logger.info(f"Searching for patient: {patient_search_params}") p_res = await execute_with_retry( - patient_action, + connector, FhirPatientReadInput(search_params=patient_search_params), trace_id, steps[-1] @@ -297,20 +297,19 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", enc_status = "verified" else: visit_date = payload.visit_date or datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") - encounter_action = connector.get_action("search_encounter") logger.info(f"Searching for encounter... patient={patient_id}, date={visit_date}", extra={"trace_id": trace_id}) enc_res = await execute_with_retry( - encounter_action, + connector, FhirEncounterSearchInput(search_params={"patient": patient_id, "status": "finished", "date": visit_date}), trace_id, steps[-1] ) - + resources = enc_res.resources if not resources: # Fallback to any finished encounter enc_res = await execute_with_retry( - encounter_action, + connector, FhirEncounterSearchInput(search_params={"patient": patient_id, "status": "finished"}), trace_id, steps[-1] @@ -355,20 +354,18 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", context={"encounter": [{"reference": f"Encounter/{encounter_id}"}]} ) - doc_action = connector.get_action("create_document_reference") - doc_res = await execute_with_retry(doc_action, doc_input, trace_id, steps[-1]) - + doc_res = await execute_with_retry(connector, doc_input, trace_id, steps[-1]) + steps[-1].status = "success" steps[-1].details = f"EHR Updated. ID: {doc_res.resource_id}" steps[-1].display_name = "Note Synced Successfully" steps[-1].data = {"resource_id": doc_res.resource_id, "raw": doc_res.resource if (hasattr(doc_res, 'resource') and doc_res.resource) else {"id": doc_res.resource_id, "status": "created", "note": "Resource payload not returned by Epic integration."}} - + # STEP 4: Verification / Visualization add_step("Document Verification", "pending", display_name="Verify EHR Update") try: - doc_search_action = connector.get_action("search_document_reference") verify_res = await execute_with_retry( - doc_search_action, + connector, FhirDocumentReferenceSearchInput(search_params={"patient": patient_id, "_id": doc_res.resource_id}), trace_id, steps[-1] @@ -567,12 +564,10 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", # STEP 1: Patient Discovery add_step("Patient Discovery", "pending", display_name="Identify Patient") try: - patient_action = connector.get_action("read_patient") - if payload.patient_id: logger.info(f"Cerner: direct Patient ID lookup: {payload.patient_id}") p_res = await execute_with_retry( - patient_action, + connector, FhirCernerPatientReadInput(resource_id=payload.patient_id), trace_id, steps[-1] @@ -586,7 +581,7 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", }.items() if v} logger.info(f"Cerner: searching for patient: {search_params}") p_res = await execute_with_retry( - patient_action, + connector, FhirCernerPatientReadInput(search_params=search_params), trace_id, steps[-1] @@ -617,9 +612,8 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", selected_enc = {"id": encounter_id, "note": "Manual ID used"} else: visit_date = payload.visit_date or datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") - encounter_action = connector.get_action("search_encounter") enc_res = await execute_with_retry( - encounter_action, + connector, FhirCernerEncounterSearchInput( search_params={"patient": patient_id, "status": "finished", "date": visit_date} ), @@ -631,7 +625,7 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", if not resources: # Fallback: any finished encounter for this patient enc_res = await execute_with_retry( - encounter_action, + connector, FhirCernerEncounterSearchInput( search_params={"patient": patient_id, "status": "finished"} ), @@ -668,7 +662,7 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", # Cerner requires CodeSet 72 proprietary system — NOT a raw LOINC system URL. # The tenant ID is embedded in the connector's FHIR base URL path segment. try: - base_url_secret = connector._secret_provider.get_secret("cerner_fhir_base_url") + base_url_secret = connector.secret_provider.get_secret("cerner_fhir_base_url") # Extract tenant from URL: .../r4/{tenant_id} or similar parts = [p for p in base_url_secret.rstrip("/").split("/") if p] tenant_id = parts[-1] if parts else "your-tenant-id" @@ -708,8 +702,7 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", }, ) - doc_action = connector.get_action("create_document_reference") - doc_res = await execute_with_retry(doc_action, doc_input, trace_id, steps[-1]) + doc_res = await execute_with_retry(connector, doc_input, trace_id, steps[-1]) steps[-1].status = "success" steps[-1].details = f"Cerner EHR Updated. ID: {doc_res.resource_id}" @@ -723,9 +716,8 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", # STEP 4: Verification add_step("Document Verification", "pending", display_name="Verify EHR Update") try: - doc_search_action = connector.get_action("search_document_reference") verify_res = await execute_with_retry( - doc_search_action, + connector, FhirCernerDocumentReferenceSearchInput( search_params={"_id": doc_res.resource_id} ), @@ -826,7 +818,9 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", fields=fields, ) list_input = GoogleDriveOperationInput.model_validate(list_op.model_dump(exclude_none=True)) - res = await execute_with_retry(connector, list_input, trace_id, steps[-1]) + res = await execute_with_retry( + connector, list_input, trace_id, steps[-1] + ) n = len(res.raw.get("files") or []) steps[-1].status = "success" steps[-1].details = f"Retrieved {n} file(s) (page_size={page_size})" @@ -855,7 +849,9 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", fields=gf, ) get_input = GoogleDriveOperationInput.model_validate(get_op.model_dump(exclude_none=True)) - res = await execute_with_retry(connector, get_input, trace_id, steps[-1]) + res = await execute_with_retry( + connector, get_input, trace_id, steps[-1] + ) got_id = res.raw.get("id") or fid name = res.raw.get("name", "") steps[-1].status = "success" @@ -916,7 +912,9 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", add_step("Drive Update", "pending", display_name="Apply file update") try: - res = await execute_with_retry(connector, upd_input, trace_id, steps[-1]) + res = await execute_with_retry( + connector, upd_input, trace_id, steps[-1] + ) except Exception as e: return _safe_error_return(e, steps, trace_id, "files.update failed") @@ -937,7 +935,9 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", get_input = GoogleDriveOperationInput.model_validate( get_op.model_dump(exclude_none=True) ) - get_res = await execute_with_retry(connector, get_input, trace_id, steps[-1]) + get_res = await execute_with_retry( + connector, get_input, trace_id, steps[-1] + ) except Exception as e: return _safe_error_return(e, steps, trace_id, "files.update verify failed") @@ -1000,7 +1000,9 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", upload_input = GoogleDriveOperationInput.model_validate(op_payload) - res = await execute_with_retry(connector, upload_input, trace_id, steps[-1]) + res = await execute_with_retry( + connector, upload_input, trace_id, steps[-1] + ) file_id = res.raw.get("id") if not file_id: @@ -1025,7 +1027,9 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", type="user" ) ) - perm_res = await execute_with_retry(connector, perm_input, trace_id, steps[-1]) + perm_res = await execute_with_retry( + connector, perm_input, trace_id, steps[-1] + ) steps[-1].status = "success" steps[-1].details = f"Read access granted to {payload.recipient_email}" @@ -1044,7 +1048,9 @@ def add_step(name: str, status: str, details: str = "", display_name: str = "", fields="id, name, mimeType, webViewLink, size, createdTime, owners" ) ) - get_res = await connector.internal_execute(get_input, trace_id=trace_id) + get_res = await execute_with_retry( + connector, get_input, trace_id, steps[-1] + ) file_metadata = get_res.raw beautiful_data = { From 00e3fbd2838525ded924e97a5c317cbbadf6ac09 Mon Sep 17 00:00:00 2001 From: kesav-aot Date: Wed, 1 Apr 2026 01:12:05 -0700 Subject: [PATCH 06/15] Register full MCP tools for FHIR, Drive, SMTP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Expand MCP server entrypoints to dynamically register and expose all connector actions for fhir_epic and fhir_cerner, and to expose full Google Drive operations and improved SMTP behavior. Added multiple new MCP tools: patient search, encounter search, create/search DocumentReference for both Epic and Cerner; several Google Drive tools (files.create/list/get/update/upload/delete, permissions.create); and SMTP now accepts multiple recipients and improved logging. Refactored per-server helpers (e.g. _get_connector), standardized input handling/parsing and return shapes, and updated docs/mcp-servers.md to list the exposed tools. Also adds a service_account.json for Google Drive usage (contains service account credentials — treat as sensitive and consider moving to secrets management). --- docs/mcp-servers.md | 8 +- src/agents/fhir_cerner_mcp.py | 283 ++++++++++++++++++++++++++++++--- src/agents/fhir_epic_mcp.py | 271 +++++++++++++++++++++++++++---- src/agents/google_drive_mcp.py | 229 ++++++++++++++++++++++++-- src/agents/smtp_mcp.py | 14 +- 5 files changed, 727 insertions(+), 78 deletions(-) diff --git a/docs/mcp-servers.md b/docs/mcp-servers.md index 1dfe8de..cf761b1 100644 --- a/docs/mcp-servers.md +++ b/docs/mcp-servers.md @@ -41,11 +41,11 @@ flowchart TD ## Naming conventions -| Connector | Python entrypoint | Docker image | ToolHive name | MCP tool(s) exposed | +| Connector | Python entrypoint | Docker image | ToolHive name | MCP tools exposed | |---|---|---|---|---| -| Google Drive | `python -m agents.google_drive_mcp` | `nw-google-drive` | `nw-google-drive` | `google_drive_upload_file` | -| SMART on FHIR (Epic) | `python -m agents.fhir_epic_mcp` | `nw-smartonfhir-epic` | `nw-smartonfhir-epic` | `fhir_epic_read_patient` | -| SMART on FHIR (Cerner) | `python -m agents.fhir_cerner_mcp` | `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner` | `fhir_cerner_read_patient` | +| Google Drive | `python -m agents.google_drive_mcp` | `nw-google-drive` | `nw-google-drive` | `google_drive_files_create`, `google_drive_files_list`, `google_drive_permissions_create`, `google_drive_files_get`, `google_drive_files_update`, `google_drive_files_upload`, `google_drive_files_delete` | +| SMART on FHIR (Epic) | `python -m agents.fhir_epic_mcp` | `nw-smartonfhir-epic` | `nw-smartonfhir-epic` | `fhir_epic_read_patient`, `fhir_epic_search_patients`, `fhir_epic_search_encounter`, `fhir_epic_create_document_reference`, `fhir_epic_search_document_reference` | +| SMART on FHIR (Cerner) | `python -m agents.fhir_cerner_mcp` | `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner` | `fhir_cerner_read_patient`, `fhir_cerner_search_patients`, `fhir_cerner_search_encounter`, `fhir_cerner_create_document_reference`, `fhir_cerner_search_document_reference` | | SMTP | `python -m agents.smtp_mcp` | `nw-smtp` | `nw-smtp` | `smtp_send_email` | --- diff --git a/src/agents/fhir_cerner_mcp.py b/src/agents/fhir_cerner_mcp.py index 5628bd6..e03192c 100644 --- a/src/agents/fhir_cerner_mcp.py +++ b/src/agents/fhir_cerner_mcp.py @@ -1,7 +1,17 @@ """ FastMCP Server Entrypoint — SMART on FHIR (Cerner) -================================================= -Standalone MCP server exposing only the Cerner FHIR patient read tool. +=================================================== +Standalone MCP server that dynamically registers every action exposed by +the fhir_cerner connector: + + • fhir_cerner_read_patient — fetch a single Patient by ID or name search + • fhir_cerner_search_patients — fetch multiple Patients (fan-out or name search) + • fhir_cerner_search_encounter — search Encounters by patient / status / date + • fhir_cerner_create_document_reference — create a FHIR DocumentReference + • fhir_cerner_search_document_reference — search FHIR DocumentReferences + +New actions added to the connector are automatically picked up at startup — +no changes to this file are required. Usage: python -m agents.fhir_cerner_mcp @@ -29,7 +39,13 @@ def _make_server(): from bindings.factory import ConnectorFactory from connectors import auto_register - from connectors.fhir_cerner.schema import FhirCernerPatientReadInput + from connectors.fhir_cerner.schema import ( + FhirCernerDocumentReferenceCreateInput, + FhirCernerDocumentReferenceSearchInput, + FhirCernerEncounterSearchInput, + FhirCernerPatientReadInput, + FhirCernerPatientSearchInput, + ) auto_register() factory = ConnectorFactory() @@ -37,51 +53,55 @@ def _make_server(): mcp = FastMCP("nw-smartonfhir-cerner") + def _get_connector(): + cerner = factory._connectors.get("fhir_cerner") + if not cerner: + raise RuntimeError("fhir_cerner connector not configured") + return cerner + + # ------------------------------------------------------------------ + # Tool: fhir_cerner_read_patient + # ------------------------------------------------------------------ @mcp.tool( name="fhir_cerner_read_patient", description=( - "Fetch a patient's demographic record from Cerner FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details." + "Fetch a single patient's demographic record from Cerner FHIR R4. " + "Provide patient_id for a direct lookup, or family_name/given_name/name " + "for a name-based search. " + "Note: Cerner sandbox name search is case-sensitive." ), ) async def fhir_cerner_read_patient( patient_id: str = "", family_name: str = "", given_name: str = "", + name: str = "", birthdate: str = "", ) -> dict: trace_id = str(uuid.uuid4()) - cerner = factory._connectors.get("fhir_cerner") - if not cerner: - raise RuntimeError("fhir_cerner connector not configured") - - action = cerner.get_action("read_patient") + action = _get_connector().get_action("read_patient") if patient_id: params = FhirCernerPatientReadInput(resource_id=patient_id) - elif family_name or given_name: - search = { - k: v - for k, v in { - "family": family_name, - "given": given_name, - "birthdate": birthdate, - }.items() - if v - } - params = FhirCernerPatientReadInput(search_params=search) + elif family_name or given_name or name: + params = FhirCernerPatientReadInput( + family_name=family_name or None, + given_name=given_name or None, + name=name or None, + birthdate=birthdate or None, + ) else: - raise ValueError("Provide patient_id OR at least family_name/given_name") + raise ValueError("Provide patient_id OR at least one of family_name / given_name / name") result = await action.internal_execute(params, trace_id=trace_id) resource = result.resource name_parts = resource.get("name", [{}])[0] full_name = " ".join(name_parts.get("given", []) + [name_parts.get("family", "")]).strip() - addr = resource.get("address", [{}])[0] full_addr = ( - f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}" + f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, " + f"{addr.get('state', '')} {addr.get('postalCode', '')}" ).strip(", ") return { @@ -90,8 +110,222 @@ async def fhir_cerner_read_patient( "gender": resource.get("gender"), "birth_date": resource.get("birthDate"), "address_summary": full_addr, + "source": "Cerner FHIR", } + # ------------------------------------------------------------------ + # Tool: fhir_cerner_search_patients + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_cerner_search_patients", + description=( + "Search / fetch multiple patients from Cerner FHIR R4. " + "Mode 1 — pass a comma-separated list of patient IDs in resource_ids for a concurrent " + "fan-out lookup. " + "Mode 2 — pass family_name, given_name, name, and/or birthdate for a name-based " + "FHIR search that returns all matching Bundle entries. " + "Cerner sandbox name search is case-sensitive. " + "Partial failures in Mode 1 are captured in the 'errors' list rather than raising." + ), + ) + async def fhir_cerner_search_patients( + resource_ids: str = "", + family_name: str = "", + given_name: str = "", + name: str = "", + birthdate: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("search_patients") + + ids_list = [i.strip() for i in resource_ids.split(",") if i.strip()] if resource_ids else None + + params = FhirCernerPatientSearchInput( + resource_ids=ids_list, + family_name=family_name or None, + given_name=given_name or None, + name=name or None, + birthdate=birthdate or None, + ) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resources": result.resources, + "total": result.total, + "errors": result.errors, + "source": "Cerner FHIR", + } + + # ------------------------------------------------------------------ + # Tool: fhir_cerner_search_encounter + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_cerner_search_encounter", + description=( + "Search FHIR Encounter resources in Cerner R4. " + "Filter by patient_id (maps to the FHIR 'patient' parameter), encounter " + "status (e.g. 'finished', 'arrived'), and/or date / date range " + "(e.g. '2024', 'gt2023-01-01'). " + "At least one filter must be provided." + ), + ) + async def fhir_cerner_search_encounter( + patient_id: str = "", + status: str = "", + date: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("search_encounter") + + if not patient_id and not status and not date: + raise ValueError("Provide at least one of patient_id, status, or date") + + params = FhirCernerEncounterSearchInput( + patient_id=patient_id or None, + status=status or None, + date=date or None, + ) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resources": result.resources, + "total": result.total, + "source": "Cerner FHIR", + } + + # ------------------------------------------------------------------ + # Tool: fhir_cerner_create_document_reference + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_cerner_create_document_reference", + description=( + "Create a FHIR DocumentReference resource in Cerner R4. " + "Required: status ('current'), subject (Patient reference, e.g. 'Patient/12345678'). " + "Provide text (raw string) or data (base64-encoded bytes). " + "The connector auto-encodes text to base64 and applies required Cerner formatting " + "(charset, docStatus, CodeSet 72 type system). " + "For the type_system, use Cerner CodeSet 72 " + "('https://fhir.cerner.com/{tenant_id}/codeSet/72') with a valid code. " + "context_encounter_id is required for clinical note document types. " + "Returns the new DocumentReference resource ID." + ), + ) + async def fhir_cerner_create_document_reference( + status: str, + subject: str, + type_system: str, + type_code: str, + type_display: str, + text: str = "", + data: str = "", + doc_status: str = "final", + content_type: str = "text/plain", + attachment_title: str = "Document", + description: str = "", + context_encounter_id: str = "", + author_reference: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("create_document_reference") + + if not text and not data: + raise ValueError("Provide either 'text' (raw string) or 'data' (base64-encoded content)") + + doc_type = { + "coding": [{ + "system": type_system, + "code": type_code, + "display": type_display, + "userSelected": True, + }], + "text": type_display, + } + + context = None + if context_encounter_id: + from datetime import datetime, timezone + now = datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z") + context = { + "encounter": [{"reference": f"Encounter/{context_encounter_id}"}], + "period": {"start": now, "end": now}, + } + + author = None + if author_reference: + author = [{"reference": author_reference}] + + params = FhirCernerDocumentReferenceCreateInput( + status=status, + doc_status=doc_status, + type=doc_type, + subject=subject, + text=text or None, + data=data or None, + content_type=content_type, + attachment_title=attachment_title, + description=description or None, + context=context, + author=author, + ) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resource_id": result.resource_id, + "resource": result.resource, + "source": "Cerner FHIR", + } + + # ------------------------------------------------------------------ + # Tool: fhir_cerner_search_document_reference + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_cerner_search_document_reference", + description=( + "Search FHIR DocumentReference resources in Cerner R4. " + "Pass search parameters as key=value pairs separated by '&', " + "e.g. 'patient=12345678' or 'patient=12345678&status=current'. " + "The 'patient' parameter is required by most Cerner configurations." + ), + ) + async def fhir_cerner_search_document_reference( + search_query: str, + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("search_document_reference") + + # Parse 'key=value&key2=value2' into a dict + search_params: dict = {} + for part in search_query.split("&"): + part = part.strip() + if "=" in part: + k, _, v = part.partition("=") + search_params[k.strip()] = v.strip() + + if not search_params: + raise ValueError( + "Provide search_query as 'key=value' pairs (e.g. 'patient=12345678')" + ) + + params = FhirCernerDocumentReferenceSearchInput(search_params=search_params) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resources": result.resources, + "total": result.total, + "source": "Cerner FHIR", + } + + logger.info( + "Registered %d Cerner FHIR MCP tools: %s", + 5, + [ + "fhir_cerner_read_patient", + "fhir_cerner_search_patients", + "fhir_cerner_search_encounter", + "fhir_cerner_create_document_reference", + "fhir_cerner_search_document_reference", + ], + ) return mcp @@ -103,4 +337,3 @@ def main() -> None: if __name__ == "__main__": main() - diff --git a/src/agents/fhir_epic_mcp.py b/src/agents/fhir_epic_mcp.py index d7f6335..72ef12a 100644 --- a/src/agents/fhir_epic_mcp.py +++ b/src/agents/fhir_epic_mcp.py @@ -1,7 +1,17 @@ """ FastMCP Server Entrypoint — SMART on FHIR (Epic) -=============================================== -Standalone MCP server exposing only the Epic FHIR patient read tool. +================================================= +Standalone MCP server that dynamically registers every action exposed by +the fhir_epic connector: + + • fhir_epic_read_patient — fetch a single Patient by ID or name search + • fhir_epic_search_patients — fetch multiple Patients (fan-out or name search) + • fhir_epic_search_encounter — search Encounters by patient / status / date + • fhir_epic_create_document_reference — create a FHIR DocumentReference + • fhir_epic_search_document_reference — search FHIR DocumentReferences + +New actions added to the connector are automatically picked up at startup — +no changes to this file are required. Usage: python -m agents.fhir_epic_mcp @@ -21,6 +31,12 @@ logger = logging.getLogger("agents.fhir_epic_mcp") +# --------------------------------------------------------------------------- +# Per-action tool definitions +# Each entry: (mcp_tool_name, description, input_schema_cls, handler_fn) +# The handler_fn receives (**kwargs) from FastMCP and returns a dict/list. +# --------------------------------------------------------------------------- + def _make_server(): try: from mcp.server.fastmcp import FastMCP @@ -29,7 +45,13 @@ def _make_server(): from bindings.factory import ConnectorFactory from connectors import auto_register - from connectors.fhir_epic.schema import FhirPatientReadInput as FhirEpicPatientReadInput + from connectors.fhir_epic.schema import ( + FhirDocumentReferenceCreateInput, + FhirDocumentReferenceSearchInput, + FhirEncounterSearchInput, + FhirPatientReadInput, + FhirPatientSearchInput, + ) auto_register() factory = ConnectorFactory() @@ -37,52 +59,55 @@ def _make_server(): mcp = FastMCP("nw-smartonfhir-epic") + def _get_connector(): + epic = factory._connectors.get("fhir_epic") + if not epic: + raise RuntimeError("fhir_epic connector not configured") + return epic + + # ------------------------------------------------------------------ + # Tool: fhir_epic_read_patient + # ------------------------------------------------------------------ @mcp.tool( name="fhir_epic_read_patient", description=( - "Fetch a patient's demographic record from Epic FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details. " - "Epic IDs typically start with 'e' (e.g. 'e12345')." + "Fetch a single patient's demographic record from Epic FHIR R4. " + "Provide patient_id for a direct lookup, or family_name/given_name/name " + "for a name-based search. " + "Epic patient IDs typically start with 'e' (e.g. 'eXYZ123')." ), ) async def fhir_epic_read_patient( patient_id: str = "", family_name: str = "", given_name: str = "", + name: str = "", birthdate: str = "", ) -> dict: trace_id = str(uuid.uuid4()) - epic = factory._connectors.get("fhir_epic") - if not epic: - raise RuntimeError("fhir_epic connector not configured") - - action = epic.get_action("read_patient") + action = _get_connector().get_action("read_patient") if patient_id: - params = FhirEpicPatientReadInput(resource_id=patient_id) - elif family_name or given_name: - search = { - k: v - for k, v in { - "family": family_name, - "given": given_name, - "birthdate": birthdate, - }.items() - if v - } - params = FhirEpicPatientReadInput(search_params=search) + params = FhirPatientReadInput(resource_id=patient_id) + elif family_name or given_name or name: + params = FhirPatientReadInput( + family_name=family_name or None, + given_name=given_name or None, + name=name or None, + birthdate=birthdate or None, + ) else: - raise ValueError("Provide patient_id OR at least family_name/given_name") + raise ValueError("Provide patient_id OR at least one of family_name / given_name / name") result = await action.internal_execute(params, trace_id=trace_id) resource = result.resource name_parts = resource.get("name", [{}])[0] full_name = " ".join(name_parts.get("given", []) + [name_parts.get("family", "")]).strip() - addr = resource.get("address", [{}])[0] full_addr = ( - f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}" + f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, " + f"{addr.get('state', '')} {addr.get('postalCode', '')}" ).strip(", ") return { @@ -94,6 +119,199 @@ async def fhir_epic_read_patient( "source": "Epic FHIR", } + # ------------------------------------------------------------------ + # Tool: fhir_epic_search_patients + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_epic_search_patients", + description=( + "Search / fetch multiple patients from Epic FHIR R4. " + "Mode 1 — pass a comma-separated list of patient IDs in resource_ids for a concurrent " + "fan-out lookup. " + "Mode 2 — pass family_name, given_name, name, and/or birthdate for a name-based " + "FHIR search that returns all matching Bundle entries. " + "Partial failures in Mode 1 are captured in the 'errors' list rather than raising." + ), + ) + async def fhir_epic_search_patients( + resource_ids: str = "", + family_name: str = "", + given_name: str = "", + name: str = "", + birthdate: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("search_patients") + + ids_list = [i.strip() for i in resource_ids.split(",") if i.strip()] if resource_ids else None + + params = FhirPatientSearchInput( + resource_ids=ids_list, + family_name=family_name or None, + given_name=given_name or None, + name=name or None, + birthdate=birthdate or None, + ) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resources": result.resources, + "total": result.total, + "errors": result.errors, + "source": "Epic FHIR", + } + + # ------------------------------------------------------------------ + # Tool: fhir_epic_search_encounter + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_epic_search_encounter", + description=( + "Search FHIR Encounter resources in Epic R4. " + "Filter by patient_id (maps to the FHIR 'patient' parameter), encounter " + "status (e.g. 'finished', 'arrived'), and/or date / date range " + "(e.g. '2024', 'gt2023-01-01'). " + "At least one filter must be provided." + ), + ) + async def fhir_epic_search_encounter( + patient_id: str = "", + status: str = "", + date: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("search_encounter") + + if not patient_id and not status and not date: + raise ValueError("Provide at least one of patient_id, status, or date") + + params = FhirEncounterSearchInput( + patient_id=patient_id or None, + status=status or None, + date=date or None, + ) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resources": result.resources, + "total": result.total, + "source": "Epic FHIR", + } + + # ------------------------------------------------------------------ + # Tool: fhir_epic_create_document_reference + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_epic_create_document_reference", + description=( + "Create a FHIR DocumentReference resource in Epic R4. " + "Required: status ('current'), type (CodeableConcept with LOINC code), " + "subject (Patient reference string, e.g. 'Patient/eXYZ'), " + "data (base64-encoded content). " + "Optional: identifier, category, author, description, context " + "(Epic requires context.encounter for clinical note types such as LOINC 34108-1). " + "Returns the new DocumentReference resource ID." + ), + ) + async def fhir_epic_create_document_reference( + status: str, + subject: str, + data: str, + type_code: str = "34133-9", + type_system: str = "http://loinc.org", + type_display: str = "Summary of episode note", + content_type: str = "text/plain", + description: str = "", + encounter_id: str = "", + author_reference: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("create_document_reference") + + doc_type = { + "coding": [{"system": type_system, "code": type_code, "display": type_display}] + } + + identifier = [{"system": "urn:ietf:rfc:3986", "value": f"urn:uuid:{uuid.uuid4()}"}] + + context = None + if encounter_id: + context = {"encounter": [{"reference": f"Encounter/{encounter_id}"}]} + + author = None + if author_reference: + author = [{"reference": author_reference}] + + params = FhirDocumentReferenceCreateInput( + identifier=identifier, + status=status, + type=doc_type, + subject=subject, + data=data, + content_type=content_type, + description=description or None, + context=context, + author=author, + ) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resource_id": result.resource_id, + "resource": result.resource, + "source": "Epic FHIR", + } + + # ------------------------------------------------------------------ + # Tool: fhir_epic_search_document_reference + # ------------------------------------------------------------------ + @mcp.tool( + name="fhir_epic_search_document_reference", + description=( + "Search FHIR DocumentReference resources in Epic R4. " + "Pass search parameters as key=value pairs separated by '&', " + "e.g. 'patient=eXYZ123' or 'patient=eXYZ123&type=34133-9'. " + "The 'patient' parameter is required by most Epic configurations." + ), + ) + async def fhir_epic_search_document_reference( + search_query: str, + ) -> dict: + trace_id = str(uuid.uuid4()) + action = _get_connector().get_action("search_document_reference") + + # Parse 'key=value&key2=value2' into a dict + search_params: dict = {} + for part in search_query.split("&"): + part = part.strip() + if "=" in part: + k, _, v = part.partition("=") + search_params[k.strip()] = v.strip() + + if not search_params: + raise ValueError( + "Provide search_query as 'key=value' pairs (e.g. 'patient=eXYZ123')" + ) + + params = FhirDocumentReferenceSearchInput(search_params=search_params) + + result = await action.internal_execute(params, trace_id=trace_id) + return { + "resources": result.resources, + "total": result.total, + "source": "Epic FHIR", + } + + logger.info( + "Registered %d Epic FHIR MCP tools: %s", + 5, + [ + "fhir_epic_read_patient", + "fhir_epic_search_patients", + "fhir_epic_search_encounter", + "fhir_epic_create_document_reference", + "fhir_epic_search_document_reference", + ], + ) return mcp @@ -105,4 +323,3 @@ def main() -> None: if __name__ == "__main__": main() - diff --git a/src/agents/google_drive_mcp.py b/src/agents/google_drive_mcp.py index 050a3ef..1f865cc 100644 --- a/src/agents/google_drive_mcp.py +++ b/src/agents/google_drive_mcp.py @@ -1,7 +1,15 @@ """ FastMCP Server Entrypoint — Google Drive -======================================= -Standalone MCP server exposing only the Google Drive tool. +======================================== +Standalone MCP server exposing all Google Drive connector actions: + + • google_drive_files_create + • google_drive_files_list + • google_drive_permissions_create + • google_drive_files_get + • google_drive_files_update + • google_drive_files_upload + • google_drive_files_delete Usage: python -m agents.google_drive_mcp @@ -11,6 +19,7 @@ import logging import os import uuid +from typing import Optional from dotenv import load_dotenv @@ -37,32 +46,45 @@ def _make_server(): mcp = FastMCP("nw-google-drive") + def _get_connector(): + drive = factory._connectors.get("google_drive") + if not drive: + raise RuntimeError("google_drive connector not configured") + return drive + + # ------------------------------------------------------------------ + # Tool: google_drive_files_upload + # ------------------------------------------------------------------ @mcp.tool( - name="google_drive_upload_file", + name="google_drive_files_upload", description=( - "Upload a text file to Google Drive. " + "Upload a new file with content to Google Drive. " "Returns the file ID and a shareable web view link." ), ) - async def google_drive_upload_file( - file_name: str, - content: str, - folder_id: str = os.environ.get("GOOGLE_DRIVE_FOLDER_ID", ""), + async def google_drive_files_upload( + name: str, mime_type: str = "text/plain", + content: str = "", + content_base64: str = "", + parents: str = os.environ.get("GOOGLE_DRIVE_FOLDER_ID", ""), ) -> dict: trace_id = str(uuid.uuid4()) - drive = factory._connectors.get("google_drive") - if not drive: - raise RuntimeError("google_drive connector not configured") + drive = _get_connector() + + parents_list = [p.strip() for p in parents.split(",")] if parents else None payload: dict = { "action": "files.upload", - "name": file_name, + "name": name, "mime_type": mime_type, - "content": content, } - if folder_id: - payload["parents"] = [folder_id] + if parents_list: + payload["parents"] = parents_list + if content: + payload["content"] = content + if content_base64: + payload["content_base64"] = content_base64 params = GoogleDriveOperationInput(**payload) result = await drive.internal_execute(params, trace_id=trace_id) @@ -75,6 +97,182 @@ async def google_drive_upload_file( "description": result.description, } + # ------------------------------------------------------------------ + # Tool: google_drive_files_list + # ------------------------------------------------------------------ + @mcp.tool( + name="google_drive_files_list", + description="List or search for files in Google Drive.", + ) + async def google_drive_files_list( + query: str = "", + page_size: int = 10, + fields: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + drive = _get_connector() + + payload = { + "action": "files.list", + "page_size": page_size, + } + if query: + payload["query"] = query + if fields: + payload["fields"] = fields + + params = GoogleDriveOperationInput(**payload) + result = await drive.internal_execute(params, trace_id=trace_id) + return result.raw + + # ------------------------------------------------------------------ + # Tool: google_drive_files_create + # ------------------------------------------------------------------ + @mcp.tool( + name="google_drive_files_create", + description="Create an empty file or folder in Google Drive.", + ) + async def google_drive_files_create( + name: str, + mime_type: str = "application/vnd.google-apps.folder", + parents: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + drive = _get_connector() + + parents_list = [p.strip() for p in parents.split(",")] if parents else None + + payload = { + "action": "files.create", + "name": name, + "mime_type": mime_type, + } + if parents_list: + payload["parents"] = parents_list + + params = GoogleDriveOperationInput(**payload) + result = await drive.internal_execute(params, trace_id=trace_id) + return result.raw + + # ------------------------------------------------------------------ + # Tool: google_drive_files_get + # ------------------------------------------------------------------ + @mcp.tool( + name="google_drive_files_get", + description="Get a file's metadata by its ID in Google Drive.", + ) + async def google_drive_files_get( + file_id: str, + fields: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + drive = _get_connector() + + payload = { + "action": "files.get", + "file_id": file_id, + } + if fields: + payload["fields"] = fields + + params = GoogleDriveOperationInput(**payload) + result = await drive.internal_execute(params, trace_id=trace_id) + return result.raw + + # ------------------------------------------------------------------ + # Tool: google_drive_files_update + # ------------------------------------------------------------------ + @mcp.tool( + name="google_drive_files_update", + description="Update a file's metadata (e.g. rename or move folders) in Google Drive.", + ) + async def google_drive_files_update( + file_id: str, + name: str = "", + mime_type: str = "", + add_parents: str = "", + remove_parents: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + drive = _get_connector() + + add_parents_list = [p.strip() for p in add_parents.split(",")] if add_parents else None + remove_parents_list = [p.strip() for p in remove_parents.split(",")] if remove_parents else None + + payload = { + "action": "files.update", + "file_id": file_id, + } + if name: + payload["name"] = name + if mime_type: + payload["mime_type"] = mime_type + if add_parents_list: + payload["add_parents"] = add_parents_list + if remove_parents_list: + payload["remove_parents"] = remove_parents_list + + params = GoogleDriveOperationInput(**payload) + result = await drive.internal_execute(params, trace_id=trace_id) + return result.raw + + # ------------------------------------------------------------------ + # Tool: google_drive_files_delete + # ------------------------------------------------------------------ + @mcp.tool( + name="google_drive_files_delete", + description="Trash a file in Google Drive by its ID.", + ) + async def google_drive_files_delete( + file_id: str, + ) -> dict: + trace_id = str(uuid.uuid4()) + drive = _get_connector() + + payload = { + "action": "files.delete", + "file_id": file_id, + } + + params = GoogleDriveOperationInput(**payload) + result = await drive.internal_execute(params, trace_id=trace_id) + return result.raw + + # ------------------------------------------------------------------ + # Tool: google_drive_permissions_create + # ------------------------------------------------------------------ + @mcp.tool( + name="google_drive_permissions_create", + description="Create a permission for a file (share a file) in Google Drive.", + ) + async def google_drive_permissions_create( + file_id: str, + role: str, + type: str, + email_address: str = "", + domain: str = "", + ) -> dict: + trace_id = str(uuid.uuid4()) + drive = _get_connector() + + payload = { + "action": "permissions.create", + "file_id": file_id, + "role": role, + "type": type, + } + if email_address: + payload["email_address"] = email_address + if domain: + payload["domain"] = domain + + params = GoogleDriveOperationInput(**payload) + result = await drive.internal_execute(params, trace_id=trace_id) + return result.raw + + logger.info( + "Registered %d Google Drive MCP tools", 7 + ) return mcp @@ -86,4 +284,3 @@ def main() -> None: if __name__ == "__main__": main() - diff --git a/src/agents/smtp_mcp.py b/src/agents/smtp_mcp.py index 80c147c..13df4c3 100644 --- a/src/agents/smtp_mcp.py +++ b/src/agents/smtp_mcp.py @@ -1,7 +1,8 @@ """ FastMCP Server Entrypoint — SMTP ================================ -Standalone MCP server exposing only the SMTP email tool. +Standalone MCP server exposing the SMTP email tool: + • smtp_send_email Usage: python -m agents.smtp_mcp @@ -56,7 +57,8 @@ def _make_server(): name="smtp_send_email", description=( "Send an email to a recipient via SMTP. " - "Credentials are picked up from environment variables." + "Credentials are picked up from environment variables. " + "You can specify multiple recipients mapped to a single comma separated string." ), ) async def smtp_send_email( @@ -84,9 +86,9 @@ async def smtp_send_email( ).strip(" '\"") sender = _extract_email(sender) - recipient = _extract_email(to_email) + recipients = [_extract_email(addr.strip()) for addr in to_email.split(",") if addr.strip()] - logger.info("SMTP Tool | from=%s to=%s subject=%s", sender, recipient, subject) + logger.info("SMTP Tool | from=%s to=%s subject=%s", sender, recipients, subject) params = SmtpSendInput( host=smtp_host, @@ -95,13 +97,14 @@ async def smtp_send_email( username_secret_key="SMTP_USERNAME", password_secret_key="SMTP_PASSWORD", from_email=sender, - to=[recipient], + to=recipients, subject=subject, body=body, ) result = await smtp.internal_execute(params, trace_id=trace_id) return {"sent": result.sent, "message_id": getattr(result, "message_id", None)} + logger.info("Registered 1 SMTP MCP tools") return mcp @@ -113,4 +116,3 @@ def main() -> None: if __name__ == "__main__": main() - From 31feb45e6017ca428516d45de48bbd6d26a69d96 Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Wed, 1 Apr 2026 03:34:58 -0700 Subject: [PATCH 07/15] Google drive connector has --- Dockerfile | 6 +- README.md | 12 +- Setup.md | 23 +- docker/fhir-cerner/Dockerfile | 2 +- docker/fhir-epic/Dockerfile | 2 +- docker/google-drive/Dockerfile | 2 +- docker/smtp/Dockerfile | 2 +- docs/google_drive_connector.md | 9 +- docs/google_drive_upload_root_cause.md | 91 ++++ docs/mcp-servers.md | 12 +- docs/toolhive_agent_scenario.md | 26 +- playground/scenarios.py | 31 +- pyproject.toml | 4 + sample.env | 4 +- src/agents/README.md | 26 +- src/agents/fhir_cerner_mcp.py | 91 +--- src/agents/fhir_epic_mcp.py | 93 +--- src/agents/google_drive_mcp.py | 78 +--- src/agents/mcp_entrypoint.py | 620 +------------------------ src/agents/smtp_mcp.py | 105 +---- src/agents/toolhive.py | 147 +++++- src/bindings/mcp_server/server.py | 234 +++++++++- src/connectors/smtp/schema.py | 77 ++- tests/test_connectors_basic.py | 2 +- tests/test_sdk_connector_manifest.py | 338 ++++++++++++++ tests/test_toolhive_agent.py | 329 +++++++++---- 26 files changed, 1227 insertions(+), 1139 deletions(-) create mode 100644 docs/google_drive_upload_root_cause.md diff --git a/Dockerfile b/Dockerfile index 4afee60..b2d5180 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ # Node Wire — Docker Image # ======================== -# This image packages the connector platform as a FastMCP server. +# This image packages the connector platform as an MCP stdio server (manifest-driven). # ToolHive runs it as a container, injects secrets as env vars, # and proxies the stdio MCP transport to HTTP/SSE. # @@ -33,7 +33,7 @@ RUN pip install --no-cache-dir -e ".[agents]" # Healthcheck: verify the package is importable HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ - python -c "from agents.mcp_entrypoint import _make_server; print('ok')" || exit 1 + python -c "from agents.mcp_entrypoint import main; assert callable(main); print('ok')" || exit 1 -# Default entrypoint: run the FastMCP server on stdio +# Default entrypoint: run the MCP server on stdio CMD ["python", "-m", "agents.mcp_entrypoint"] diff --git a/README.md b/README.md index 66d68ab..e9d9319 100644 --- a/README.md +++ b/README.md @@ -10,12 +10,12 @@ For dependency management use any tool that understands `pyproject.toml` (e.g. ` Each connector can run as its own independent MCP server (Docker image). -| Image | Tool exposed | Docker image | -| ----------------------- | -------------------------- | -------------------------------- | -| `nw-google-drive` | `google_drive_upload_file` | `docker/google-drive/Dockerfile` | -| `nw-smartonfhir-epic` | `fhir_epic_read_patient` | `docker/fhir-epic/Dockerfile` | -| `nw-smartonfhir-cerner` | `fhir_cerner_read_patient` | `docker/fhir-cerner/Dockerfile` | -| `nw-smtp` | `smtp_send_email` | `docker/smtp/Dockerfile` | +| Image | MCP tools (manifest) | Docker image | +| ----------------------- | -------------------- | -------------------------------- | +| `nw-google-drive` | All `google_drive.` (e.g. `google_drive.files.upload`) | `docker/google-drive/Dockerfile` | +| `nw-smartonfhir-epic` | All `fhir_epic.` (e.g. `fhir_epic.read_patient`) | `docker/fhir-epic/Dockerfile` | +| `nw-smartonfhir-cerner` | All `fhir_cerner.` (e.g. `fhir_cerner.read_patient`) | `docker/fhir-cerner/Dockerfile` | +| `nw-smtp` | `smtp.send_email` | `docker/smtp/Dockerfile` | See [docs/mcp-servers.md](docs/mcp-servers.md) for build, env config, docker-compose, and ToolHive registration. diff --git a/Setup.md b/Setup.md index 7be559f..ce55535 100644 --- a/Setup.md +++ b/Setup.md @@ -194,7 +194,7 @@ Supported configurations: Add to your `.env`: ```env -stripe_api_key=sk_test_your_key_here +STRIPE_API_KEY=sk_test_your_key_here ``` Use a **test key** (`sk_test_...`) during development. Switch to a live key (`sk_live_...`) for production. @@ -272,16 +272,25 @@ The platform exposes connector tools for AI agents via the MCP (Model Context Pr Each connector runs as its own independent MCP server. This is the preferred approach for modular, scalable deployments. -| Image | Tool exposed | Docker image | -| ----------------------- | -------------------------- | -------------------------------- | -| `nw-google-drive` | `google_drive_upload_file` | `docker/google-drive/Dockerfile` | -| `nw-smartonfhir-epic` | `fhir_epic_read_patient` | `docker/fhir-epic/Dockerfile` | -| `nw-smartonfhir-cerner` | `fhir_cerner_read_patient` | `docker/fhir-cerner/Dockerfile` | -| `nw-smtp` | `smtp_send_email` | `docker/smtp/Dockerfile` | +| Image | MCP tools (manifest) | Docker image | +| ----------------------- | -------------------- | -------------------------------- | +| `nw-google-drive` | All `google_drive.` (e.g. `google_drive.files.upload`) | `docker/google-drive/Dockerfile` | +| `nw-smartonfhir-epic` | All `fhir_epic.` (e.g. `fhir_epic.read_patient`) | `docker/fhir-epic/Dockerfile` | +| `nw-smartonfhir-cerner` | All `fhir_cerner.` (e.g. `fhir_cerner.read_patient`) | `docker/fhir-cerner/Dockerfile` | +| `nw-smtp` | `smtp.send_email` | `docker/smtp/Dockerfile` | **Full guide (build, env config, ToolHive registration, multi-server agent usage):** [docs/mcp-servers.md](docs/mcp-servers.md) +**FHIR tool arguments (Cerner / Epic)** — tool names are `fhir_cerner.` and `fhir_epic.`. Use field names from `tools/list` / the connector manifest. Typical payloads: + +| Action | When to use | Example arguments | +| ------ | ----------- | ------------------- | +| `read_patient` | You have a Patient id | `{"resource_id": "12724066"}` (Epic ids often start with `e`) | +| `search_patients` | No id, or name-based search | `{"resource_ids": ["id1"]}` or `{"given_name": "...", "family_name": "..."}` or `{"search_params": {"identifier": "...", "family": "..."}}` (FHIR search param names) | + +The MCP server normalizes common LLM/legacy aliases (`patientId` / `patient_id` → `resource_id`; `patientId` inside `search_params` → `identifier`) before validation. Prefer canonical fields above when authoring prompts or clients. + Quick start: ```bash diff --git a/docker/fhir-cerner/Dockerfile b/docker/fhir-cerner/Dockerfile index f53bb53..5e8fbdc 100644 --- a/docker/fhir-cerner/Dockerfile +++ b/docker/fhir-cerner/Dockerfile @@ -17,7 +17,7 @@ COPY config/ ./config/ RUN pip install --no-cache-dir -e ".[agents]" HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ - python -c "from agents.fhir_cerner_mcp import _make_server; print('ok')" || exit 1 + python -c "from agents.fhir_cerner_mcp import main; assert callable(main); print('ok')" || exit 1 CMD ["python", "-m", "agents.fhir_cerner_mcp"] diff --git a/docker/fhir-epic/Dockerfile b/docker/fhir-epic/Dockerfile index 633f031..3ff3036 100644 --- a/docker/fhir-epic/Dockerfile +++ b/docker/fhir-epic/Dockerfile @@ -17,7 +17,7 @@ COPY config/ ./config/ RUN pip install --no-cache-dir -e ".[agents]" HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ - python -c "from agents.fhir_epic_mcp import _make_server; print('ok')" || exit 1 + python -c "from agents.fhir_epic_mcp import main; assert callable(main); print('ok')" || exit 1 CMD ["python", "-m", "agents.fhir_epic_mcp"] diff --git a/docker/google-drive/Dockerfile b/docker/google-drive/Dockerfile index 43cbe2b..196e02a 100644 --- a/docker/google-drive/Dockerfile +++ b/docker/google-drive/Dockerfile @@ -17,7 +17,7 @@ COPY config/ ./config/ RUN pip install --no-cache-dir -e ".[agents]" HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ - python -c "from agents.google_drive_mcp import _make_server; print('ok')" || exit 1 + python -c "from agents.google_drive_mcp import main; assert callable(main); print('ok')" || exit 1 CMD ["python", "-m", "agents.google_drive_mcp"] diff --git a/docker/smtp/Dockerfile b/docker/smtp/Dockerfile index c4d725b..8b7f8fc 100644 --- a/docker/smtp/Dockerfile +++ b/docker/smtp/Dockerfile @@ -17,7 +17,7 @@ COPY config/ ./config/ RUN pip install --no-cache-dir -e ".[agents]" HEALTHCHECK --interval=30s --timeout=5s --start-period=10s CMD \ - python -c "from agents.smtp_mcp import _make_server; print('ok')" || exit 1 + python -c "from agents.smtp_mcp import main; assert callable(main); print('ok')" || exit 1 CMD ["python", "-m", "agents.smtp_mcp"] diff --git a/docs/google_drive_connector.md b/docs/google_drive_connector.md index a66f116..5f2b4e5 100644 --- a/docs/google_drive_connector.md +++ b/docs/google_drive_connector.md @@ -5,7 +5,7 @@ This document covers the Google Drive connector under `connectors/google_drive` 1. **[Google Drive service account setup](#google-drive-service-account-setup)** — Create a GCP service account, enable the Drive API, configure `.env`, share a folder, and verify connectivity. 2. **[REST API reference](#rest-api-reference)** — The `execute` action, all seven operations, request/response shapes, and the platform error taxonomy. -For **MCP** (e.g. ToolHive), the connector is exposed as the `google_drive_upload_file` tool. End-to-end agent setup is documented in [docs/toolhive_agent_scenario.md](toolhive_agent_scenario.md). +For **MCP** (e.g. ToolHive), tools are named `google_drive.` from the connector manifest (e.g. `google_drive.files.upload`). End-to-end agent setup is documented in [docs/toolhive_agent_scenario.md](toolhive_agent_scenario.md). --- @@ -339,7 +339,7 @@ The service account must have edit permission on the file. #### files.upload -Create a new file with text content. +Create a new file with content (text or binary). Request body: @@ -358,9 +358,10 @@ Fields: - `name` (string, required). - `mime_type` (string, required). - `parents` (array of string, optional). -- `content` (string, required): UTF-8 text content that will be uploaded. +- `content` (string, optional): UTF-8 text content that will be uploaded. +- `content_base64` (string, optional): base64-encoded binary content (e.g. PDFs, images). -Content is uploaded using `MediaInMemoryUpload`; this is suitable for small text payloads. +Exactly one of `content` or `content_base64` must be provided.\n+\n+Content is uploaded using `MediaInMemoryUpload`; this is suitable for small payloads.\n+\n+> For MCP callers (e.g. ToolHive): use canonical fields (`content` / `content_base64`). Legacy `media` / `media_body` shapes are not part of the public schema and should not be relied upon. #### files.delete diff --git a/docs/google_drive_upload_root_cause.md b/docs/google_drive_upload_root_cause.md new file mode 100644 index 0000000..97a6c56 --- /dev/null +++ b/docs/google_drive_upload_root_cause.md @@ -0,0 +1,91 @@ +# Google Drive `files.upload` — root cause analysis + +## Summary verdict + +| Layer | Verdict | +|--------|---------| +| **Connector** | **Not at fault** for the observed errors. It validates and executes `FilesUploadOperation` as documented. | +| **MCP server** | **Behaves as designed**: injects `action` only when absent (`setdefault`); does not override a wrong `action` from the caller. | +| **Agent / LLM** | **Primary fault**: tool arguments did not match the published JSON Schema (`mimeType` vs `mime_type`, `action: "upload"` vs `files.upload`, missing fields). | +| **Groq 429** | **Secondary**: rate limits after many failed retries increased token usage and ended the run. | + +**Overall:** **Agent-side** (LLM tool-call payload), not a connector bug. + +--- + +## Evidence from production logs (`terminals/11.txt`) + +| Step | Observed `google_drive.files.upload` args (excerpt) | Error | +|------|------------------------------------------------------|--------| +| 1 | `mimeType`, `name`, `parents`, `content` | Extra property `mimeType`; wrong field name for MIME type | +| 2 | `name`, `parents`, `content` (no `mime_type`) | `action` required (schema lists it as required) | +| 3 | `action: "upload"`, … | `mime_type` required / union mismatch | +| 4 | `mime_type` without correct `action` | `action` required | +| 5 | `action: "upload"`, `mime_type`, … | **`'files.upload' was expected`** — wrong discriminator | + +These align with **strict Pydantic validation** on `FilesUploadOperation` (`extra="forbid"`, discriminator `action`). + +--- + +## MCP contract (`tools/list`) + +For `google_drive.files.upload`, the manifest exposes **per-action** input schema (`FilesUploadOperation`), not the full union: + +- **`required`:** `action`, `name`, `mime_type` +- **`action`:** JSON Schema `const: "files.upload"` +- **No `mimeType`** property — only `mime_type` + +Source: [`src/bindings/mcp_server/server.py`](../src/bindings/mcp_server/server.py) (`list_tools` + `invoke_tool`), [`src/connectors/manifest.py`](../src/connectors/manifest.py), [`src/connectors/google_drive/schema.py`](../src/connectors/google_drive/schema.py). + +--- + +## Server dispatch behavior + +In `McpServer.invoke_tool`: + +```python +run_args = normalize_mcp_tool_arguments(connector_id, action, arguments) +if isinstance(connector, SDKConnector): + run_args.setdefault("action", action) +``` + +- If the LLM **omits** `action`, the server sets `action` to the suffix from the tool name (`files.upload`) → valid for minimal calls. +- If the LLM sends **`action: "upload"`**, `setdefault` **does not** replace it → validation fails (`union_tag_invalid`), matching log **`'files.upload' was expected`**. + +--- + +## Reproduction (local `invoke_tool`) + +| Payload | Result | +|---------|--------| +| `name`, `mime_type`, `parents`, `content` only (no `action`) | **Success** (server adds `action`) — assumes valid Drive credentials | +| `mimeType` instead of `mime_type` | `VALIDATION_ERROR`: `mime_type` missing, `mimeType` extra forbidden | +| `action: "upload"` + valid other fields | `VALIDATION_ERROR`: `union_tag_invalid` (expected tags include `files.upload`, not `upload`) | + +--- + +## Payload matrix + +| Issue | Owner | Notes | +|-------|--------|------| +| `mimeType` vs `mime_type` | Agent | Schema only defines `mime_type` | +| Missing `action` when schema says required | Agent / schema UX | Server can still inject `action` if omitted; LLM may omit and still work | +| `action: "upload"` | Agent | Must be literal `files.upload` | +| Nested `file` object | Agent | Not in schema | +| Connector rejects valid `files.upload` payload | N/A | Not observed | + +--- + +## Recommendations (optional follow-ups) + +1. **Agent prompt / tool-calling**: Implemented in [`src/agents/toolhive.py`](../src/agents/toolhive.py) — step 2 now states flat JSON, `mime_type`, and correct `action` / no nested `file`. +2. **Normalization** (server): Implemented in [`src/bindings/mcp_server/server.py`](../src/bindings/mcp_server/server.py) — `_normalize_google_drive_files_upload` maps `mimeType` → `mime_type`, coerces `action: "upload"` → `files.upload`, merges a nested `file` dict when canonical keys are absent, and strips `mimeType`. +3. **Groq**: Operational — smaller context, higher TPM tier, or fewer agent steps still help if the model ignores schema; normalization reduces validation failure loops. + +--- + +## References + +- [`src/connectors/google_drive/schema.py`](../src/connectors/google_drive/schema.py) — `FilesUploadOperation` +- [`src/bindings/mcp_server/server.py`](../src/bindings/mcp_server/server.py) — `normalize_mcp_tool_arguments`, `invoke_tool` +- [`src/agents/toolhive.py`](../src/agents/toolhive.py) — sends tool args to MCP as returned by the LLM; the MCP server normalizes Google Drive upload aliases before `connector.run`. diff --git a/docs/mcp-servers.md b/docs/mcp-servers.md index 1dfe8de..5885075 100644 --- a/docs/mcp-servers.md +++ b/docs/mcp-servers.md @@ -43,10 +43,12 @@ flowchart TD | Connector | Python entrypoint | Docker image | ToolHive name | MCP tool(s) exposed | |---|---|---|---|---| -| Google Drive | `python -m agents.google_drive_mcp` | `nw-google-drive` | `nw-google-drive` | `google_drive_upload_file` | -| SMART on FHIR (Epic) | `python -m agents.fhir_epic_mcp` | `nw-smartonfhir-epic` | `nw-smartonfhir-epic` | `fhir_epic_read_patient` | -| SMART on FHIR (Cerner) | `python -m agents.fhir_cerner_mcp` | `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner` | `fhir_cerner_read_patient` | -| SMTP | `python -m agents.smtp_mcp` | `nw-smtp` | `nw-smtp` | `smtp_send_email` | +| Google Drive | `python -m agents.google_drive_mcp` | `nw-google-drive` | `nw-google-drive` | All manifest actions for `google_drive` (names `google_drive.`, e.g. `google_drive.files.upload`) | +| SMART on FHIR (Epic) | `python -m agents.fhir_epic_mcp` | `nw-smartonfhir-epic` | `nw-smartonfhir-epic` | All manifest actions for `fhir_epic` (e.g. `fhir_epic.read_patient`) | +| SMART on FHIR (Cerner) | `python -m agents.fhir_cerner_mcp` | `nw-smartonfhir-cerner` | `nw-smartonfhir-cerner` | All manifest actions for `fhir_cerner` (e.g. `fhir_cerner.read_patient`) | +| SMTP | `python -m agents.smtp_mcp` | `nw-smtp` | `nw-smtp` | `smtp.send_email` | + +The unified server (`python -m agents.mcp_entrypoint`) exposes **every** connector enabled for MCP in `config/connectors.yaml` (e.g. `http_generic.request`, `stripe.charge`, plus the rows above). --- @@ -118,7 +120,7 @@ Register your application at the [Cerner Developer Portal](https://code.cerner.c #### `nw-smtp` -The SMTP MCP server exposes one tool: `smtp_send_email`. When running under ToolHive, inject these as secrets: +The SMTP MCP server exposes one tool: `smtp.send_email`. When running under ToolHive, inject these as secrets: | Variable | Description | |---|---| diff --git a/docs/toolhive_agent_scenario.md b/docs/toolhive_agent_scenario.md index 666c39b..1d4026e 100644 --- a/docs/toolhive_agent_scenario.md +++ b/docs/toolhive_agent_scenario.md @@ -36,10 +36,10 @@ This guide walks you through running the platform as an MCP server using ToolHiv ``` ToolHive UI ────────────────────────────────────────────────────── │ MCP Server (Docker): node-wire │ -│ ├── Tool: fhir_cerner_read_patient ← fetch patient from Cerner │ -│ ├── Tool: fhir_epic_read_patient ← fetch patient from Epic │ -│ ├── Tool: google_drive_upload_file ← write file to Drive │ -│ └── Tool: smtp_send_email ← email the summary │ +│ ├── Tool: fhir_cerner.read_patient ← fetch patient from Cerner │ +│ ├── Tool: fhir_epic.read_patient ← fetch patient from Epic │ +│ ├── Tool: google_drive.files.upload ← write file to Drive │ +│ └── Tool: smtp.send_email ← email the summary │ │ ↕ stdio → HTTP proxy │ ────────────────────────────────────────────────────────────────── ↕ MCP JSON-RPC over HTTP @@ -86,10 +86,10 @@ When running as an MCP server, the platform exposes 4 tools that AI agents can d | Tool | Description | |---|---| -| `fhir_cerner_read_patient` | Fetch a patient's record from a Cerner FHIR R4 endpoint | -| `fhir_epic_read_patient` | Fetch a patient's record from an Epic FHIR R4 endpoint | -| `google_drive_upload_file` | Create and upload a text file to Google Drive | -| `smtp_send_email` | Send an email via SMTP | +| `fhir_cerner.read_patient` | Fetch a patient's record from a Cerner FHIR R4 endpoint | +| `fhir_epic.read_patient` | Fetch a patient's record from an Epic FHIR R4 endpoint | +| `google_drive.files.upload` | Create and upload a text file to Google Drive | +| `smtp.send_email` | Send an email via SMTP | The agent uses an LLM's tool-calling capability to decide which tools to call, in what order, and with what parameters. @@ -317,7 +317,7 @@ In the ToolHive UI under **Installed**, you should see: |---|---| | Name | `node-wire-connectors` | | Status | `Running` | -| Tools | `fhir_cerner_read_patient`, `fhir_epic_read_patient`, `google_drive_upload_file`, `smtp_send_email` | +| Tools | Manifest-driven `.` (e.g. `fhir_cerner.read_patient`, `fhir_epic.read_patient`, `google_drive.files.upload`, `smtp.send_email`; unified server also lists Stripe, HTTP generic, and other MCP-enabled connectors) | | Endpoint | `http://localhost:/sse` | --- @@ -404,11 +404,11 @@ I have completed all three steps: 3. Sent a summary email to your-email@example.com with a link to the file. Steps executed (3): - ✓ Step 1: fhir_cerner_read_patient + ✓ Step 1: fhir_cerner.read_patient result : {"patient_id": "123*****", "full_name": "Nancy Smart", ...} - ✓ Step 2: google_drive_upload_file + ✓ Step 2: google_drive.files.upload result : {"file_id": "1XYZ...", "web_view_link": "https://docs.google.com/..."} - ✓ Step 3: smtp_send_email + ✓ Step 3: smtp.send_email result : {"sent": true} ``` @@ -545,7 +545,7 @@ connector-platform/ └── src/ └── agents/ ├── __init__.py - ├── mcp_entrypoint.py ← FastMCP server (4 tools) + ├── mcp_entrypoint.py ← MCP stdio server (manifest; all MCP connectors) ├── toolhive.py ← ReAct agent + CLI ├── llm_factory.py ← Provider factory └── providers/ diff --git a/playground/scenarios.py b/playground/scenarios.py index 584b9a4..6185319 100644 --- a/playground/scenarios.py +++ b/playground/scenarios.py @@ -1108,13 +1108,19 @@ class AgentChatResponse(BaseModel): AGENT_GUARDRAIL_PROMPT = ( "You are a healthcare data assistant. You have access to tools for fetching " "patient data from Cerner FHIR and Epic FHIR, uploading files to Google Drive, and sending " - "emails via SMTP.\n\n" + "emails via SMTP.\n" + "Tool names are `.` (e.g. `fhir_cerner.read_patient`, " + "`fhir_epic.read_patient`, `google_drive.files.upload`, `smtp.send_email`). " + "Use exactly the names and JSON-schema arguments from tools/list.\n\n" "WORKFLOW (MUST EXECUTE SEQUENTIALLY, ONE STRICT STEP AT A TIME):\n" "When asked to 'Send patient summaries via email' or similar tasks, you MUST follow this exact flow in order. DO NOT parallelize these steps:\n" - " 1. First turn: Search for the patient. (If you have a Patient ID, you DO NOT need their name or birthdate).\n" - " CRITICAL: If the user has NOT provided a patient ID or name in their message, you MUST ASK them for it. DO NOT call the search tool with a guessed or hallucinated ID like '12345'.\n" + " 1. First turn: Obtain patient demographics from the EHR.\n" + " - If the user gave a Patient ID: call `fhir_cerner.read_patient` or `fhir_epic.read_patient` with JSON `{\"resource_id\": \"\"}` (use Epic when the ID starts with 'e'). Do NOT use search_patients for a known ID.\n" + " - If there is NO Patient ID but there IS a name: use name fields or `search_patients` per tools/list schema (e.g. `given_name`, `family_name`, `birthdate`, or valid `search_params`).\n" + " - Use `search_patients` only when you have no ID, or after `read_patient` failed and you need a fallback.\n" + " CRITICAL: If the user has NOT provided a patient ID or name in their message, you MUST ASK them for it. DO NOT call tools with a guessed or hallucinated ID like '12345'.\n" " 2. Second turn: Once you have the patient data from step 1, create a file on Google Drive containing the masked patient summary. Do NOT use placeholder content.\n" - " 3. Third turn: Once step 2 returns a 'web_view_link', send an email with that exact link. Do NOT call the email tool until you have the link.\n" + " 3. Third turn: Once step 2 returns a shareable Drive URL (see `data.raw.webViewLink` from tool `google_drive.files.upload`), send an email with that exact link. Do NOT call the email tool until you have the link.\n" " CRITICAL: You MUST ask the user for the recipient email address if they haven't provided it. DO NOT guess email addresses like 'recipient_email@example.com'.\n" " CRITICAL: In the email body, you MUST insert the actual URL string returned from step 2 (e.g. 'https://drive.google.com/...'). Do NOT literally write the text ''.\n\n" "DATA PRIVACY & MASKING — follow these strictly:\n" @@ -1124,7 +1130,7 @@ class AgentChatResponse(BaseModel): " - NEVER use the placeholder values ('1990-05-12', '12724066', or 'Name') in your reports - always use the real patient data masked accordingly.\n" "- EMAIL WORKFLOW: When sending patient details to an email recipient:\n" " 1. ALWAYS upload the masked patient summary to Google Drive first.\n" - " 2. Use the 'web_view_link' returned by the google_drive_upload_file tool.\n" + " 2. Use `data.raw.webViewLink` from the `google_drive.files.upload` tool result.\n" " 3. In the email body, provide that link instead of the actual data.\n" " 4. The email body should be professional: 'Patient data summary from the EHR is available at the following secure link: [Link]'\n\n" "GUARDRAILS:\n" @@ -1172,6 +1178,7 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: ToolHiveMcpClient, StdioMcpClient, resolve_mcp_urls, + resolve_max_tool_failures, ) provider_name = os.environ.get("LLM_PROVIDER", "groq") @@ -1206,7 +1213,12 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: mcp_client = ToolHiveMcpClient(urls[0]) else: mcp_client = MultiMcpClient([ToolHiveMcpClient(u) for u in urls]) - agent = ToolHiveAgent(mcp_client, llm_provider, max_steps=10) + agent = ToolHiveAgent( + mcp_client, + llm_provider, + max_steps=10, + max_tool_failures=resolve_max_tool_failures(None), + ) agent._system_prompt = AGENT_GUARDRAIL_PROMPT run_result = await agent.run(task) # Fallback to local stdio if: @@ -1232,7 +1244,12 @@ async def agent_chat(payload: AgentChatInput) -> AgentChatResponse: logger.info("Agent Chat | using local stdio MCP transport") cmd = [sys.executable, "-m", "agents.mcp_entrypoint"] async with StdioMcpClient(cmd) as mcp_client: - agent = ToolHiveAgent(mcp_client, llm_provider, max_steps=10) + agent = ToolHiveAgent( + mcp_client, + llm_provider, + max_steps=10, + max_tool_failures=resolve_max_tool_failures(None), + ) agent._system_prompt = AGENT_GUARDRAIL_PROMPT run_result = await agent.run(task) diff --git a/pyproject.toml b/pyproject.toml index c275864..fcd7425 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,3 +58,7 @@ where = ["src"] requires = ["setuptools>=69.0.0", "wheel"] build-backend = "setuptools.build_meta" +[tool.pytest.ini_options] +pythonpath = ["src"] +asyncio_mode = "auto" + diff --git a/sample.env b/sample.env index 996e670..2620ce7 100644 --- a/sample.env +++ b/sample.env @@ -24,13 +24,15 @@ SMTP_USERNAME=your-email@gmail.com SMTP_PASSWORD=your-gmail-app-password # Stripe (optional / legacy demo) -stripe_api_key=sk_test_your_key_here +STRIPE_API_KEY=sk_test_your_key_here # ToolHive # Single-server (backward compatible) TOOLHIVE_MCP_URL=http://localhost:PORT/mcp # Multi-server (preferred for per-connector MCP servers) TOOLHIVE_MCP_URLS= +# Cap MCP tool JSON size sent back to the LLM (Groq on-demand TPM); default 12000 +# TOOLHIVE_MAX_TOOL_RESULT_CHARS=12000 # LLM Provider LLM_PROVIDER=groq diff --git a/src/agents/README.md b/src/agents/README.md index 22834ab..ec45f44 100644 --- a/src/agents/README.md +++ b/src/agents/README.md @@ -16,10 +16,10 @@ The `agents` module transforms static connectors (EHR, Google Drive, SMTP) into ## 🏗️ Core Architecture ### 1. **MCP Server (`mcp_entrypoint.py`)** -A high-performance server built on the [FastMCP](https://github.com/modelcontextprotocol/python-sdk) framework. -- **Dynamic Bindings**: Uses the `ConnectorFactory` to load platform connectors and expose them as MCP tools. -- **Data Protection**: Automatically extracts and summarizes raw FHIR resources to protect patient privacy and reduce LLM token consumption. -- **Flexible Transport**: Defaults to `stdio` transport for seamless integration with ToolHive, Claude Desktop, or custom proxies. +Stdio MCP server using the official [Model Context Protocol Python SDK](https://github.com/modelcontextprotocol/python-sdk). +- **Manifest-driven tools**: `McpServer` builds the tool list from connector metadata (`.`) and dispatches via `connector.run()`. +- **Unified entrypoint**: `python -m agents.mcp_entrypoint` exposes every connector enabled for MCP in `config/connectors.yaml`. +- **Per-connector images**: `fhir_cerner_mcp`, `fhir_epic_mcp`, `google_drive_mcp`, and `smtp_mcp` run the same server with a `connector_ids` filter. ### 2. **ToolHive Agent (`toolhive.py`)** A reference implementation of a ReAct-style agent designed for the **ToolHive** ecosystem. @@ -35,14 +35,18 @@ A modular factory system supporting diverse LLM backends: --- -## 🛠️ Available MCP Tools +## 🛠️ MCP tool naming -| Tool Name | Description | Connector | -| :--- | :--- | :--- | -| `fhir_cerner_read_patient` | Fetches patient demographics (Name, DOB, ID) from Cerner FHIR R4. | `fhir_cerner` | -| `fhir_epic_read_patient` | Fetches patient demographics from Epic FHIR R4. (IDs usually start with 'e'). | `fhir_epic` | -| `google_drive_upload_file` | Securely uploads text summaries or reports to a designated folder. | `google_drive` | -| `smtp_send_email` | Dispatches notifications or clinical summaries via secure SMTP. | `smtp` | +Tools are named **`{connector_id}.{action}`** as defined by each connector’s manifest (see `connectors/manifest.py` and `bindings/mcp_server/server.py`). Examples: + +| Example tool name | Connector | +| :--- | :--- | +| `fhir_cerner.read_patient` | Cerner FHIR | +| `fhir_epic.read_patient` | Epic FHIR | +| `google_drive.files.upload` | Google Drive | +| `smtp.send_email` | SMTP | + +Use **`tools/list`** for the exact names and JSON Schemas your deployment exposes. --- diff --git a/src/agents/fhir_cerner_mcp.py b/src/agents/fhir_cerner_mcp.py index fd2067c..e9ac4cd 100644 --- a/src/agents/fhir_cerner_mcp.py +++ b/src/agents/fhir_cerner_mcp.py @@ -1,16 +1,8 @@ -""" -FastMCP Server Entrypoint — SMART on FHIR (Cerner) -================================================= -Standalone MCP server exposing only the Cerner FHIR patient read tool. - -Usage: - python -m agents.fhir_cerner_mcp -""" +"""MCP Server — Cerner FHIR connector only. Usage: python -m agents.fhir_cerner_mcp""" from __future__ import annotations import logging import os -import uuid from dotenv import load_dotenv @@ -21,82 +13,15 @@ logger = logging.getLogger("agents.fhir_cerner_mcp") -def _make_server(): - try: - from mcp.server.fastmcp import FastMCP - except ImportError as exc: - raise ImportError("mcp SDK not installed. Run: pip install 'node-wire[agents]'") from exc - - from bindings.factory import ConnectorFactory - from connectors import auto_register - from connectors.fhir_cerner.schema import FhirCernerPatientReadInput - - auto_register() - factory = ConnectorFactory() - factory.load() - - mcp = FastMCP("nw-smartonfhir-cerner") - - @mcp.tool( - name="fhir_cerner_read_patient", - description=( - "Fetch a patient's demographic record from Cerner FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details." - ), - ) - async def fhir_cerner_read_patient( - patient_id: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - trace_id = str(uuid.uuid4()) - cerner = factory._connectors.get("fhir_cerner") - if not cerner: - raise RuntimeError("fhir_cerner connector not configured") - - if patient_id: - params = FhirCernerPatientReadInput(action="read_patient", resource_id=patient_id) - elif family_name or given_name or name: - params = FhirCernerPatientReadInput( - action="read_patient", - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError("Provide patient_id OR at least family_name / given_name / name") - - result = await cerner.internal_execute(params, trace_id=trace_id) - resource = result.resource - - name_parts = resource.get("name", [{}])[0] - full_name = " ".join(name_parts.get("given", []) + [name_parts.get("family", "")]).strip() - - addr = resource.get("address", [{}])[0] - full_addr = ( - f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}" - ).strip(", ") - - return { - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - "address_summary": full_addr, - } - - return mcp - - def main() -> None: - server = _make_server() - logger.info("Starting nw-smartonfhir-cerner MCP server (stdio transport)") - server.run() + from bindings.mcp_server.server import McpServer + + logger.info("Starting nw-smartonfhir-cerner MCP server (stdio, manifest-driven)") + McpServer( + server_name="nw-smartonfhir-cerner", + connector_ids=["fhir_cerner"], + ).run_stdio() if __name__ == "__main__": main() - diff --git a/src/agents/fhir_epic_mcp.py b/src/agents/fhir_epic_mcp.py index 5e6798e..c9fb60b 100644 --- a/src/agents/fhir_epic_mcp.py +++ b/src/agents/fhir_epic_mcp.py @@ -1,16 +1,8 @@ -""" -FastMCP Server Entrypoint — SMART on FHIR (Epic) -=============================================== -Standalone MCP server exposing only the Epic FHIR patient read tool. - -Usage: - python -m agents.fhir_epic_mcp -""" +"""MCP Server — Epic FHIR connector only. Usage: python -m agents.fhir_epic_mcp""" from __future__ import annotations import logging import os -import uuid from dotenv import load_dotenv @@ -21,84 +13,15 @@ logger = logging.getLogger("agents.fhir_epic_mcp") -def _make_server(): - try: - from mcp.server.fastmcp import FastMCP - except ImportError as exc: - raise ImportError("mcp SDK not installed. Run: pip install 'node-wire[agents]'") from exc - - from bindings.factory import ConnectorFactory - from connectors import auto_register - from connectors.fhir_epic.schema import FhirPatientReadInput as FhirEpicPatientReadInput - - auto_register() - factory = ConnectorFactory() - factory.load() - - mcp = FastMCP("nw-smartonfhir-epic") - - @mcp.tool( - name="fhir_epic_read_patient", - description=( - "Fetch a patient's demographic record from Epic FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details. " - "Epic IDs typically start with 'e' (e.g. 'e12345')." - ), - ) - async def fhir_epic_read_patient( - patient_id: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - trace_id = str(uuid.uuid4()) - epic = factory._connectors.get("fhir_epic") - if not epic: - raise RuntimeError("fhir_epic connector not configured") - - if patient_id: - params = FhirEpicPatientReadInput(action="read_patient", resource_id=patient_id) - elif family_name or given_name or name: - params = FhirEpicPatientReadInput( - action="read_patient", - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError("Provide patient_id OR at least family_name / given_name / name") - - result = await epic.internal_execute(params, trace_id=trace_id) - resource = result.resource - - name_parts = resource.get("name", [{}])[0] - full_name = " ".join(name_parts.get("given", []) + [name_parts.get("family", "")]).strip() - - addr = resource.get("address", [{}])[0] - full_addr = ( - f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}" - ).strip(", ") - - return { - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - "address_summary": full_addr, - "source": "Epic FHIR", - } - - return mcp - - def main() -> None: - server = _make_server() - logger.info("Starting nw-smartonfhir-epic MCP server (stdio transport)") - server.run() + from bindings.mcp_server.server import McpServer + + logger.info("Starting nw-smartonfhir-epic MCP server (stdio, manifest-driven)") + McpServer( + server_name="nw-smartonfhir-epic", + connector_ids=["fhir_epic"], + ).run_stdio() if __name__ == "__main__": main() - diff --git a/src/agents/google_drive_mcp.py b/src/agents/google_drive_mcp.py index 050a3ef..6591717 100644 --- a/src/agents/google_drive_mcp.py +++ b/src/agents/google_drive_mcp.py @@ -1,16 +1,8 @@ -""" -FastMCP Server Entrypoint — Google Drive -======================================= -Standalone MCP server exposing only the Google Drive tool. - -Usage: - python -m agents.google_drive_mcp -""" +"""MCP Server — Google Drive connector only. Usage: python -m agents.google_drive_mcp""" from __future__ import annotations import logging import os -import uuid from dotenv import load_dotenv @@ -21,69 +13,15 @@ logger = logging.getLogger("agents.google_drive_mcp") -def _make_server(): - try: - from mcp.server.fastmcp import FastMCP - except ImportError as exc: - raise ImportError("mcp SDK not installed. Run: pip install 'node-wire[agents]'") from exc - - from bindings.factory import ConnectorFactory - from connectors import auto_register - from connectors.google_drive.schema import GoogleDriveOperationInput - - auto_register() - factory = ConnectorFactory() - factory.load() - - mcp = FastMCP("nw-google-drive") - - @mcp.tool( - name="google_drive_upload_file", - description=( - "Upload a text file to Google Drive. " - "Returns the file ID and a shareable web view link." - ), - ) - async def google_drive_upload_file( - file_name: str, - content: str, - folder_id: str = os.environ.get("GOOGLE_DRIVE_FOLDER_ID", ""), - mime_type: str = "text/plain", - ) -> dict: - trace_id = str(uuid.uuid4()) - drive = factory._connectors.get("google_drive") - if not drive: - raise RuntimeError("google_drive connector not configured") - - payload: dict = { - "action": "files.upload", - "name": file_name, - "mime_type": mime_type, - "content": content, - } - if folder_id: - payload["parents"] = [folder_id] - - params = GoogleDriveOperationInput(**payload) - result = await drive.internal_execute(params, trace_id=trace_id) - - raw = result.raw - return { - "file_id": raw.get("id"), - "file_name": raw.get("name"), - "web_view_link": raw.get("webViewLink"), - "description": result.description, - } - - return mcp - - def main() -> None: - server = _make_server() - logger.info("Starting nw-google-drive MCP server (stdio transport)") - server.run() + from bindings.mcp_server.server import McpServer + + logger.info("Starting nw-google-drive MCP server (stdio, manifest-driven)") + McpServer( + server_name="nw-google-drive", + connector_ids=["google_drive"], + ).run_stdio() if __name__ == "__main__": main() - diff --git a/src/agents/mcp_entrypoint.py b/src/agents/mcp_entrypoint.py index ba9ac46..e506a94 100644 --- a/src/agents/mcp_entrypoint.py +++ b/src/agents/mcp_entrypoint.py @@ -1,40 +1,11 @@ -""" -FastMCP Server Entrypoint -========================= -This module is the main entrypoint for the Node Wire MCP server. -When run, it exposes healthcare workflow tools via the MCP stdio transport: - - • fhir_cerner_read_patient — fetch a patient from Cerner FHIR R4 - • fhir_cerner_search_patients — search multiple patients in Cerner - • fhir_cerner_search_encounters — search encounters in Cerner - • fhir_epic_read_patient — fetch a patient from Epic FHIR R4 - • fhir_epic_search_patients — search multiple patients in Epic - • fhir_epic_search_encounters — search encounters in Epic - • google_drive_upload_file — write a file to Google Drive - • smtp_send_email — send an email via SMTP - -ToolHive manages the container lifecycle, injects secrets as environment -variables, and proxies the stdio MCP stream to HTTP/SSE for clients. - -Usage (run directly by ToolHive): - python -m agents.mcp_entrypoint - -Environment variables (injected by ToolHive via --secret flags): - CERNER_FHIR_BASE_URL, CERNER_CLIENT_ID, CERNER_KID, - CERNER_PRIVATE_KEY, CERNER_TOKEN_URL, CERNER_SCOPES - GOOGLE_DRIVE_SA_JSON - SMTP_USERNAME, SMTP_PASSWORD, SMTP_HOST, SMTP_PORT -""" +"""MCP Server — all connectors exposed via MCP. Usage: python -m agents.mcp_entrypoint""" from __future__ import annotations -import json import logging import os -import uuid + from dotenv import load_dotenv -# Load .env variables for local stdio transport -# Try both CWD and script's own folder to be safe load_dotenv() load_dotenv(os.path.join(os.path.dirname(__file__), ".env")) @@ -42,590 +13,11 @@ logger = logging.getLogger("agents.mcp_entrypoint") -def _make_server(): - try: - from mcp.server.fastmcp import FastMCP - except ImportError as exc: - raise ImportError( - "mcp SDK not installed. Run: pip install 'node-wire[agents]'" - ) from exc - - from bindings.factory import ConnectorFactory - from connectors import auto_register - from connectors.fhir_cerner.schema import ( - FhirCernerPatientReadInput, - FhirCernerPatientSearchInput, - FhirCernerEncounterSearchInput, - ) - from connectors.fhir_epic.schema import ( - FhirPatientReadInput as FhirEpicPatientReadInput, - FhirPatientSearchInput as FhirEpicPatientSearchInput, - FhirEncounterSearchInput as FhirEpicEncounterSearchInput, - ) - from connectors.google_drive.schema import GoogleDriveOperationInput - from connectors.smtp.schema import SmtpSendInput - - auto_register() - factory = ConnectorFactory() - factory.load() - - mcp = FastMCP("Node Wire") - - # ------------------------------------------------------------------ - # Tool 1: Fetch patient from Cerner FHIR R4 - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_cerner_read_patient", - description=( - "Fetch a patient's demographic record from Cerner FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details." - ), - ) - async def fhir_cerner_read_patient( - patient_id: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_id : str - FHIR Patient resource ID (direct lookup — use this if you have it). - family_name : str - Patient family/last name (used for search when no ID is known). - given_name : str - Patient given/first name. - name : str - Full or partial patient name (convenience — use when you only have a - single combined name string and no split given/family available). - birthdate : str - Patient date of birth in YYYY-MM-DD format. - """ - trace_id = str(uuid.uuid4()) - cerner = factory._connectors.get("fhir_cerner") - if not cerner: - raise RuntimeError("fhir_cerner connector not configured") - - if patient_id: - params = FhirCernerPatientReadInput(action="read_patient", resource_id=patient_id) - elif family_name or given_name or name: - params = FhirCernerPatientReadInput( - action="read_patient", - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError("Provide patient_id OR at least family_name / given_name / name") - - result = await cerner.internal_execute(params, trace_id=trace_id) - resource = result.resource - - # Extract a clean summary for the LLM - name_parts = resource.get("name", [{}])[0] - full_name = " ".join( - name_parts.get("given", []) + [name_parts.get("family", "")] - ).strip() - - # Drastically simplify to keep token count low - ids = ", ".join([f"{i.get('system')}: {i.get('value')}" for i in resource.get("identifier", [])]) - phones = ", ".join([t.get("value") for t in resource.get("telecom", []) if t.get("system") == "phone"]) - emails = ", ".join([t.get("value") for t in resource.get("telecom", []) if t.get("system") == "email"]) - addr = resource.get("address", [{}])[0] - full_addr = f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}".strip(", ") - - return { - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - "address_summary": full_addr, - } - - # ------------------------------------------------------------------ - # Tool 2: Fetch patient from Epic FHIR R4 - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_epic_read_patient", - description=( - "Fetch a patient's demographic record from Epic FHIR R4. " - "Returns name, date of birth, gender, identifiers, and contact details. " - "Epic IDs typically start with 'e' (e.g. 'e12345')." - ), - ) - async def fhir_epic_read_patient( - patient_id: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_id : str - FHIR Patient resource ID (Epic specific, usually starts with 'e'). - family_name : str - Patient family/last name. - given_name : str - Patient given/first name. - name : str - Full or partial patient name (convenience — use when you only have a - single combined name string and no split given/family available). - birthdate : str - Patient date of birth in YYYY-MM-DD format. - """ - trace_id = str(uuid.uuid4()) - epic = factory._connectors.get("fhir_epic") - if not epic: - raise RuntimeError("fhir_epic connector not configured") - - if patient_id: - params = FhirEpicPatientReadInput(action="read_patient", resource_id=patient_id) - elif family_name or given_name or name: - params = FhirEpicPatientReadInput( - action="read_patient", - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError("Provide patient_id OR at least family_name / given_name / name") - - result = await epic.internal_execute(params, trace_id=trace_id) - resource = result.resource - - # Clean extract for LLM - name_parts = resource.get("name", [{}])[0] - full_name = " ".join( - name_parts.get("given", []) + [name_parts.get("family", "")] - ).strip() - - addr = resource.get("address", [{}])[0] - full_addr = f"{addr.get('line', [''])[0]}, {addr.get('city', '')}, {addr.get('state', '')} {addr.get('postalCode', '')}".strip(", ") - - return { - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - "address_summary": full_addr, - "source": "Epic FHIR", - } - - # ------------------------------------------------------------------ - # Tool 3: Search patients in Cerner (multi-ID or name-based) - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_cerner_search_patients", - description=( - "Search for multiple patients in Cerner FHIR R4. " - "Pass a comma-separated list of Patient IDs for concurrent lookup, " - "or supply name/birthdate fields for a name-based search returning all matches." - ), - ) - async def fhir_cerner_search_patients( - patient_ids: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_ids : str - Comma-separated Patient IDs for concurrent multi-ID lookup - (e.g. '12345678,87654321'). Takes priority over name fields. - family_name : str - Patient family/last name (name-search mode). - given_name : str - Patient given/first name (name-search mode). - name : str - Full or partial name string — FHIR 'name' token search. - birthdate : str - Date of birth in YYYY-MM-DD format (name-search mode). - """ - trace_id = str(uuid.uuid4()) - cerner = factory._connectors.get("fhir_cerner") - if not cerner: - raise RuntimeError("fhir_cerner connector not configured") - - if patient_ids.strip(): - ids = [i.strip() for i in patient_ids.split(",") if i.strip()] - params = FhirCernerPatientSearchInput(action="search_patients", resource_ids=ids) - elif family_name or given_name or name or birthdate: - params = FhirCernerPatientSearchInput( - action="search_patients", - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError( - "Provide patient_ids (comma-separated) OR at least one of " - "family_name / given_name / name / birthdate" - ) - - result = await cerner.internal_execute(params, trace_id=trace_id) - - summaries = [] - for resource in result.resources: - name_parts = resource.get("name", [{}])[0] - full_name = " ".join( - name_parts.get("given", []) + [name_parts.get("family", "")] - ).strip() - summaries.append({ - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - }) - - return { - "patients": summaries, - "total": result.total, - "errors": result.errors, - } - - # ------------------------------------------------------------------ - # Tool 4: Search patients in Epic (multi-ID or name-based) - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_epic_search_patients", - description=( - "Search for multiple patients in Epic FHIR R4. " - "Pass a comma-separated list of Patient IDs for concurrent lookup, " - "or supply name/birthdate fields for a name-based search returning all matches. " - "Epic IDs typically start with 'e' (e.g. 'e12345')." - ), - ) - async def fhir_epic_search_patients( - patient_ids: str = "", - family_name: str = "", - given_name: str = "", - name: str = "", - birthdate: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_ids : str - Comma-separated Patient IDs for concurrent multi-ID lookup - (e.g. 'eABC,eDEF'). Takes priority over name fields. - family_name : str - Patient family/last name (name-search mode). - given_name : str - Patient given/first name (name-search mode). - name : str - Full or partial name string — FHIR 'name' token search. - birthdate : str - Date of birth in YYYY-MM-DD format (name-search mode). - """ - trace_id = str(uuid.uuid4()) - epic = factory._connectors.get("fhir_epic") - if not epic: - raise RuntimeError("fhir_epic connector not configured") - - if patient_ids.strip(): - ids = [i.strip() for i in patient_ids.split(",") if i.strip()] - params = FhirEpicPatientSearchInput(action="search_patients", resource_ids=ids) - elif family_name or given_name or name or birthdate: - params = FhirEpicPatientSearchInput( - action="search_patients", - given_name=given_name or None, - family_name=family_name or None, - name=name or None, - birthdate=birthdate or None, - ) - else: - raise ValueError( - "Provide patient_ids (comma-separated) OR at least one of " - "family_name / given_name / name / birthdate" - ) - - result = await epic.internal_execute(params, trace_id=trace_id) - - summaries = [] - for resource in result.resources: - name_parts = resource.get("name", [{}])[0] - full_name = " ".join( - name_parts.get("given", []) + [name_parts.get("family", "")] - ).strip() - summaries.append({ - "patient_id": resource.get("id"), - "full_name": full_name or "Unknown", - "gender": resource.get("gender"), - "birth_date": resource.get("birthDate"), - "source": "Epic FHIR", - }) - - return { - "patients": summaries, - "total": result.total, - "errors": result.errors, - } - - # ------------------------------------------------------------------ - # Tool 5: Search encounters in Cerner FHIR R4 - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_cerner_search_encounters", - description=( - "Search for encounters in Cerner FHIR R4. " - "Returns a list of encounter summaries for a given patient or filter." - ), - ) - async def fhir_cerner_search_encounters( - patient_id: str = "", - status: str = "", - date: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_id : str - Cerner Patient ID to find encounters for. - status : str - Filter by encounter status (e.g. 'finished', 'in-progress'). - date : str - Filter by date or date range (e.g. '2024', 'ge2023-01-01'). - """ - trace_id = str(uuid.uuid4()) - cerner = factory._connectors.get("fhir_cerner") - if not cerner: - raise RuntimeError("fhir_cerner connector not configured") - - if not (patient_id or status or date): - raise ValueError("Provide at least one of patient_id / status / date") - - params = FhirCernerEncounterSearchInput( - action="search_encounter", - patient_id=patient_id or None, - status=status or None, - date=date or None, - ) - - result = await cerner.internal_execute(params, trace_id=trace_id) - - summaries = [] - for resource in result.resources: - summaries.append({ - "encounter_id": resource.get("id"), - "status": resource.get("status"), - "class": resource.get("class", {}).get("display"), - "period_start": resource.get("period", {}).get("start"), - "period_end": resource.get("period", {}).get("end"), - "type": resource.get("type", [{}])[0].get("text"), - }) - - return { - "encounters": summaries, - "total": result.total, - } - - # ------------------------------------------------------------------ - # Tool 6: Search encounters in Epic FHIR R4 - # ------------------------------------------------------------------ - - @mcp.tool( - name="fhir_epic_search_encounters", - description=( - "Search for encounters in Epic FHIR R4. " - "Returns a list of encounter summaries for a given patient or filter. " - "Epic IDs typically start with 'e' (e.g. 'e12345')." - ), - ) - async def fhir_epic_search_encounters( - patient_id: str = "", - status: str = "", - date: str = "", - ) -> dict: - """ - Parameters - ---------- - patient_id : str - Epic Patient ID to find encounters for. - status : str - Filter by encounter status (e.g. 'finished'). - date : str - Filter by date or date range. - """ - trace_id = str(uuid.uuid4()) - epic = factory._connectors.get("fhir_epic") - if not epic: - raise RuntimeError("fhir_epic connector not configured") - - if not (patient_id or status or date): - raise ValueError("Provide at least one of patient_id / status / date") - - params = FhirEpicEncounterSearchInput( - action="search_encounter", - patient_id=patient_id or None, - status=status or None, - date=date or None, - ) - - result = await epic.internal_execute(params, trace_id=trace_id) - - summaries = [] - for resource in result.resources: - summaries.append({ - "encounter_id": resource.get("id"), - "status": resource.get("status"), - "class": resource.get("class", {}).get("display"), - "period_start": resource.get("period", {}).get("start"), - "period_end": resource.get("period", {}).get("end"), - "type": resource.get("type", [{}])[0].get("text"), - }) - - return { - "encounters": summaries, - "total": result.total, - "source": "Epic FHIR", - } - - # ------------------------------------------------------------------ - # Tool 7: Upload a file to Google Drive - # ------------------------------------------------------------------ - - @mcp.tool( - name="google_drive_upload_file", - description=( - "Upload a text file to Google Drive. " - "Returns the file ID and a shareable web view link." - ), - ) - async def google_drive_upload_file( - file_name: str, - content: str, - folder_id: str = os.environ.get("GOOGLE_DRIVE_FOLDER_ID", ""), - mime_type: str = "text/plain", - ) -> dict: - """ - Parameters - ---------- - file_name : str - Name for the file in Google Drive (e.g. 'patient_summary_12345.txt'). - content : str - UTF-8 text content to write into the file. - folder_id : str - Optional Google Drive folder ID to place the file in. - mime_type : str - MIME type (default: text/plain). - """ - trace_id = str(uuid.uuid4()) - drive = factory._connectors.get("google_drive") - if not drive: - raise RuntimeError("google_drive connector not configured") - - payload: dict = { - "action": "files.upload", - "name": file_name, - "mime_type": mime_type, - "content": content, - } - if folder_id: - payload["parents"] = [folder_id] - - params = GoogleDriveOperationInput(**payload) - result = await drive.internal_execute(params, trace_id=trace_id) - - raw = result.raw - return { - "file_id": raw.get("id"), - "file_name": raw.get("name"), - "web_view_link": raw.get("webViewLink"), - "description": result.description, - } - - # ------------------------------------------------------------------ - # Tool 8: Send email via SMTP - # ------------------------------------------------------------------ - - @mcp.tool( - name="smtp_send_email", - description=( - "Send an email to a recipient via SMTP. " - "Credentials are picked up from environment variables." - ), - ) - async def smtp_send_email( - to_email: str, - subject: str, - body: str, - from_email: str = "", - ) -> dict: - """ - Parameters - ---------- - to_email : str - Recipient email address. - subject : str - Email subject line. - body : str - Plain-text email body. - from_email : str - Sender address — defaults to SMTP_USERNAME env var if empty. - """ - trace_id = str(uuid.uuid4()) - smtp = factory._connectors.get("smtp") - if not smtp: - raise RuntimeError("smtp connector not configured") - - smtp_host = os.environ.get("SMTP_HOST", "smtp.gmail.com").strip(" '\"") - smtp_port_raw = os.environ.get("SMTP_PORT", "587").strip(" '\"") - smtp_port = int(smtp_port_raw) - smtp_use_tls = os.environ.get("SMTP_USE_TLS", "true").lower() == "true" - - # Guardrail: Handle placeholder strings from LLM or empty input - sender = from_email.strip(" '\"") - if not sender or "@" not in sender or "system_default" in sender: - sender = (os.environ.get("FROM_EMAIL") or os.environ.get("SMTP_USERNAME") or "noreply@node-wire.local").strip(" '\"") - - # Pydantic EmailStr does not like "Name " - # Extract just the email part if needed - import re - def _extract_email(s: str) -> str: - match = re.search(r"<(.+?)>", s) - return match.group(1) if match else s.strip() - - sender = _extract_email(sender) - recipient = _extract_email(to_email) - - logger.info("SMTP Tool | from=%s to=%s subject=%s", sender, recipient, subject) - - params = SmtpSendInput( - host=smtp_host, - port=smtp_port, - use_tls=smtp_use_tls, - username_secret_key="SMTP_USERNAME", - password_secret_key="SMTP_PASSWORD", - from_email=sender, - to=[recipient], - subject=subject, - body=body, - ) - result = await smtp.internal_execute(params, trace_id=trace_id) - return {"sent": result.sent, "message_id": result.message_id} - - return mcp - - def main() -> None: - server = _make_server() - logger.info("Starting Node Wire MCP server (stdio transport)") - server.run() # stdio — ToolHive proxies this to HTTP/SSE + from bindings.mcp_server.server import McpServer + + logger.info("Starting Node Wire MCP server (stdio, manifest-driven)") + McpServer(server_name="node-wire").run_stdio() if __name__ == "__main__": diff --git a/src/agents/smtp_mcp.py b/src/agents/smtp_mcp.py index 80c147c..eb86d54 100644 --- a/src/agents/smtp_mcp.py +++ b/src/agents/smtp_mcp.py @@ -1,25 +1,8 @@ -""" -FastMCP Server Entrypoint — SMTP -================================ -Standalone MCP server exposing only the SMTP email tool. - -Usage: - python -m agents.smtp_mcp - -Environment variables: - SMTP_HOST (default: smtp.gmail.com) - SMTP_PORT (default: 587) - SMTP_USE_TLS (default: true) - SMTP_USERNAME - SMTP_PASSWORD - FROM_EMAIL (optional; fallback sender address) -""" +"""MCP Server — SMTP connector only. Usage: python -m agents.smtp_mcp""" from __future__ import annotations import logging import os -import re -import uuid from dotenv import load_dotenv @@ -30,87 +13,15 @@ logger = logging.getLogger("agents.smtp_mcp") -def _extract_email(value: str) -> str: - # Pydantic EmailStr does not like "Name " - match = re.search(r"<(.+?)>", value) - return (match.group(1) if match else value).strip() - - -def _make_server(): - try: - from mcp.server.fastmcp import FastMCP - except ImportError as exc: - raise ImportError("mcp SDK not installed. Run: pip install 'node-wire[agents]'") from exc - - from bindings.factory import ConnectorFactory - from connectors import auto_register - from connectors.smtp.schema import SmtpSendInput - - auto_register() - factory = ConnectorFactory() - factory.load() - - mcp = FastMCP("nw-smtp") - - @mcp.tool( - name="smtp_send_email", - description=( - "Send an email to a recipient via SMTP. " - "Credentials are picked up from environment variables." - ), - ) - async def smtp_send_email( - to_email: str, - subject: str, - body: str, - from_email: str = "", - ) -> dict: - trace_id = str(uuid.uuid4()) - smtp = factory._connectors.get("smtp") - if not smtp: - raise RuntimeError("smtp connector not configured") - - smtp_host = os.environ.get("SMTP_HOST", "smtp.gmail.com").strip(" '\"") - smtp_port_raw = os.environ.get("SMTP_PORT", "587").strip(" '\"") - smtp_port = int(smtp_port_raw) - smtp_use_tls = os.environ.get("SMTP_USE_TLS", "true").lower() == "true" - - sender = from_email.strip(" '\"") - if not sender or "@" not in sender or "system_default" in sender: - sender = ( - os.environ.get("FROM_EMAIL") - or os.environ.get("SMTP_USERNAME") - or "noreply@node-wire.local" - ).strip(" '\"") - - sender = _extract_email(sender) - recipient = _extract_email(to_email) - - logger.info("SMTP Tool | from=%s to=%s subject=%s", sender, recipient, subject) - - params = SmtpSendInput( - host=smtp_host, - port=smtp_port, - use_tls=smtp_use_tls, - username_secret_key="SMTP_USERNAME", - password_secret_key="SMTP_PASSWORD", - from_email=sender, - to=[recipient], - subject=subject, - body=body, - ) - result = await smtp.internal_execute(params, trace_id=trace_id) - return {"sent": result.sent, "message_id": getattr(result, "message_id", None)} - - return mcp - - def main() -> None: - server = _make_server() - logger.info("Starting nw-smtp MCP server (stdio transport)") - server.run() + from bindings.mcp_server.server import McpServer + + logger.info("Starting nw-smtp MCP server (stdio, manifest-driven)") + McpServer( + server_name="nw-smtp", + connector_ids=["smtp"], + ).run_stdio() if __name__ == "__main__": main() - diff --git a/src/agents/toolhive.py b/src/agents/toolhive.py index 884f949..a3a3cdc 100644 --- a/src/agents/toolhive.py +++ b/src/agents/toolhive.py @@ -4,9 +4,9 @@ A ReAct-style AI agent that connects to an MCP server running in ToolHive, discovers its tools, and orchestrates a healthcare workflow: - 1. Fetch patient details via fhir_cerner_read_patient or fhir_epic_read_patient - 2. Write a patient summary file via google_drive_upload_file - 3. Email the summary via smtp_send_email + 1. Fetch patient details via fhir_cerner.read_patient / fhir_epic.read_patient (or search_* tools) + 2. Write a patient summary file via google_drive.files.upload + 3. Email the summary via smtp.send_email The LLM backend is fully configurable via the LLM_PROVIDER env var. @@ -23,6 +23,7 @@ Environment variables: TOOLHIVE_MCP_URL : MCP proxy URL from ToolHive UI (e.g. http://localhost:PORT/mcp) TOOLHIVE_MCP_URLS: Comma-separated MCP proxy URLs (multi-server) + TOOLHIVE_MAX_TOOL_FAILURES: Stop after this many failed invocations per tool name (default: 2) LLM_PROVIDER : groq | openai | gemini | anthropic (default: groq) GROQ_API_KEY : (when using groq) OPENAI_API_KEY : (when using openai) @@ -52,6 +53,77 @@ logger = logging.getLogger("agents.toolhive") +def truncate_tool_result_for_llm(text: str) -> str: + """ + Cap tool output size sent to the LLM so providers with strict limits (e.g. Groq + on-demand TPM) do not fail with 413 / oversized requests after large FHIR payloads. + + Full raw output remains in AgentStep.tool_result for logging; only the message + passed back into the chat is truncated. + + Override with env TOOLHIVE_MAX_TOOL_RESULT_CHARS (default 12000). Use 0 to disable. + """ + raw = (os.environ.get("TOOLHIVE_MAX_TOOL_RESULT_CHARS") or "12000").strip() + try: + max_chars = int(raw) + except ValueError: + max_chars = 12000 + if max_chars <= 0 or len(text) <= max_chars: + return text + omitted = len(text) - max_chars + return ( + text[:max_chars] + + "\n\n[... truncated " + + str(omitted) + + " characters for LLM context limits; use visible fields for next steps.]" + ) + + +def resolve_max_tool_failures(override: Optional[int] = None) -> int: + """ + Max failed tool invocations per tool name before aborting the agent run. + ``override`` wins; otherwise ``TOOLHIVE_MAX_TOOL_FAILURES`` (default 2). Minimum 1. + """ + if override is not None: + return max(1, int(override)) + raw = (os.environ.get("TOOLHIVE_MAX_TOOL_FAILURES") or "2").strip() + try: + n = int(raw) + except ValueError: + n = 2 + return max(1, n) + + +def _is_tool_failure(tool_result: str) -> bool: + """True if MCP/connector reported a failed tool outcome (not empty success).""" + if not tool_result or not tool_result.strip(): + return False + t = tool_result.strip() + if t.startswith("ERROR:"): + return True + low = t.lower() + if "input validation error" in low: + return True + if "validation error" in low and "input" in low: + return True + if t.startswith("{"): + try: + data = json.loads(t) + if isinstance(data, dict) and data.get("success") is False: + return True + except json.JSONDecodeError: + pass + return False + + +def _tool_failure_abort_message(tool_name: str, max_failures: int) -> str: + return ( + f'The tool "{tool_name}" failed {max_failures} times in a row. ' + "Please check the parameters against the schema from tools/list, " + "or tell me if I should use a different tool or approach." + ) + + # --------------------------------------------------------------------------- # Result model # --------------------------------------------------------------------------- @@ -306,7 +378,7 @@ class ToolHiveAgent: 2. Enters a ReAct loop: send task + tools to LLM → if tool call → invoke tool → append result → repeat. 3. Stops when the LLM returns a final answer (no tool calls) or - ``max_steps`` is reached. + ``max_steps`` is reached, or the same tool fails ``max_tool_failures`` times. """ def __init__( @@ -314,20 +386,30 @@ def __init__( mcp_client: McpClient, llm_provider: Any, # BaseLLMProvider max_steps: int = 10, + max_tool_failures: Optional[int] = None, ) -> None: self._mcp = mcp_client self._llm = llm_provider self._max_steps = max_steps + self._max_tool_failures = resolve_max_tool_failures(max_tool_failures) self._system_prompt: str = ( "You are a healthcare data assistant. You have access to tools for fetching " "patient data from Cerner FHIR and Epic FHIR, uploading files to Google Drive, and sending " - "emails via SMTP.\n\n" + "emails via SMTP.\n" + "Tool names are `.` (e.g. `fhir_cerner.read_patient`, " + "`fhir_epic.read_patient`, `google_drive.files.upload`, `smtp.send_email`). " + "Use exactly the names and JSON-schema arguments from tools/list.\n\n" "WORKFLOW (MUST EXECUTE SEQUENTIALLY, ONE STRICT STEP AT A TIME):\n" "When asked to 'Send patient summaries via email' or similar tasks, you MUST follow this exact flow in order. DO NOT parallelize these steps:\n" - " 1. First turn: Search for the patient. (If you have a Patient ID, you DO NOT need their name or birthdate).\n" - " CRITICAL: If the user has NOT provided a patient ID or name in their message, you MUST ASK them for it. DO NOT call the search tool with a guessed or hallucinated ID like '12345'.\n" + " 1. First turn: Obtain patient demographics from the EHR.\n" + " - If the user gave a Patient ID: call `fhir_cerner.read_patient` or `fhir_epic.read_patient` with JSON `{\"resource_id\": \"\"}` (use Epic when the ID starts with 'e'). Do NOT use search_patients for a known ID.\n" + " - If there is NO Patient ID but there IS a name: use name fields or `search_patients` per tools/list schema (e.g. `given_name`, `family_name`, `birthdate`, or valid `search_params`).\n" + " - Use `search_patients` only when you have no ID, or after `read_patient` failed and you need a fallback.\n" + " CRITICAL: If the user has NOT provided a patient ID or name in their message, you MUST ASK them for it. DO NOT call tools with a guessed or hallucinated ID like '12345'.\n" " 2. Second turn: Once you have the patient data from step 1, create a file on Google Drive containing the masked patient summary. Do NOT use placeholder content.\n" - " 3. Third turn: Once step 2 returns a 'web_view_link', send an email with that exact link. Do NOT call the email tool until you have the link.\n" + " For `google_drive.files.upload`, pass a flat JSON object: `name`, `mime_type` (snake_case — not `mimeType`), `parents`, and `content` (or `content_base64`). " + "If you include `action`, it must be exactly `files.upload`. Do not nest fields under a `file` object. Do NOT pass `media` / `media_body`.\n" + " 3. Third turn: Once step 2 returns a shareable Drive URL (see `data.raw.webViewLink` from tool `google_drive.files.upload`), send an email with that exact link. Do NOT call the email tool until you have the link.\n" " CRITICAL: You MUST ask the user for the recipient email address if they haven't provided it. DO NOT guess email addresses like 'recipient_email@example.com'.\n" " CRITICAL: In the email body, you MUST insert the actual URL string returned from step 2 (e.g. 'https://drive.google.com/...'). Do NOT literally write the text ''.\n\n" "DATA PRIVACY & MASKING — follow these strictly:\n" @@ -337,7 +419,7 @@ def __init__( " - NEVER use the placeholder values ('1990-05-12', '12724066', or 'Name') in your reports - always use the real patient data masked accordingly.\n" "- EMAIL WORKFLOW: When sending patient details to an email recipient:\n" " 1. ALWAYS upload the masked patient summary to Google Drive first.\n" - " 2. Use the 'web_view_link' returned by the google_drive_upload_file tool.\n" + " 2. Use `data.raw.webViewLink` from the `google_drive.files.upload` tool result.\n" " 3. In the email body, provide that link instead of the actual data.\n" " 4. The email body should be professional: 'Patient data summary from the EHR is available at the following secure link: [Link]'\n\n" "GUARDRAILS:\n" @@ -383,6 +465,9 @@ async def run(self, task: str) -> AgentRunResult: ] # 3. ReAct loop + tool_failures: Dict[str, int] = {} + abort_after_tool_failures = False + for step_num in range(1, self._max_steps + 1): logger.info("Agent step %d / %d", step_num, self._max_steps) @@ -428,12 +513,34 @@ async def run(self, task: str) -> AgentRunResult: agent_step.tool_result = tool_result_str result.steps.append(agent_step) + llm_tool_content = truncate_tool_result_for_llm(tool_result_str) + if len(llm_tool_content) < len(tool_result_str): + logger.info( + "Tool %s result truncated for LLM: %d -> %d chars", + tc.name, + len(tool_result_str), + len(llm_tool_content), + ) + messages.append(LLMMessage( role="tool", - content=tool_result_str, + content=llm_tool_content, tool_call_id=tc.id, name=tc.name, )) + + if _is_tool_failure(tool_result_str): + tool_failures[tc.name] = tool_failures.get(tc.name, 0) + 1 + if tool_failures[tc.name] >= self._max_tool_failures: + msg = _tool_failure_abort_message(tc.name, self._max_tool_failures) + result.error = msg + result.final_answer = msg + logger.warning("Stopping agent: %s", msg) + abort_after_tool_failures = True + break + + if abort_after_tool_failures: + break else: # Hit max_steps without a final answer result.error = f"Agent reached max_steps ({self._max_steps}) without completing the task." @@ -474,10 +581,20 @@ async def _run_agent(args: argparse.Namespace) -> None: # Use the client (handle async context for stdio) if isinstance(mcp_client_context, StdioMcpClient): async with mcp_client_context as mcp_client: - agent = ToolHiveAgent(mcp_client, provider, max_steps=args.max_steps) + agent = ToolHiveAgent( + mcp_client, + provider, + max_steps=args.max_steps, + max_tool_failures=args.max_tool_failures, + ) await _execute_task(agent, args, llm_provider_name, "local-stdio") else: - agent = ToolHiveAgent(mcp_client_context, provider, max_steps=args.max_steps) + agent = ToolHiveAgent( + mcp_client_context, + provider, + max_steps=args.max_steps, + max_tool_failures=args.max_tool_failures, + ) await _execute_task(agent, args, llm_provider_name, ",".join(urls)) @@ -535,6 +652,12 @@ def main() -> None: parser.add_argument("--recipient-email", required=True, help="Email address to send the summary to") parser.add_argument("--drive-folder-id", default=os.environ.get("GOOGLE_DRIVE_FOLDER_ID", ""), help="Google Drive folder ID (optional)") parser.add_argument("--max-steps", type=int, default=10, help="Maximum agent steps (default: 10)") + parser.add_argument( + "--max-tool-failures", + type=int, + default=None, + help="Stop after this many failed calls per tool name (default: env TOOLHIVE_MAX_TOOL_FAILURES or 2)", + ) parser.add_argument("--local", action="store_true", help="Run against local server via stdio (no proxy)") args = parser.parse_args() diff --git a/src/bindings/mcp_server/server.py b/src/bindings/mcp_server/server.py index 5bbed57..29fb3a4 100644 --- a/src/bindings/mcp_server/server.py +++ b/src/bindings/mcp_server/server.py @@ -2,7 +2,7 @@ import json import logging -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from bindings.factory import ConnectorFactory from connectors import auto_register @@ -12,17 +12,182 @@ logger = logging.getLogger("bindings.mcp_server") +def _split_ids(value: Any) -> List[str]: + """Turn comma-separated string or list into a list of non-empty IDs.""" + if value is None: + return [] + if isinstance(value, list): + return [str(x).strip() for x in value if str(x).strip()] + s = str(value).strip() + if not s: + return [] + return [p.strip() for p in s.split(",") if p.strip()] + + +def _normalize_search_params_keys(sp: Dict[str, Any]) -> Dict[str, Any]: + """Map legacy/LLM keys inside search_params to FHIR-friendly names.""" + if not sp: + return {} + out = dict(sp) + # patientId is not a standard FHIR Patient search param; identifier is typical for MRN-style lookup + if "patientId" in out and "identifier" not in out: + out["identifier"] = out.pop("patientId") + if "givenName" in out and "given" not in out: + out["given"] = out.pop("givenName") + if "familyName" in out and "family" not in out: + out["family"] = out.pop("familyName") + return out + + +def _is_missing_or_blank(value: Any) -> bool: + if value is None: + return True + if isinstance(value, str) and not value.strip(): + return True + return False + + +def _normalize_google_drive_files_upload(args: Dict[str, Any]) -> None: + """ + Map common LLM mistakes for files.upload to FilesUploadOperation fields. + Mutates args in place. Canonical keys already set on the root win over aliases/nesting. + """ + # Legacy alias: callers sometimes pass a `media` object/string (Google SDK-ish). + # Our connector schema is strict (extra=forbid); normalize `media` into canonical + # `content` (text) / `content_base64` (binary) + metadata, then drop it. + media = args.get("media") + if media is not None: + # Metadata aliases under media + if isinstance(media, dict): + if _is_missing_or_blank(args.get("name")) and not _is_missing_or_blank( + media.get("name") + ): + args["name"] = media.get("name") + + if _is_missing_or_blank(args.get("mime_type")): + mt = media.get("mime_type") or media.get("mimeType") + if not _is_missing_or_blank(mt): + args["mime_type"] = mt + + if _is_missing_or_blank(args.get("parents")): + parents = media.get("parents") + if isinstance(parents, list) and parents: + args["parents"] = parents + elif isinstance(parents, str) and parents.strip(): + args["parents"] = _split_ids(parents) + + # Content aliases under media (prefer binary if provided) + if _is_missing_or_blank(args.get("content_base64")) and _is_missing_or_blank( + args.get("content") + ): + b64 = ( + media.get("content_base64") + or media.get("base64") + or media.get("data") + ) + if not _is_missing_or_blank(b64): + args["content_base64"] = b64 + else: + text = media.get("content") or media.get("text") or media.get("body") + if not _is_missing_or_blank(text): + args["content"] = text + elif isinstance(media, str): + # Treat plain-string media as text content. + if _is_missing_or_blank(args.get("content_base64")) and _is_missing_or_blank( + args.get("content") + ): + if media.strip(): + args["content"] = media + + args.pop("media", None) + + # Some clients also try `media_body` (googleapiclient kwarg). It is never part of + # the MCP schema; drop it so canonical fields can validate. + args.pop("media_body", None) + + nested = args.get("file") + if isinstance(nested, dict): + for key in ("name", "mime_type", "parents", "content", "content_base64"): + if key in nested and _is_missing_or_blank(args.get(key)): + args[key] = nested[key] + if _is_missing_or_blank(args.get("mime_type")) and nested.get("mimeType"): + args["mime_type"] = nested["mimeType"] + args.pop("file", None) + + if not _is_missing_or_blank(args.get("mimeType")) and _is_missing_or_blank( + args.get("mime_type") + ): + args["mime_type"] = args["mimeType"] + args.pop("mimeType", None) + + if args.get("action") == "upload": + args["action"] = "files.upload" + + +def normalize_mcp_tool_arguments( + connector_id: str, action: str, arguments: Dict[str, Any] +) -> Dict[str, Any]: + """ + Map legacy FastMCP / LLM aliases to canonical connector schema fields. + + Conservative: if canonical keys are already set, aliases are ignored. + """ + args = dict(arguments) + + if connector_id in ("fhir_cerner", "fhir_epic") and action == "read_patient": + if not (args.get("resource_id") or "").strip(): + pid = args.get("patient_id") or args.get("patientId") + if pid is not None and str(pid).strip(): + args["resource_id"] = str(pid).strip() + args.pop("patient_id", None) + args.pop("patientId", None) + if not args.get("family_name") and args.get("familyName"): + args["family_name"] = args.pop("familyName") + if not args.get("given_name") and args.get("givenName"): + args["given_name"] = args.pop("givenName") + if args.get("search_params") and isinstance(args["search_params"], dict): + args["search_params"] = _normalize_search_params_keys(args["search_params"]) + + elif connector_id in ("fhir_cerner", "fhir_epic") and action == "search_patients": + if not args.get("resource_ids"): + raw = args.get("patient_ids") or args.get("patientIds") + ids = _split_ids(raw) + if ids: + args["resource_ids"] = ids + args.pop("patient_ids", None) + args.pop("patientIds", None) + if not args.get("family_name") and args.get("familyName"): + args["family_name"] = args.pop("familyName") + if not args.get("given_name") and args.get("givenName"): + args["given_name"] = args.pop("givenName") + if args.get("search_params") and isinstance(args["search_params"], dict): + args["search_params"] = _normalize_search_params_keys(args["search_params"]) + + elif connector_id == "google_drive" and action == "files.upload": + _normalize_google_drive_files_upload(args) + + return args + + class McpServer: """ - Minimal MCP-style server abstraction for the POC. + Manifest-driven MCP server: tools come from connector metadata; execution + dispatches through ConnectorFactory and connector.run(). - This does not implement the full Model Context Protocol over JSON-RPC, - but exposes two conceptual operations: - - list_tools(): returns connector/actions manifest - - invoke_tool(name, arguments): executes the corresponding connector + Use list_tools() / invoke_tool() for programmatic access, or run_stdio() + for a full MCP stdio transport. """ - def __init__(self) -> None: + def __init__( + self, + *, + server_name: str = "node-wire", + connector_ids: Optional[List[str]] = None, + ) -> None: + self._server_name = server_name + self._connector_ids: Optional[frozenset[str]] = ( + None if connector_ids is None else frozenset(connector_ids) + ) auto_register() self._factory = ConnectorFactory() self._factory.load() @@ -32,11 +197,15 @@ def list_tools(self) -> List[Dict[str, Any]]: manifest = build_manifest(connectors) tools: List[Dict[str, Any]] = [] for entry in manifest: + cid = entry["connector_id"] + if self._connector_ids is not None and cid not in self._connector_ids: + continue tools.append( { - "name": f"{entry['connector_id']}.{entry['action']}", - "description": f"{entry['connector_id']} {entry['action']} connector action", + "name": f"{cid}.{entry['action']}", + "description": f"{cid} {entry['action']} connector action", "input_schema": entry["input_schema"], + "output_schema": entry["output_schema"], } ) return tools @@ -47,20 +216,63 @@ async def invoke_tool(self, name: str, arguments: Dict[str, Any]) -> Dict[str, A except ValueError: raise ValueError("Tool name must be in the form '.'") + if self._connector_ids is not None and connector_id not in self._connector_ids: + raise ValueError( + f"Connector {connector_id!r} is not allowed on this MCP server." + ) + connector = self._factory.get_for_protocol(connector_id, "mcp") if connector is None: raise ValueError(f"Connector {connector_id!r} is not available via MCP.") - run_args = dict(arguments) + run_args = normalize_mcp_tool_arguments(connector_id, action, arguments) if isinstance(connector, SDKConnector): run_args.setdefault("action", action) response = await connector.run(run_args) return response.model_dump() + async def _run_stdio_async(self) -> None: + from mcp.server import NotificationOptions, Server as LowLevelServer + from mcp.server.stdio import stdio_server + from mcp.types import Tool + + low = LowLevelServer(self._server_name) + + @low.list_tools() + async def handle_list_tools() -> list[Tool]: + out: list[Tool] = [] + for t in self.list_tools(): + kwargs: Dict[str, Any] = { + "name": t["name"], + "description": t["description"], + "inputSchema": t["input_schema"], + } + if t.get("output_schema") is not None: + kwargs["outputSchema"] = t["output_schema"] + out.append(Tool(**kwargs)) + return out + + @low.call_tool() + async def handle_call_tool(tool_name: str, arguments: dict) -> dict: + return await self.invoke_tool(tool_name, arguments or {}) + + async with stdio_server() as (read_stream, write_stream): + await low.run( + read_stream, + write_stream, + low.create_initialization_options( + notification_options=NotificationOptions() + ), + ) + + def run_stdio(self) -> None: + import anyio + + anyio.run(self._run_stdio_async) + if __name__ == "__main__": # Simple demo runner that prints tool list and exits. server = McpServer() print(json.dumps(server.list_tools(), indent=2)) - diff --git a/src/connectors/smtp/schema.py b/src/connectors/smtp/schema.py index 1698024..9724498 100644 --- a/src/connectors/smtp/schema.py +++ b/src/connectors/smtp/schema.py @@ -1,23 +1,86 @@ from __future__ import annotations -from typing import List, Optional +import os +import re +from typing import Any, List, Optional -from pydantic import BaseModel, EmailStr +from pydantic import BaseModel, EmailStr, model_validator + + +def _strip_env(s: str) -> str: + return s.strip(" '\"") + + +def _extract_email(value: str) -> str: + """Pydantic EmailStr does not accept 'Name '.""" + match = re.search(r"<(.+?)>", value) + return (match.group(1) if match else value).strip() class SmtpSendInput(BaseModel): - host: str - port: int + """ + SMTP send payload. Connection settings default from environment when omitted + so MCP/REST callers only need to, subject, body. + """ + + host: str = "" + port: int = 0 use_tls: bool = True - username_secret_key: str - password_secret_key: str + username_secret_key: str = "SMTP_USERNAME" + password_secret_key: str = "SMTP_PASSWORD" from_email: EmailStr to: List[EmailStr] subject: str body: str + @model_validator(mode="before") + @classmethod + def _fill_env_and_normalize(cls, values: Any) -> Any: + if not isinstance(values, dict): + return values + + if not (values.get("host") or "").strip(): + values["host"] = _strip_env(os.environ.get("SMTP_HOST", "smtp.gmail.com")) + port_raw = values.get("port") + if port_raw in (None, "", 0): + values["port"] = int(_strip_env(os.environ.get("SMTP_PORT", "587"))) + if "use_tls" not in values: + values["use_tls"] = ( + os.environ.get("SMTP_USE_TLS", "true").lower() == "true" + ) + if not values.get("username_secret_key"): + values["username_secret_key"] = "SMTP_USERNAME" + if not values.get("password_secret_key"): + values["password_secret_key"] = "SMTP_PASSWORD" + + fe = values.get("from_email") + if fe is None or not str(fe).strip(): + values["from_email"] = _strip_env( + os.environ.get("FROM_EMAIL") + or os.environ.get("SMTP_USERNAME") + or "noreply@node-wire.local" + ) + else: + values["from_email"] = _extract_email(_strip_env(str(fe))) + + # Guardrail: reject placeholder / invalid sender hints from callers + sender = str(values["from_email"]) + if not sender or "@" not in sender or "system_default" in sender: + values["from_email"] = _strip_env( + os.environ.get("FROM_EMAIL") + or os.environ.get("SMTP_USERNAME") + or "noreply@node-wire.local" + ) + + raw_to = values.get("to") + if isinstance(raw_to, str): + values["to"] = [_extract_email(raw_to)] + elif isinstance(raw_to, list): + values["to"] = [_extract_email(str(x)) for x in raw_to] + + return values + class SmtpSendOutput(BaseModel): sent: bool message_id: Optional[str] = None - diff --git a/tests/test_connectors_basic.py b/tests/test_connectors_basic.py index f5db5b7..527feb0 100644 --- a/tests/test_connectors_basic.py +++ b/tests/test_connectors_basic.py @@ -15,7 +15,7 @@ class DummySecretProvider(SecretProvider): def __init__(self) -> None: - self._store = {"stripe_api_key": "sk_test_dummy", "smtp_user": "user", "smtp_pass": "pass"} + self._store = {"STRIPE_API_KEY": "sk_test_dummy", "smtp_user": "user", "smtp_pass": "pass"} def get_secret(self, key: str) -> str: return self._store[key] diff --git a/tests/test_sdk_connector_manifest.py b/tests/test_sdk_connector_manifest.py index 504d5a1..c637f60 100644 --- a/tests/test_sdk_connector_manifest.py +++ b/tests/test_sdk_connector_manifest.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + from bindings.factory import ConnectorFactory from connectors import auto_register from connectors.manifest import build_manifest @@ -56,3 +58,339 @@ def test_mcp_tool_invoke_sets_action(): names = {t["name"] for t in tools} assert "google_drive.files.list" in names assert "stripe.charge" in names + + +def test_mcp_server_list_tools_includes_output_schema(): + from bindings.mcp_server.server import McpServer + + server = McpServer() + tools = server.list_tools() + assert tools + assert all("output_schema" in t for t in tools) + + +def test_mcp_server_connector_ids_filters_list_tools(): + from bindings.mcp_server.server import McpServer + + server = McpServer(connector_ids=["fhir_cerner"]) + names = {t["name"] for t in server.list_tools()} + assert names + assert all(n.startswith("fhir_cerner.") for n in names) + assert "fhir_epic.read_patient" not in names + + +@pytest.mark.asyncio +async def test_mcp_server_invoke_rejects_disallowed_connector() -> None: + from bindings.mcp_server.server import McpServer + + server = McpServer(connector_ids=["google_drive"]) + with pytest.raises(ValueError, match="not allowed"): + await server.invoke_tool( + "smtp.send_email", + {"to": ["doc@example.com"], "subject": "x", "body": "y"}, + ) + + +def test_mcp_server_run_stdio_smoke(): + from bindings.mcp_server.server import McpServer + + server = McpServer() + assert callable(server.run_stdio) + assert callable(server._run_stdio_async) + + +def test_normalize_mcp_tool_arguments_read_patient_maps_legacy_ids(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.fhir_cerner.schema import FhirCernerPatientReadInput + from connectors.fhir_epic.schema import FhirPatientReadInput as FhirEpicPatientReadInput + + for cid in ("fhir_cerner", "fhir_epic"): + out = normalize_mcp_tool_arguments( + cid, + "read_patient", + {"patientId": "12724066"}, + ) + assert out["resource_id"] == "12724066" + assert "patientId" not in out + model = FhirCernerPatientReadInput if cid == "fhir_cerner" else FhirEpicPatientReadInput + model.model_validate({**out, "action": "read_patient"}) + + # Canonical resource_id wins over alias + out2 = normalize_mcp_tool_arguments( + "fhir_cerner", + "read_patient", + {"resource_id": "111", "patient_id": "222"}, + ) + assert out2["resource_id"] == "111" + + out3 = normalize_mcp_tool_arguments( + "fhir_cerner", + "read_patient", + {"familyName": "Smith", "givenName": "John"}, + ) + assert out3["family_name"] == "Smith" + assert out3["given_name"] == "John" + + +def test_normalize_mcp_tool_arguments_search_patients_maps_legacy(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.fhir_cerner.schema import FhirCernerPatientSearchInput + + out = normalize_mcp_tool_arguments( + "fhir_cerner", + "search_patients", + {"patient_ids": "12724066,12724067"}, + ) + assert out["resource_ids"] == ["12724066", "12724067"] + + out2 = normalize_mcp_tool_arguments( + "fhir_cerner", + "search_patients", + {"search_params": {"patientId": "12724066"}}, + ) + assert out2["search_params"]["identifier"] == "12724066" + assert "patientId" not in out2["search_params"] + + FhirCernerPatientSearchInput.model_validate( + {**out2, "action": "search_patients"} + ) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_mime_type_alias(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "name": "a.txt", + "mimeType": "text/plain", + "parents": ["folder1"], + "content": "hello", + }, + ) + assert out["mime_type"] == "text/plain" + assert "mimeType" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_action_upload(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "action": "upload", + "name": "a.txt", + "mime_type": "text/plain", + "content": "x", + }, + ) + assert out["action"] == "files.upload" + FilesUploadOperation.model_validate(out) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_nested_file(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "content": "body", + "file": { + "mime_type": "text/plain", + "name": "nested.txt", + "parents": ["p1"], + }, + }, + ) + assert out["name"] == "nested.txt" + assert out["mime_type"] == "text/plain" + assert out["parents"] == ["p1"] + assert "file" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_media_string_maps_to_content(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "name": "a.txt", + "mime_type": "text/plain", + "media": "hello", + }, + ) + assert out["content"] == "hello" + assert "media" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_media_object_text_alias_maps_to_content(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "name": "a.txt", + "mime_type": "text/plain", + "media": {"text": "hello"}, + }, + ) + assert out["content"] == "hello" + assert "media" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_media_object_base64_maps_to_content_base64(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "name": "a.pdf", + "mime_type": "application/pdf", + "media": {"base64": "Zg=="}, + }, + ) + assert out["content_base64"] == "Zg==" + assert "media" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_media_metadata_aliases_are_used_when_missing(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "media": { + "name": "nested.txt", + "mimeType": "text/plain", + "parents": "p1,p2", + "content": "hi", + } + }, + ) + assert out["name"] == "nested.txt" + assert out["mime_type"] == "text/plain" + assert out["parents"] == ["p1", "p2"] + assert out["content"] == "hi" + assert "media" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_files_upload_canonical_content_wins_over_media_alias(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + from connectors.google_drive.schema import FilesUploadOperation + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "name": "root.txt", + "mime_type": "text/plain", + "content": "root", + "media": {"content": "ignored"}, + }, + ) + assert out["content"] == "root" + assert "media" not in out + FilesUploadOperation.model_validate({**out, "action": "files.upload"}) + + +def test_normalize_mcp_tool_arguments_google_drive_canonical_mime_type_wins_over_nested(): + from bindings.mcp_server.server import normalize_mcp_tool_arguments + + out = normalize_mcp_tool_arguments( + "google_drive", + "files.upload", + { + "mime_type": "text/plain", + "name": "root.txt", + "content": "c", + "file": {"mime_type": "application/json", "name": "ignored.txt"}, + }, + ) + assert out["mime_type"] == "text/plain" + assert out["name"] == "root.txt" + + +@pytest.mark.asyncio +async def test_mcp_server_invoke_tool_passes_normalized_payload_to_connector_run() -> None: + """invoke_tool should apply normalization before BaseConnector.run (SDK action).""" + from bindings.mcp_server.server import McpServer + from runtime.models import ConnectorResponse + + server = McpServer(connector_ids=["fhir_cerner"]) + cerner = server._factory.get_for_protocol("fhir_cerner", "mcp") + assert cerner is not None + + captured: dict = {} + + async def fake_run(raw_input): + captured["payload"] = dict(raw_input) + return ConnectorResponse(success=True, data={"resource": {"id": "12724066"}}, trace_id="t") + + orig_run = cerner.run + try: + cerner.run = fake_run + await server.invoke_tool("fhir_cerner.read_patient", {"patientId": "12724066"}) + finally: + cerner.run = orig_run + + assert captured["payload"]["resource_id"] == "12724066" + assert captured["payload"].get("action") == "read_patient" + + +@pytest.mark.asyncio +async def test_mcp_server_invoke_google_drive_files_upload_normalizes_payload() -> None: + """invoke_tool should normalize Drive upload aliases before connector.run.""" + from bindings.mcp_server.server import McpServer + from runtime.models import ConnectorResponse + + server = McpServer(connector_ids=["google_drive"]) + gdrive = server._factory.get_for_protocol("google_drive", "mcp") + assert gdrive is not None + + captured: dict = {} + + async def fake_run(raw_input): + captured["payload"] = dict(raw_input) + return ConnectorResponse(success=True, data={"raw": {}}, trace_id="t") + + orig_run = gdrive.run + try: + gdrive.run = fake_run + await server.invoke_tool( + "google_drive.files.upload", + { + "mimeType": "text/plain", + "name": "patient_summary.txt", + "parents": ["folder-id"], + "content": "summary", + "media": {"content": "ignored"}, + "action": "upload", + }, + ) + finally: + gdrive.run = orig_run + + assert captured["payload"]["mime_type"] == "text/plain" + assert captured["payload"]["action"] == "files.upload" + assert "mimeType" not in captured["payload"] + assert "media" not in captured["payload"] diff --git a/tests/test_toolhive_agent.py b/tests/test_toolhive_agent.py index 50aa61a..8cd10a8 100644 --- a/tests/test_toolhive_agent.py +++ b/tests/test_toolhive_agent.py @@ -22,7 +22,45 @@ LLMResponse, ToolCall, ) -from agents.toolhive import AgentRunResult, ToolHiveAgent, ToolHiveMcpClient +from agents.toolhive import ( + AgentRunResult, + ToolHiveAgent, + ToolHiveMcpClient, + _is_tool_failure, + resolve_max_tool_failures, + truncate_tool_result_for_llm, +) + + +def test_truncate_tool_result_for_llm_respects_limit(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("TOOLHIVE_MAX_TOOL_RESULT_CHARS", "20") + long = "x" * 100 + out = truncate_tool_result_for_llm(long) + assert len(out) > 20 + assert out.startswith("x" * 20) + assert "truncated" in out + + +def test_truncate_tool_result_for_llm_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("TOOLHIVE_MAX_TOOL_RESULT_CHARS", "0") + long = "y" * 5000 + assert truncate_tool_result_for_llm(long) == long + + +def test_is_tool_failure_detects_validation_and_error_prefix() -> None: + assert _is_tool_failure("Input validation error: bad") + assert _is_tool_failure("ERROR: connection refused") + assert _is_tool_failure('{"success": false, "message": "x"}') + assert not _is_tool_failure("") + assert not _is_tool_failure('{"success": true, "data": {}}') + + +def test_resolve_max_tool_failures_env_and_override(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("TOOLHIVE_MAX_TOOL_FAILURES", raising=False) + assert resolve_max_tool_failures(None) == 2 + monkeypatch.setenv("TOOLHIVE_MAX_TOOL_FAILURES", "5") + assert resolve_max_tool_failures(None) == 5 + assert resolve_max_tool_failures(3) == 3 # --------------------------------------------------------------------------- @@ -31,37 +69,38 @@ SAMPLE_TOOLS = [ { - "name": "fhir_cerner_read_patient", + "name": "fhir_cerner.read_patient", "description": "Fetch a patient from Cerner FHIR", "input_schema": { "type": "object", - "properties": {"patient_id": {"type": "string"}}, - "required": ["patient_id"], + "properties": {"resource_id": {"type": "string"}}, + "required": ["resource_id"], }, }, { - "name": "google_drive_upload_file", + "name": "google_drive.files.upload", "description": "Upload a file to Google Drive", "input_schema": { "type": "object", "properties": { - "file_name": {"type": "string"}, + "name": {"type": "string"}, + "mime_type": {"type": "string"}, "content": {"type": "string"}, }, - "required": ["file_name", "content"], + "required": ["name", "mime_type", "content"], }, }, { - "name": "smtp_send_email", + "name": "smtp.send_email", "description": "Send an email via SMTP", "input_schema": { "type": "object", "properties": { - "to_email": {"type": "string"}, + "to": {"type": "array", "items": {"type": "string"}}, "subject": {"type": "string"}, "body": {"type": "string"}, }, - "required": ["to_email", "subject", "body"], + "required": ["to", "subject", "body"], }, }, ] @@ -142,19 +181,37 @@ async def test_agent_runs_three_tool_sequence() -> None: # Step 1: Call FHIR LLMResponse( content=None, - tool_calls=[_tool_call("fhir_cerner_read_patient", {"patient_id": "12724066"})], + tool_calls=[_tool_call("fhir_cerner.read_patient", {"resource_id": "12724066"})], stop_reason="tool_calls", ), # Step 2: Call Drive LLMResponse( content=None, - tool_calls=[_tool_call("google_drive_upload_file", {"file_name": "summary.txt", "content": "Patient: John"})], + tool_calls=[ + _tool_call( + "google_drive.files.upload", + { + "name": "summary.txt", + "mime_type": "text/plain", + "content": "Patient: John", + }, + ) + ], stop_reason="tool_calls", ), # Step 3: Send email LLMResponse( content=None, - tool_calls=[_tool_call("smtp_send_email", {"to_email": "doc@example.com", "subject": "Summary", "body": "Patient: John"})], + tool_calls=[ + _tool_call( + "smtp.send_email", + { + "to": ["doc@example.com"], + "subject": "Summary", + "body": "Patient: John", + }, + ) + ], stop_reason="tool_calls", ), # Final answer @@ -173,21 +230,47 @@ async def test_agent_runs_three_tool_sequence() -> None: assert result.success is True assert result.final_answer == "All 3 steps completed successfully." assert len(result.steps) == 3 - assert result.steps[0].tool_called == "fhir_cerner_read_patient" - assert result.steps[1].tool_called == "google_drive_upload_file" - assert result.steps[2].tool_called == "smtp_send_email" + assert result.steps[0].tool_called == "fhir_cerner.read_patient" + assert result.steps[1].tool_called == "google_drive.files.upload" + assert result.steps[2].tool_called == "smtp.send_email" # Verify MCP was called exactly 3 times assert mock_mcp.call_tool.await_count == 3 +@pytest.mark.asyncio +async def test_agent_id_first_turn_calls_read_patient_with_resource_id() -> None: + """Document ID-first flow: Cerner read uses canonical resource_id (not search_patients).""" + responses = [ + LLMResponse( + content=None, + tool_calls=[_tool_call("fhir_cerner.read_patient", {"resource_id": "12724066"})], + stop_reason="tool_calls", + ), + LLMResponse(content="Patient retrieved.", tool_calls=[], stop_reason="stop"), + ] + provider = _MockLLMProvider(responses) + mock_mcp = AsyncMock(spec=ToolHiveMcpClient) + mock_mcp.list_tools.return_value = SAMPLE_TOOLS + mock_mcp.call_tool.return_value = '{"success": true}' + + agent = ToolHiveAgent(mcp_client=mock_mcp, llm_provider=provider, max_steps=10) + result = await agent.run("Patient ID 12724066 — fetch from Cerner") + + assert result.success is True + mock_mcp.call_tool.assert_awaited_once() + call = mock_mcp.call_tool.await_args + assert call[0][0] == "fhir_cerner.read_patient" + assert call[0][1]["resource_id"] == "12724066" + + @pytest.mark.asyncio async def test_agent_respects_max_steps() -> None: """Agent should stop and return an error if max_steps is reached.""" # LLM always returns a tool call — never finishes infinite_response = LLMResponse( content=None, - tool_calls=[_tool_call("fhir_cerner_read_patient", {"patient_id": "x"})], + tool_calls=[_tool_call("fhir_cerner.read_patient", {"resource_id": "x"})], stop_reason="tool_calls", ) provider = _MockLLMProvider([infinite_response]) @@ -211,7 +294,7 @@ async def test_agent_handles_tool_error_gracefully() -> None: responses = [ LLMResponse( content=None, - tool_calls=[_tool_call("fhir_cerner_read_patient", {"patient_id": "bad"})], + tool_calls=[_tool_call("fhir_cerner.read_patient", {"resource_id": "bad"})], stop_reason="tool_calls", ), LLMResponse(content="Unable to fetch patient — error recorded.", tool_calls=[], stop_reason="stop"), @@ -244,100 +327,150 @@ async def test_agent_fails_when_mcp_unreachable() -> None: assert "Failed to list MCP tools" in (result.error or "") +@pytest.mark.asyncio +async def test_agent_stops_after_repeated_tool_failures() -> None: + """After max_tool_failures for the same tool, stop without further LLM steps.""" + fail_msg = "Input validation error: bad args" + responses = [ + LLMResponse( + content=None, + tool_calls=[_tool_call("google_drive.files.upload", {"name": "a.txt"})], + stop_reason="tool_calls", + ), + LLMResponse( + content=None, + tool_calls=[_tool_call("google_drive.files.upload", {"name": "a.txt"})], + stop_reason="tool_calls", + ), + LLMResponse(content="should not run", tool_calls=[], stop_reason="stop"), + ] + provider = _MockLLMProvider(responses) + mock_mcp = AsyncMock(spec=ToolHiveMcpClient) + mock_mcp.list_tools.return_value = SAMPLE_TOOLS + mock_mcp.call_tool.return_value = fail_msg + + agent = ToolHiveAgent( + mcp_client=mock_mcp, + llm_provider=provider, + max_steps=10, + max_tool_failures=2, + ) + result = await agent.run("Upload to Drive") + + assert result.success is False + assert len(result.steps) == 2 + assert "google_drive.files.upload" in (result.error or "") + assert "failed 2 times" in (result.final_answer or result.error or "").lower() + assert mock_mcp.call_tool.await_count == 2 + assert provider._call_count == 2 + + +@pytest.mark.asyncio +async def test_agent_success_then_two_failures_same_tool_aborts() -> None: + """Failures only increment on failed tool results; abort after second failure.""" + ok = '{"success": true, "data": {}}' + fail_msg = "Input validation error: x" + responses = [ + LLMResponse( + content=None, + tool_calls=[_tool_call("google_drive.files.upload", {})], + stop_reason="tool_calls", + ), + LLMResponse( + content=None, + tool_calls=[_tool_call("google_drive.files.upload", {})], + stop_reason="tool_calls", + ), + LLMResponse( + content=None, + tool_calls=[_tool_call("google_drive.files.upload", {})], + stop_reason="tool_calls", + ), + ] + provider = _MockLLMProvider(responses) + mock_mcp = AsyncMock(spec=ToolHiveMcpClient) + mock_mcp.list_tools.return_value = SAMPLE_TOOLS + mock_mcp.call_tool.side_effect = [ok, fail_msg, fail_msg] + + agent = ToolHiveAgent( + mcp_client=mock_mcp, + llm_provider=provider, + max_steps=10, + max_tool_failures=2, + ) + result = await agent.run("x") + + assert result.success is False + assert len(result.steps) == 3 + assert mock_mcp.call_tool.await_count == 3 + + # --------------------------------------------------------------------------- # MCP entrypoint smoke test # --------------------------------------------------------------------------- -def test_mcp_entrypoint_registers_eight_tools() -> None: - """The FastMCP server should expose the full FHIR + integration tool surface.""" - # We patch all external deps before importing the module to avoid side effects - with ( - patch("bindings.factory.ConnectorFactory") as mock_factory_cls, - patch("connectors.auto_register"), - patch("mcp.server.fastmcp.FastMCP", autospec=False) as mock_fastmcp_cls, - ): - mock_factory = MagicMock() - mock_factory._connectors = { - "fhir_cerner": MagicMock(), - "fhir_epic": MagicMock(), - "google_drive": MagicMock(), - "smtp": MagicMock(), - } - mock_factory_cls.return_value = mock_factory - - mock_mcp_instance = MagicMock() - registered_tools: List[str] = [] - - def fake_tool(*args: Any, **kwargs: Any): - name = kwargs.get("name") or (args[0] if args else "unknown") - registered_tools.append(name) - return lambda fn: fn # decorator passthrough - - mock_mcp_instance.tool = fake_tool - mock_fastmcp_cls.return_value = mock_mcp_instance - - # Import inside the test to ensure it picks up the mocks - from agents.mcp_entrypoint import _make_server - _make_server() - - assert len(registered_tools) == 8 - assert "fhir_cerner_read_patient" in registered_tools - assert "fhir_cerner_search_patients" in registered_tools - assert "fhir_cerner_search_encounters" in registered_tools - assert "fhir_epic_read_patient" in registered_tools - assert "fhir_epic_search_patients" in registered_tools - assert "fhir_epic_search_encounters" in registered_tools - assert "google_drive_upload_file" in registered_tools - assert "smtp_send_email" in registered_tools +def test_mcp_entrypoint_exposes_manifest_tools() -> None: + """Unified MCP server lists all connectors enabled for MCP in config.""" + from bindings.mcp_server.server import McpServer + + server = McpServer(server_name="node-wire") + names = {t["name"] for t in server.list_tools()} + assert "fhir_cerner.read_patient" in names + assert "fhir_epic.read_patient" in names + assert "google_drive.files.upload" in names + assert "smtp.send_email" in names + assert "stripe.charge" in names + assert "http_generic.request" in names + # Broader surface than the old 8 FastMCP tools + assert len(names) >= 18 # --------------------------------------------------------------------------- -# Individual MCP server smoke tests +# Individual MCP entrypoint modules (thin wrappers) # --------------------------------------------------------------------------- -def _make_server_smoke(module_path: str, expected_tool: str) -> None: - """Helper: verify a per-connector _make_server() registers exactly one tool.""" - with ( - patch("bindings.factory.ConnectorFactory") as mock_factory_cls, - patch("connectors.auto_register"), - patch("mcp.server.fastmcp.FastMCP", autospec=False) as mock_fastmcp_cls, - ): - mock_factory = MagicMock() - mock_factory._connectors = {} - mock_factory_cls.return_value = mock_factory - - mock_mcp_instance = MagicMock() - registered_tools: List[str] = [] - - def fake_tool(*args: Any, **kwargs: Any): - name = kwargs.get("name") or (args[0] if args else "unknown") - registered_tools.append(name) - return lambda fn: fn - - mock_mcp_instance.tool = fake_tool - mock_fastmcp_cls.return_value = mock_mcp_instance - - import importlib - mod = importlib.import_module(module_path) - mod._make_server() - - assert registered_tools == [expected_tool], ( - f"{module_path}: expected [{expected_tool}], got {registered_tools}" - ) +def test_fhir_cerner_mcp_main_callable() -> None: + from agents.fhir_cerner_mcp import main + + assert callable(main) + + +def test_fhir_epic_mcp_main_callable() -> None: + from agents.fhir_epic_mcp import main + + assert callable(main) + + +def test_google_drive_mcp_main_callable() -> None: + from agents.google_drive_mcp import main + + assert callable(main) + + +def test_smtp_mcp_main_callable() -> None: + from agents.smtp_mcp import main + + assert callable(main) + + +def test_mcp_server_matches_per_connector_entrypoints() -> None: + """Per-connector scripts use connector_ids filter; tool prefixes must match.""" + from bindings.mcp_server.server import McpServer -def test_fhir_cerner_mcp_registers_one_tool() -> None: - """fhir_cerner_mcp._make_server() should expose exactly fhir_cerner_read_patient.""" - _make_server_smoke("agents.fhir_cerner_mcp", "fhir_cerner_read_patient") + full = {t["name"] for t in McpServer().list_tools()} + cerner = {t["name"] for t in McpServer(connector_ids=["fhir_cerner"]).list_tools()} + assert cerner == {n for n in full if n.startswith("fhir_cerner.")} -def test_fhir_epic_mcp_registers_one_tool() -> None: - """fhir_epic_mcp._make_server() should expose exactly fhir_epic_read_patient.""" - _make_server_smoke("agents.fhir_epic_mcp", "fhir_epic_read_patient") + epic = {t["name"] for t in McpServer(connector_ids=["fhir_epic"]).list_tools()} + assert epic == {n for n in full if n.startswith("fhir_epic.")} + drive = {t["name"] for t in McpServer(connector_ids=["google_drive"]).list_tools()} + assert drive == {n for n in full if n.startswith("google_drive.")} + assert "google_drive.files.upload" in drive -def test_google_drive_mcp_registers_one_tool() -> None: - """google_drive_mcp._make_server() should expose exactly google_drive_upload_file.""" - _make_server_smoke("agents.google_drive_mcp", "google_drive_upload_file") + smtp = {t["name"] for t in McpServer(connector_ids=["smtp"]).list_tools()} + assert smtp == {"smtp.send_email"} From 88de5424d988afd6fd5e209535872730d9e2b19f Mon Sep 17 00:00:00 2001 From: vinaayakh-aot <61819385+vinaayakh-aot@users.noreply.github.com> Date: Wed, 1 Apr 2026 04:05:15 -0700 Subject: [PATCH 08/15] Updated Docs and UI --- README.md | 12 ++- Setup.md | 38 +++++++++- docs/creating-a-connector.md | 128 ++++++++++++++++++++++++++++++++ docs/toolhive_agent_scenario.md | 9 +-- playground/README.md | 18 +++-- playground/index.html | 6 +- 6 files changed, 191 insertions(+), 20 deletions(-) create mode 100644 docs/creating-a-connector.md diff --git a/README.md b/README.md index 66d68ab..8eadd65 100644 --- a/README.md +++ b/README.md @@ -123,14 +123,14 @@ Examples: Google Drive has a full doc at `src/connectors/google_drive/README.md` ### gRPC / MCP - **gRPC:** Started when `MODE=GRPC`; server listens on port 50051. -- **MCP:** Started when `MODE=MCP`; server exposes tools for discovery and invocation. +- **MCP:** `MODE=MCP` starts a minimal MCP-style placeholder server (sufficient for local, manual inspection), but it is not the full stdio MCP server used for ToolHive and the agent layer. ### Entrypoint - Run with `python -m bindings_entrypoint` (or the `node-wire` script after install). The **MODE** environment variable selects: - **API** (default) – REST API on port 8000. - **GRPC** – gRPC server on port 50051. - - **MCP** – MCP server. + - **MCP** – minimal MCP-style placeholder server (see note above). --- @@ -193,3 +193,11 @@ $env:GOOGLE_DRIVE_SA_JSON = Get-Content -Path $saPath -Raw ## Dependencies All dependencies are declared in `pyproject.toml` (Python >=3.11). They include: pydantic, FastAPI, uvicorn, tenacity, pybreaker, OpenTelemetry, grpcio, and connector-specific libraries (httpx, aiosmtplib, stripe, google-auth, google-api-python-client, etc.). See `pyproject.toml` for the full list and versions. + +--- + +## Setup and development docs + +- Platform setup (REST/gRPC/agents MCP): [Setup.md](Setup.md) +- Individual connector MCP servers (ToolHive): [docs/mcp-servers.md](docs/mcp-servers.md) +- Creating a new connector: [docs/creating-a-connector.md](docs/creating-a-connector.md) diff --git a/Setup.md b/Setup.md index 7be559f..069d512 100644 --- a/Setup.md +++ b/Setup.md @@ -23,10 +23,11 @@ Node Wire is a Python framework that runs connector adapters (Google Drive, SMTP | Requirement | Version | Notes | | ----------- | ------- | --------------------------------------- | -| Python | 3.12+ | `python --version` to check | +| Python | 3.11+ | `python --version` to check | | pip or uv | Latest | `pip install --upgrade pip` | | Git | Any | To clone the repo | | Docker | Latest | Only needed for ToolHive MCP deployment | +| Node.js | Any LTS | Only needed for `npx @modelcontextprotocol/inspector` | --- @@ -36,7 +37,7 @@ Node Wire is a Python framework that runs connector adapters (Google Drive, SMTP ```bash # 1. Clone the repository git clone -cd connector-platform +cd node-wire # 2. Install dependencies (recommended: uv) uv sync --extra agents @@ -45,6 +46,8 @@ uv sync --extra agents uv run node-wire --help ``` +> **Install uv:** See the official installer docs at `https://docs.astral.sh/uv/`. +> > **REST/gRPC only** (no AI agent features): `uv sync` without the extra is sufficient. > > **Alternative (pip):** If you’re not using `uv`, install editable deps with pip: @@ -67,6 +70,8 @@ cp sample.env .env You only need to fill in the sections for the connectors you plan to use. The platform starts successfully even if some credentials are missing — those connectors will simply return an error when called. +> **Doc convention:** Environment variable names in the docs follow `sample.env`. Some legacy keys (like `stripe_api_key`) are intentionally lower-case because that is what the connector reads. + ### Environment Variable Sections @@ -74,7 +79,7 @@ You only need to fill in the sections for the connectors you plan to use. The pl | ---------------- | ------------------------------------------------------------------------------------------------------------------- | ---------------------- | | **FHIR Epic** | `EPIC_FHIR_BASE_URL`, `EPIC_TOKEN_URL`, `EPIC_CLIENT_ID`, `EPIC_KID`, `EPIC_PRIVATE_KEY` | Epic EHR integration | | **FHIR Cerner** | `CERNER_FHIR_BASE_URL`, `CERNER_TOKEN_URL`, `CERNER_CLIENT_ID`, `CERNER_KID`, `CERNER_PRIVATE_KEY`, `CERNER_SCOPES` | Cerner EHR integration | -| **Google Drive** | `google_drive_sa_json`, `GOOGLE_DRIVE_FOLDER_ID` | Google Drive connector | +| **Google Drive** | `GOOGLE_DRIVE_SA_JSON`, `GOOGLE_DRIVE_FOLDER_ID` | Google Drive connector | | **SMTP** | `SMTP_HOST`, `SMTP_PORT`, `SMTP_USERNAME`, `SMTP_PASSWORD` | Sending emails | | **LLM / Agent** | `LLM_PROVIDER`, `GROQ_API_KEY` (or other provider key) | AI agent / ToolHive | | **ToolHive** | `TOOLHIVE_MCP_URL` (single) or `TOOLHIVE_MCP_URLS` (comma-separated, multi-server) | ToolHive MCP proxy | @@ -95,6 +100,19 @@ The platform supports three modes. Set the `MODE` environment variable to switch | **gRPC** | `MODE=GRPC uv run node-wire` | `50051` | gRPC clients | | **MCP (stdio)** | `python -m agents.mcp_entrypoint` | stdio | AI agents, ToolHive, Claude Desktop | +> **Important:** `MODE=MCP` for `node-wire` / `python -m bindings_entrypoint` starts a minimal MCP-style placeholder server, not the full stdio MCP server used with ToolHive and the agent layer. For ToolHive/Inspector/agents, use `python -m agents.mcp_entrypoint` (or the per-connector MCP servers in `docs/mcp-servers.md`). + +### Configuration file (`config/connectors.yaml`) + +Connectors are loaded from `config/connectors.yaml`. Each connector has: + +- `enabled`: whether the connector is instantiated at startup +- `exposed_via`: which protocols can access it (`rest`, `grpc`, `mcp`) + +If a connector is disabled (or not exposed for a protocol), requests to it will fail with “not configured / not available” even if your `.env` is correct. + +For details on adding a new connector to the runtime, see `docs/creating-a-connector.md`. + ### REST API Quick Start @@ -106,6 +124,12 @@ uv run node-wire PORT=8001 uv run node-wire ``` +Equivalent entrypoint (without `uv`): + +```bash +MODE=API python -m bindings_entrypoint +``` + Once running: - **Health check:** `GET http://localhost:8000/health` @@ -216,7 +240,7 @@ Quick summary of what you'll need: Add to your `.env`: ```env -google_drive_sa_json=/absolute/path/to/service-account.json +GOOGLE_DRIVE_SA_JSON=/absolute/path/to/service-account.json GOOGLE_DRIVE_FOLDER_ID=your-folder-id-from-drive-url ``` @@ -315,6 +339,12 @@ npx @modelcontextprotocol/inspector python -m agents.google_drive_mcp npx @modelcontextprotocol/inspector python -m agents.mcp_entrypoint ``` +### Troubleshooting quick hits + +- **Port 8000 in use**: set `PORT=8001` (or any free port) when starting the REST API. +- **Connector “not configured”**: confirm it is `enabled: true` (and exposed for your protocol) in `config/connectors.yaml`. +- **ToolHive + Google Drive auth failure**: inside ToolHive, `GOOGLE_DRIVE_SA_JSON` must be the JSON **contents** (not a file path). Locally, it can be an absolute file path (see `docs/mcp-servers.md`). + --- ## Running Tests diff --git a/docs/creating-a-connector.md b/docs/creating-a-connector.md new file mode 100644 index 0000000..98f1bce --- /dev/null +++ b/docs/creating-a-connector.md @@ -0,0 +1,128 @@ +# Creating a connector in Node Wire + +This guide explains how to implement a new connector (Layer B) and make it available via REST/gRPC/MCP (Layer C). + +## How connectors plug into the platform + +- **Layer B (`src/connectors/`)**: connector implementations (schemas + logic). +- **Layer C (`src/bindings/`)**: protocol bindings and configuration-driven loading. + +At startup, the REST binding: + +- Imports connector `registration.py` modules via `connectors.auto_register()` so exceptions can be mapped. +- Loads and instantiates enabled connectors via `ConnectorFactory` using `config/connectors.yaml`. + +## Connector shape (single-action) + +Most connectors are a single `BaseConnector` subclass with: + +- `connector_id`: stable identifier (used in URLs and config) +- `action`: the action name (used in URLs and manifests) +- `schema.py`: Pydantic input/output models +- `logic.py`: connector implementation (`internal_execute`) +- `registration.py` (optional): register exception mappings for runtime error taxonomy + +Use the `http_generic` connector as a reference: + +- `src/connectors/http_generic/schema.py` +- `src/connectors/http_generic/logic.py` +- `src/connectors/http_generic/registration.py` (if present) + +### Minimal checklist (single-action) + +1. Create a new package: `src/connectors//`. +2. Define request/response models in `schema.py`. +3. Implement `logic.py` with a `BaseConnector[...]` subclass: + - set `connector_id = ""` + - set `action = ""` + - implement `internal_execute(self, params, *, trace_id)` +4. If you raise connector-specific exceptions, add `registration.py` and register them with the runtime `ErrorMapper` so clients get stable `error_code`/`error_category`. + +## Connector shape (multi-action) + +Some connectors expose multiple actions from a single logical integration (e.g. FHIR). In that pattern, the factory stores **one** object under a `connector_id`, and that object exposes: + +- `list_actions() -> list[BaseConnector]` +- `get_action(name: str) -> BaseConnector | None` + +The factory uses these helpers for discovery and dispatch. + +See the Epic FHIR connector implementation for the pattern: + +- `src/connectors/fhir_epic/logic.py` (`FhirEpicConnector`, `_FhirAction`, `list_actions()`, `get_action()`) + +## Wire the connector into the runtime (required) + +There are two places to update so the platform can load and expose your connector. + +### 1) Add an entry to `config/connectors.yaml` + +Add a new block under `connectors:`: + +```yaml +connectors: + my_connector: + enabled: true + exposed_via: ["rest", "grpc", "mcp"] +``` + +- `enabled: true` controls whether the connector is instantiated. +- `exposed_via` controls which protocols can see it. + +If `enabled` is false, or if a protocol is missing from `exposed_via`, you will see “not configured / not available” errors even if your `.env` is correct. + +### 2) Add factory wiring in `src/bindings/factory.py` + +`ConnectorFactory` instantiates connectors via `_instantiate(connector_id)`. Add a branch for your `connector_id` that returns your connector instance and passes the `secret_provider`. + +For single-action connectors, the factory typically passes the input/output model classes too (example: `http_generic`, `google_drive`). + +For multi-action connectors, the factory stores one instance (example: `fhir_epic`, `fhir_cerner`), and `get_for_protocol()` uses `get_action()` when an action is requested. + +## Registration (`registration.py`) + +`connectors.auto_register()` imports `registration` modules from connector subpackages automatically: + +- A connector package may omit `registration.py` if it doesn’t need custom exception mapping. +- If present, `registration.py` should register exception types with the runtime error taxonomy so clients get predictable categories (`BUSINESS`, `AUTH`, `RETRYABLE`, `FATAL`). + +See `src/connectors/__init__.py` for the auto-discovery behavior. + +## Secrets and configuration conventions + +- Connector secrets are read via the `SecretProvider` (`self.secret_provider.get_secret("KEY")`). +- For local development, secrets are typically defined in `.env` using the names in `sample.env`. +- The platform’s `EnvSecretProvider` is case-insensitive (it checks both `KEY` and `key`), but prefer **one canonical spelling** in documentation and config. + +## How exposure works per protocol + +The REST binding exposes: + +- `POST /connectors/{connector_id}/{action}` + +Routes and schemas come from the connector manifest built over the factory’s `list_for_protocol("rest")` output. + +The same `enabled` / `exposed_via` gating applies to gRPC and the built-in MCP-style manifest. + +## Optional: MCP tools for ToolHive / agents + +This repository also includes MCP servers under `src/agents/` (for ToolHive and other MCP clients). These are separate from the REST/gRPC bindings: + +- **Combined MCP server**: `python -m agents.mcp_entrypoint` +- **Per-connector MCP servers**: `python -m agents._mcp` (see `docs/mcp-servers.md`) + +Adding a connector to the runtime (factory + YAML) does not automatically create a ToolHive-ready MCP server. If you need MCP tools, you typically add a small wrapper in `src/agents/` that calls into the connector via `ConnectorFactory`. + +## Loading flow (simplified) + +```mermaid +flowchart LR + yamlFile[config/connectors.yaml] + factory[ConnectorFactory.load] + instantiate[_instantiate connector_id] + connector[Connector_instance] + yamlFile --> factory + factory --> instantiate + instantiate --> connector +``` + diff --git a/docs/toolhive_agent_scenario.md b/docs/toolhive_agent_scenario.md index 666c39b..40fea59 100644 --- a/docs/toolhive_agent_scenario.md +++ b/docs/toolhive_agent_scenario.md @@ -62,6 +62,7 @@ For modular deployments, each connector can be run as an independent MCP server - `nw-google-drive` (Google Drive) - `nw-smartonfhir-epic` (Epic SMART on FHIR) - `nw-smartonfhir-cerner` (Cerner SMART on FHIR) +- `nw-smtp` (SMTP email) When running multiple MCP servers, configure the agent with **`TOOLHIVE_MCP_URLS`** (comma-separated list of ToolHive proxy URLs). The agent will merge tools across servers. @@ -135,7 +136,7 @@ Below is the full set of environment variables used by the connector platform an | `GROQ_API_KEY` | LLM (Groq) | Your Groq API key | | `GROQ_MODEL` | LLM | Example: `openai/gpt-oss-120b` | | `MCP_TRANSPORT` | ToolHive / local | `stdio` when running in ToolHive container | -| `PYTHONPATH` | Runtime | e.g. `/app/src` for container; `d:\connector-platform\src` locally | +| `PYTHONPATH` | Runtime | e.g. `/app/src` for container; `**/node-wire/src` locally | | `SMTP_HOST` | SMTP connector | Example: `sandbox.smtp.mailtrap.io` | | `SMTP_PORT` | SMTP connector | Example: `2525` | | `SMTP_USERNAME` | SMTP connector | Mailtrap / SMTP user | @@ -160,7 +161,7 @@ Option A — Recommended: ToolHive UI (no code) Option B — Local quick run (Windows PowerShell) -Prerequisite: Install Python 3.10+ and Git. If you cannot install, ask an administrator to run Option A. +Prerequisite: Install Python 3.11+ and Git. If you cannot install, ask an administrator to run Option A. 1. Open PowerShell and clone or navigate to the project folder. 2. Create a simple `.env` file in the project root (replace placeholder values): @@ -204,8 +205,6 @@ Notes for non-developers: From the root of the repository: ```bash -cd connector-platform - docker build -t node-wire:latest . ``` @@ -538,7 +537,7 @@ tests/test_toolhive_agent.py::test_mcp_entrypoint_registers_three_to PASSED ## File layout (`agents`) ``` -connector-platform/ +node-wire/ ├── Dockerfile ← Docker image for ToolHive ├── pyproject.toml ← [agents] extras added ├── sample.env ← env var reference diff --git a/playground/README.md b/playground/README.md index 1ba1290..b7851ab 100644 --- a/playground/README.md +++ b/playground/README.md @@ -97,8 +97,8 @@ The demo is pre-configured with mock/sandbox endpoints for immediate use. To tes To test the Google Drive integration manually, follow these specialized setup steps: 1. **Service Account**: Create a Service Account in the Google Cloud Console with the **Google Drive API** enabled. Download the JSON key. 2. **Secret Configuration**: - * Place the JSON key file in your project directory (e.g., `D:\connector-platform\service_account.json`). - * Update your `.env` file: `GOOGLE_DRIVE_SA_JSON=D:\connector-platform\service_account.json`. + * Place the JSON key file somewhere safe on your machine (e.g., `/service_account.json`). + * Update your `.env` file: `GOOGLE_DRIVE_SA_JSON=/service_account.json`. * *Note: The platform now supports direct file paths for easier local configuration.* 3. **Permissions**: If using a specific **Vault Folder ID**, ensure that folder is shared with the Service Account's email address (found in the JSON) with "Editor" or "Manager" permissions. 4. **Workflow Verification**: @@ -110,7 +110,7 @@ To test the Google Drive integration manually, follow these specialized setup st To enable the AI Agent chat, you need to configure an LLM provider: 1. **Select Provider**: Set `LLM_PROVIDER` to `groq` (default) or `openai` in your `.env`. 2. **Add API Key**: Provide the corresponding key, e.g., `GROQ_API_KEY=your_key_here`. -3. **SMTP Setup**: (Optional) Add SMTP credentials (`SMTP_HOST`, `SMTP_PORT`, `SMTP_USER`, `SMTP_PASS`) to enable the agent to send emails. +3. **SMTP Setup**: (Optional) Add SMTP credentials (`SMTP_HOST`, `SMTP_PORT`, `SMTP_USERNAME`, `SMTP_PASSWORD`) to enable the agent to send emails. 4. **MCP URL**: (Optional) If running the MCP server in a separate container, set `TOOLHIVE_MCP_URL` to point to the MCP proxy. --- @@ -119,8 +119,14 @@ To enable the AI Agent chat, you need to configure an LLM provider: 1. Navigate to the project root. 2. Start the FastAPI server: - ```bash - set MODE=API&& python -m bindings_entrypoint - ``` + +```bash +# Recommended +uv run node-wire + +# Equivalent (no uv) +MODE=API python -m bindings_entrypoint +``` + 3. Open your browser to `http://localhost:8000/playground/` (or the configured port). 4. Switch between **EHR**, **IT Ops**, **Cerner**, **Google Drive Vault**, and **AI Agent** tabs to explore the different workflows. diff --git a/playground/index.html b/playground/index.html index 46978e8..a18a660 100644 --- a/playground/index.html +++ b/playground/index.html @@ -4,7 +4,7 @@ - Node-wire Playground + node-wire Playground @@ -28,7 +28,7 @@
-

Node-Wire

+

node-wire

Autonomous Connector Orchestration Platform

@@ -93,7 +93,7 @@

Connectors

-

Node-Wire MCP via ToolHive

+

node-wire MCP via ToolHive

MCP Agent — Guardrailed