diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 243eef5ae6..a8c9bec638 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -20,6 +20,7 @@ from mcp.shared.response_router import ResponseRouter from mcp.types import ( CONNECTION_CLOSED, + INTERNAL_ERROR, INVALID_PARAMS, REQUEST_TIMEOUT, CancelledNotification, @@ -184,6 +185,7 @@ class BaseSession( _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] _response_routers: list[ResponseRouter] + _propagate_errors: dict[RequestId, BaseException] def __init__( self, @@ -201,6 +203,7 @@ def __init__( self._progress_callbacks = {} self._response_routers = [] self._exit_stack = AsyncExitStack() + self._propagate_errors = {} def add_response_router(self, router: ResponseRouter) -> None: """Register a response router to handle responses for non-standard requests. @@ -295,6 +298,11 @@ async def send_request( class_name = request.__class__.__name__ message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." raise MCPError(code=REQUEST_TIMEOUT, message=message) + except anyio.EndOfStream: + propagate = self._propagate_errors.pop(request_id, None) + if propagate is not None: + raise propagate from None + raise if isinstance(response_or_error, JSONRPCError): raise MCPError.from_jsonrpc_error(response_or_error) @@ -374,7 +382,20 @@ async def _handle_session_message(message: SessionMessage) -> None: if not responder._completed: # type: ignore[reportPrivateUsage] await self._handle_incoming(responder) - except Exception: + except Exception as e: + if getattr(e, "__mcp_propagate__", False): + error_response = JSONRPCError( + jsonrpc="2.0", + id=message.message.id, + error=ErrorData(code=INTERNAL_ERROR, message="Handler raised", data=""), + ) + await self._write_stream.send(SessionMessage(message=error_response)) + self._in_flight.pop(message.message.id, None) + for in_flight_id, stream in list(self._response_streams.items()): + self._propagate_errors[in_flight_id] = e + await stream.aclose() + return + # For request validation errors, send a proper JSON-RPC error # response instead of crashing the server logging.warning("Failed to validate request", exc_info=True) @@ -451,7 +472,7 @@ async def _handle_session_message(message: SessionMessage) -> None: try: await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) await stream.aclose() - except Exception: # pragma: no cover + except Exception: # Stream might already be closed pass self._response_streams.clear() diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index d7c6cc3b5f..19331c3263 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -4,11 +4,13 @@ from mcp import Client, types from mcp.client.session import ClientSession from mcp.server import Server, ServerRequestContext +from mcp.shared._context import RequestContext from mcp.shared.exceptions import MCPError from mcp.shared.memory import create_client_server_memory_streams from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ( + INTERNAL_ERROR, PARSE_ERROR, CancelledNotification, CancelledNotificationParams, @@ -416,3 +418,115 @@ async def make_request(client_session: ClientSession): # Pending request completed successfully assert len(result_holder) == 1 assert isinstance(result_holder[0], EmptyResult) + + +@pytest.mark.anyio +async def test_callback_exception_propagation(): + """Verify that exceptions raised in callbacks with __mcp_propagate__ = True + are propagated to the awaiter of send_request, and result in INTERNAL_ERROR to peer. + """ + + class CustomPropagatedException(Exception): + __mcp_propagate__ = True + + ev_server_received_error = anyio.Event() + server_error_holder: list[JSONRPCError] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + # Wait for client's ping request + msg = await server_read.receive() + assert isinstance(msg, SessionMessage) + assert isinstance(msg.message, JSONRPCRequest) + + # Trigger list_roots callback on client by sending roots/list request + roots_request = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="roots/list", + ) + await server_write.send(SessionMessage(message=roots_request)) + + # Receive the client's response (which should be an error due to propagated exception) + response_msg = await server_read.receive() + assert isinstance(response_msg, SessionMessage) + assert isinstance(response_msg.message, JSONRPCError) + server_error_holder.append(response_msg.message) + ev_server_received_error.set() + + async def mock_list_roots(context: RequestContext[ClientSession]): + raise CustomPropagatedException("Callback error that should propagate") + + async def make_request(client_session: ClientSession): + # Send a ping request and assert that CustomPropagatedException propagates to it + with pytest.raises(CustomPropagatedException) as exc_info: + await client_session.send_ping() + assert "Callback error that should propagate" in str(exc_info.value) + + async with ( + anyio.create_task_group() as tg, + ClientSession( + read_stream=client_read, + write_stream=client_write, + list_roots_callback=mock_list_roots, + ) as client_session, + ): + tg.start_soon(mock_server) + tg.start_soon(make_request, client_session) + + with anyio.fail_after(2): # pragma: no branch + await ev_server_received_error.wait() + + assert len(server_error_holder) == 1 + assert server_error_holder[0].error.code == INTERNAL_ERROR + + +@pytest.mark.anyio +async def test_send_request_end_of_stream_without_propagated_error(): + """Ensure EndOfStream is surfaced when no propagated error is present.""" + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, _server_write = server_streams + + async def mock_server(client_session: ClientSession): + message = await server_read.receive() + assert isinstance(message, SessionMessage) + assert isinstance(message.message, JSONRPCRequest) + response_stream = client_session._response_streams[message.message.id] + await response_stream.aclose() + + async def make_request(client_session: ClientSession): + with pytest.raises(anyio.EndOfStream): + await client_session.send_ping() + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server, client_session) + tg.start_soon(make_request, client_session) + + +@pytest.mark.anyio +async def test_receive_loop_handles_closed_response_stream(): + """Cover receive loop cleanup when a response stream is already closed.""" + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + _server_read, server_write = server_streams + + async with ClientSession(read_stream=client_read, write_stream=client_write) as client_session: + response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError]( + 1 + ) + await response_stream.aclose() + await response_stream_reader.aclose() + client_session._response_streams[0] = response_stream + + server_write.close() + + with anyio.fail_after(2): # pragma: no branch + while client_session._response_streams: + await anyio.sleep(0)