diff --git a/nonebot_plugin_llmchat/config.py b/nonebot_plugin_llmchat/config.py index 6ecebf1..1a70a44 100755 --- a/nonebot_plugin_llmchat/config.py +++ b/nonebot_plugin_llmchat/config.py @@ -1,3 +1,5 @@ +from dataclasses import dataclass + from pydantic import BaseModel, Field @@ -21,6 +23,8 @@ class MCPServerConfig(BaseModel): env: dict[str, str] | None = Field({}, description="stdio模式下MCP命令环境变量") url: str | None = Field(None, description="sse模式下MCP服务器地址") headers: dict[str, str] | None = Field({}, description="sse模式下http请求头,用于认证或其他设置") + transport_type: str | None = Field(None, description="请求类型 sse、stdio 或 streamablehttp") + # 额外字段 friendly_name: str | None = Field(None, description="MCP服务器友好名称") @@ -55,3 +59,9 @@ class ScopedConfig(BaseModel): class Config(BaseModel): llmchat: ScopedConfig + +@dataclass +class transportType: + sse = "sse" + stdio = "stdio" + streamablehttp = "streamablehttp" \ No newline at end of file diff --git a/nonebot_plugin_llmchat/mcpclient.py b/nonebot_plugin_llmchat/mcpclient.py index 8861dd9..447750f 100644 --- a/nonebot_plugin_llmchat/mcpclient.py +++ b/nonebot_plugin_llmchat/mcpclient.py @@ -4,9 +4,13 @@ from contextlib import AsyncExitStack from mcp import ClientSession, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client +try: + from mcp.client.streamable_http import streamable_http_client as streamablehttp_client +except: + from mcp.client.streamable_http import streamablehttp_client from nonebot import logger -from .config import MCPServerConfig +from .config import MCPServerConfig, transportType from .onebottools import OneBotTools @@ -85,18 +89,33 @@ class MCPClient: self.exit_stack = AsyncExitStack() async def __aenter__(self): - if config.url: - transport = await self.exit_stack.enter_async_context( - sse_client(url=config.url, headers=config.headers) - ) - elif config.command: - transport = await self.exit_stack.enter_async_context( - stdio_client(StdioServerParameters(**config.model_dump())) - ) - else: - raise ValueError("Server config must have either url or command") + if config.transport_type is None: + if config.url: + config.transport_type = transportType.sse + elif config.command: + config.transport_type = transportType.stdio + else: + raise ValueError("Server config must have either url or command") + + match config.transport_type: + case transportType.sse: + transport = await self.exit_stack.enter_async_context( + sse_client(url=config.url, headers=config.headers) + ) + read, write = transport + case transportType.stdio: + transport = await self.exit_stack.enter_async_context( + stdio_client(StdioServerParameters(**config.model_dump())) + ) + read, write = transport + case transportType.streamablehttp: + transport = await self.exit_stack.enter_async_context( + streamablehttp_client(url=config.url, headers=config.headers) + ) + read, write, session_callback = transport + case _: + raise ValueError("Unknown transport type") - read, write = transport self.session = await self.exit_stack.enter_async_context(ClientSession(read, write)) await self.session.initialize() return self.session