添加了对streamablehttp的支持

This commit is contained in:
slexce 2026-03-08 19:53:07 +08:00
parent 7c7e270851
commit 3def80b047
2 changed files with 41 additions and 12 deletions

View file

@ -1,3 +1,5 @@
from dataclasses import dataclass
from pydantic import BaseModel, Field
@ -21,6 +23,8 @@ class MCPServerConfig(BaseModel):
env: dict[str, str] | None = Field({}, description="stdio模式下MCP命令环境变量")
url: str | None = Field(None, description="sse模式下MCP服务器地址")
headers: dict[str, str] | None = Field({}, description="sse模式下http请求头用于认证或其他设置")
transport_type: str | None = Field(None, description="请求类型 sse、stdio 或 streamablehttp")
# 额外字段
friendly_name: str | None = Field(None, description="MCP服务器友好名称")
@ -55,3 +59,9 @@ class ScopedConfig(BaseModel):
class Config(BaseModel):
llmchat: ScopedConfig
@dataclass
class transportType:
sse = "sse"
stdio = "stdio"
streamablehttp = "streamablehttp"

View file

@ -4,9 +4,13 @@ from contextlib import AsyncExitStack
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
try:
from mcp.client.streamable_http import streamable_http_client as streamablehttp_client
except:
from mcp.client.streamable_http import streamablehttp_client
from nonebot import logger
from .config import MCPServerConfig
from .config import MCPServerConfig, transportType
from .onebottools import OneBotTools
@ -85,18 +89,33 @@ class MCPClient:
self.exit_stack = AsyncExitStack()
async def __aenter__(self):
if config.url:
transport = await self.exit_stack.enter_async_context(
sse_client(url=config.url, headers=config.headers)
)
elif config.command:
transport = await self.exit_stack.enter_async_context(
stdio_client(StdioServerParameters(**config.model_dump()))
)
else:
raise ValueError("Server config must have either url or command")
if config.transport_type is None:
if config.url:
config.transport_type = transportType.sse
elif config.command:
config.transport_type = transportType.stdio
else:
raise ValueError("Server config must have either url or command")
match config.transport_type:
case transportType.sse:
transport = await self.exit_stack.enter_async_context(
sse_client(url=config.url, headers=config.headers)
)
read, write = transport
case transportType.stdio:
transport = await self.exit_stack.enter_async_context(
stdio_client(StdioServerParameters(**config.model_dump()))
)
read, write = transport
case transportType.streamablehttp:
transport = await self.exit_stack.enter_async_context(
streamablehttp_client(url=config.url, headers=config.headers)
)
read, write, session_callback = transport
case _:
raise ValueError("Unknown transport type")
read, write = transport
self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))
await self.session.initialize()
return self.session