From e729d1b2330dfc6d57c90160d6749e757d074857 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Sun, 14 Jun 2026 09:39:58 +0800 Subject: [PATCH] fix: respect session group capabilities --- src/mcp/client/session_group.py | 56 +++++++++++++++++------------- tests/client/test_session_group.py | 52 +++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 25 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 9610212642..f6ea2d818c 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -344,36 +344,42 @@ async def _aggregate_components(self, server_info: types.Implementation, session tools_temp: dict[str, types.Tool] = {} tool_to_session_temp: dict[str, mcp.ClientSession] = {} + initialize_result = session.initialize_result + capabilities = initialize_result.capabilities if initialize_result is not None else None + # Query the server for its prompts and aggregate to list. - try: - prompts = (await session.list_prompts()).prompts - for prompt in prompts: - name = self._component_name(prompt.name, server_info) - prompts_temp[name] = prompt - component_names.prompts.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch prompts: {err}") + if capabilities is None or capabilities.prompts is not None: + try: + prompts = (await session.list_prompts()).prompts + for prompt in prompts: + name = self._component_name(prompt.name, server_info) + prompts_temp[name] = prompt + component_names.prompts.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch prompts: {err}") # Query the server for its resources and aggregate to list. - try: - resources = (await session.list_resources()).resources - for resource in resources: - name = self._component_name(resource.name, server_info) - resources_temp[name] = resource - component_names.resources.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch resources: {err}") + if capabilities is None or capabilities.resources is not None: + try: + resources = (await session.list_resources()).resources + for resource in resources: + name = self._component_name(resource.name, server_info) + resources_temp[name] = resource + component_names.resources.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch resources: {err}") # Query the server for its tools and aggregate to list. - try: - tools = (await session.list_tools()).tools - for tool in tools: - name = self._component_name(tool.name, server_info) - tools_temp[name] = tool - tool_to_session_temp[name] = session - component_names.tools.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch tools: {err}") + if capabilities is None or capabilities.tools is not None: + try: + tools = (await session.list_tools()).tools + for tool in tools: + name = self._component_name(tool.name, server_info) + tools_temp[name] = tool + tool_to_session_temp[name] = session + component_names.tools.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch tools: {err}") # Clean up exit stack for session if we couldn't retrieve anything # from the server. diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f39..f21d17f115 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -125,6 +125,58 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli mock_session.list_prompts.assert_awaited_once() +@pytest.mark.anyio +async def test_client_session_group_skips_unadvertised_capabilities(mock_exit_stack: contextlib.AsyncExitStack): + server_info = types.Implementation(name="ToolsOnlyServer", version="1") + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + mock_tool = types.Tool(name="ping", input_schema={}) + mock_session.initialize_result = types.InitializeResult( + protocol_version="2025-03-26", + capabilities=types.ServerCapabilities(tools=types.ToolsCapability(list_changed=False)), + server_info=server_info, + ) + mock_session.list_tools.return_value = types.ListToolsResult(tools=[mock_tool]) + mock_session.list_resources.return_value = types.ListResourcesResult(resources=[]) + mock_session.list_prompts.return_value = types.ListPromptsResult(prompts=[]) + + group = ClientSessionGroup(exit_stack=mock_exit_stack) + with mock.patch.object(group, "_establish_session", return_value=(server_info, mock_session)): + await group.connect_to_server(StdioServerParameters(command="test")) + + assert group.tools == {"ping": mock_tool} + assert group.resources == {} + assert group.prompts == {} + mock_session.list_tools.assert_awaited_once() + mock_session.list_resources.assert_not_awaited() + mock_session.list_prompts.assert_not_awaited() + + +@pytest.mark.anyio +async def test_client_session_group_skips_unadvertised_tools(mock_exit_stack: contextlib.AsyncExitStack): + server_info = types.Implementation(name="PromptServer", version="1") + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + mock_prompt = types.Prompt(name="explain") + mock_session.initialize_result = types.InitializeResult( + protocol_version="2025-03-26", + capabilities=types.ServerCapabilities(prompts=types.PromptsCapability(list_changed=False)), + server_info=server_info, + ) + mock_session.list_prompts.return_value = types.ListPromptsResult(prompts=[mock_prompt]) + mock_session.list_resources.return_value = types.ListResourcesResult(resources=[]) + mock_session.list_tools.return_value = types.ListToolsResult(tools=[]) + + group = ClientSessionGroup(exit_stack=mock_exit_stack) + with mock.patch.object(group, "_establish_session", return_value=(server_info, mock_session)): + await group.connect_to_server(StdioServerParameters(command="test")) + + assert group.prompts == {"explain": mock_prompt} + assert group.resources == {} + assert group.tools == {} + mock_session.list_prompts.assert_awaited_once() + mock_session.list_resources.assert_not_awaited() + mock_session.list_tools.assert_not_awaited() + + @pytest.mark.anyio async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack): """Test connecting with a component name hook."""