diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7c964a334..654b45efb 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -118,6 +118,7 @@ def __init__( logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, + protocol_version: str | None = None, *, sampling_capabilities: types.SamplingCapability | None = None, experimental_task_handlers: ExperimentalTaskHandlers | None = None, @@ -133,6 +134,7 @@ def __init__( self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._initialize_result: types.InitializeResult | None = None self._experimental_features: ExperimentalClientFeatures | None = None + self._protocol_version = protocol_version or types.LATEST_PROTOCOL_VERSION # Experimental: Task handlers (use defaults if not provided) self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers() @@ -168,7 +170,7 @@ async def initialize(self) -> types.InitializeResult: result = await self.send_request( types.InitializeRequest( params=types.InitializeRequestParams( - protocol_version=types.LATEST_PROTOCOL_VERSION, + protocol_version=self._protocol_version, capabilities=types.ClientCapabilities( sampling=sampling, elicitation=elicitation, diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f25c964f0..a13c4ef50 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -606,6 +606,125 @@ async def mock_server(): assert result.protocol_version == LATEST_PROTOCOL_VERSION +@pytest.mark.anyio +async def test_client_session_custom_protocol_version(): + """Test that custom protocol_version is sent during initialization. + + This allows connecting to servers that require a specific protocol version, + such as Snowflake's managed MCP server which requires "2025-06-18". + See: https://github.com/modelcontextprotocol/python-sdk/issues/2307 + """ + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + custom_protocol_version = "2025-06-18" + received_protocol_version = None + + async def mock_server(): + nonlocal received_protocol_version + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request, JSONRPCRequest) + request = client_request_adapter.validate_python( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request, InitializeRequest) + received_protocol_version = request.params.protocol_version + + result = InitializeResult( + protocol_version=custom_protocol_version, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + # Receive initialized notification + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + protocol_version=custom_protocol_version, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + result = await session.initialize() + + # Assert that the custom protocol version was sent and received + assert received_protocol_version == custom_protocol_version + assert result.protocol_version == custom_protocol_version + + +@pytest.mark.anyio +async def test_client_session_default_protocol_version(): + """Test that LATEST_PROTOCOL_VERSION is used when protocol_version is not specified.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + received_protocol_version = None + + async def mock_server(): + nonlocal received_protocol_version + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request, JSONRPCRequest) + request = client_request_adapter.validate_python( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request, InitializeRequest) + received_protocol_version = request.params.protocol_version + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + # Receive initialized notification + await client_to_server_receive.receive() + + async with ( + ClientSession(server_to_client_receive, client_to_server_send) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that the default (latest) protocol version was sent + assert received_protocol_version == LATEST_PROTOCOL_VERSION + + @pytest.mark.anyio @pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}]) async def test_client_tool_call_with_meta(meta: RequestParamsMeta | None):