Skip to content
25 changes: 23 additions & 2 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mcp.shared.response_router import ResponseRouter
from mcp.types import (
CONNECTION_CLOSED,
INTERNAL_ERROR,
INVALID_PARAMS,
REQUEST_TIMEOUT,
CancelledNotification,
Expand Down Expand Up @@ -184,6 +185,7 @@ class BaseSession(
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_progress_callbacks: dict[RequestId, ProgressFnT]
_response_routers: list[ResponseRouter]
_propagate_errors: dict[RequestId, BaseException]

def __init__(
self,
Expand All @@ -201,6 +203,7 @@ def __init__(
self._progress_callbacks = {}
self._response_routers = []
self._exit_stack = AsyncExitStack()
self._propagate_errors = {}

def add_response_router(self, router: ResponseRouter) -> None:
"""Register a response router to handle responses for non-standard requests.
Expand Down Expand Up @@ -295,6 +298,11 @@ async def send_request(
class_name = request.__class__.__name__
message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds."
raise MCPError(code=REQUEST_TIMEOUT, message=message)
except anyio.EndOfStream:
propagate = self._propagate_errors.pop(request_id, None)
if propagate is not None:
raise propagate from None
raise

if isinstance(response_or_error, JSONRPCError):
raise MCPError.from_jsonrpc_error(response_or_error)
Expand Down Expand Up @@ -374,7 +382,20 @@ async def _handle_session_message(message: SessionMessage) -> None:

if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)
except Exception:
except Exception as e:
if getattr(e, "__mcp_propagate__", False):
error_response = JSONRPCError(
jsonrpc="2.0",
id=message.message.id,
error=ErrorData(code=INTERNAL_ERROR, message="Handler raised", data=""),
)
await self._write_stream.send(SessionMessage(message=error_response))
self._in_flight.pop(message.message.id, None)
for in_flight_id, stream in list(self._response_streams.items()):
self._propagate_errors[in_flight_id] = e
await stream.aclose()
return

# For request validation errors, send a proper JSON-RPC error
# response instead of crashing the server
logging.warning("Failed to validate request", exc_info=True)
Expand Down Expand Up @@ -451,7 +472,7 @@ async def _handle_session_message(message: SessionMessage) -> None:
try:
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
await stream.aclose()
except Exception: # pragma: no cover
except Exception:
# Stream might already be closed
pass
self._response_streams.clear()
Expand Down
114 changes: 114 additions & 0 deletions tests/shared/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from mcp import Client, types
from mcp.client.session import ClientSession
from mcp.server import Server, ServerRequestContext
from mcp.shared._context import RequestContext
from mcp.shared.exceptions import MCPError
from mcp.shared.memory import create_client_server_memory_streams
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder
from mcp.types import (
INTERNAL_ERROR,
PARSE_ERROR,
CancelledNotification,
CancelledNotificationParams,
Expand Down Expand Up @@ -416,3 +418,115 @@ async def make_request(client_session: ClientSession):
# Pending request completed successfully
assert len(result_holder) == 1
assert isinstance(result_holder[0], EmptyResult)


@pytest.mark.anyio
async def test_callback_exception_propagation():
"""Verify that exceptions raised in callbacks with __mcp_propagate__ = True
are propagated to the awaiter of send_request, and result in INTERNAL_ERROR to peer.
"""

class CustomPropagatedException(Exception):
__mcp_propagate__ = True

ev_server_received_error = anyio.Event()
server_error_holder: list[JSONRPCError] = []

async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, server_write = server_streams

async def mock_server():
# Wait for client's ping request
msg = await server_read.receive()
assert isinstance(msg, SessionMessage)
assert isinstance(msg.message, JSONRPCRequest)

# Trigger list_roots callback on client by sending roots/list request
roots_request = JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="roots/list",
)
await server_write.send(SessionMessage(message=roots_request))

# Receive the client's response (which should be an error due to propagated exception)
response_msg = await server_read.receive()
assert isinstance(response_msg, SessionMessage)
assert isinstance(response_msg.message, JSONRPCError)
server_error_holder.append(response_msg.message)
ev_server_received_error.set()

async def mock_list_roots(context: RequestContext[ClientSession]):
raise CustomPropagatedException("Callback error that should propagate")

async def make_request(client_session: ClientSession):
# Send a ping request and assert that CustomPropagatedException propagates to it
with pytest.raises(CustomPropagatedException) as exc_info:
await client_session.send_ping()
assert "Callback error that should propagate" in str(exc_info.value)

async with (
anyio.create_task_group() as tg,
ClientSession(
read_stream=client_read,
write_stream=client_write,
list_roots_callback=mock_list_roots,
) as client_session,
):
tg.start_soon(mock_server)
tg.start_soon(make_request, client_session)

with anyio.fail_after(2): # pragma: no branch
await ev_server_received_error.wait()

assert len(server_error_holder) == 1
assert server_error_holder[0].error.code == INTERNAL_ERROR


@pytest.mark.anyio
async def test_send_request_end_of_stream_without_propagated_error():
"""Ensure EndOfStream is surfaced when no propagated error is present."""
async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, _server_write = server_streams

async def mock_server(client_session: ClientSession):
message = await server_read.receive()
assert isinstance(message, SessionMessage)
assert isinstance(message.message, JSONRPCRequest)
response_stream = client_session._response_streams[message.message.id]
await response_stream.aclose()

async def make_request(client_session: ClientSession):
with pytest.raises(anyio.EndOfStream):
await client_session.send_ping()

async with (
anyio.create_task_group() as tg,
ClientSession(read_stream=client_read, write_stream=client_write) as client_session,
):
tg.start_soon(mock_server, client_session)
tg.start_soon(make_request, client_session)


@pytest.mark.anyio
async def test_receive_loop_handles_closed_response_stream():
"""Cover receive loop cleanup when a response stream is already closed."""
async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
_server_read, server_write = server_streams

async with ClientSession(read_stream=client_read, write_stream=client_write) as client_session:
response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](
1
)
await response_stream.aclose()
await response_stream_reader.aclose()
client_session._response_streams[0] = response_stream

server_write.close()

with anyio.fail_after(2): # pragma: no branch
while client_session._response_streams:
await anyio.sleep(0)
Loading