添加MCP服务器全局工作目录配置

This commit is contained in:
FuQuan233 2026-06-29 10:23:21 +08:00
parent 38af060cb2
commit b6af4ec334
4 changed files with 254 additions and 224 deletions

View file

@ -18,12 +18,20 @@ class MCPClient:
_SESSION_TTL_SECONDS = 600
_SESSION_CLEANUP_INTERVAL_SECONDS = 60
def __new__(cls, server_config: dict[str, MCPServerConfig] | None = None):
def __new__(
cls,
server_config: dict[str, MCPServerConfig] | None = None,
default_command_cwd: str | None = None,
):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, server_config: dict[str, MCPServerConfig] | None = None):
def __init__(
self,
server_config: dict[str, MCPServerConfig] | None = None,
default_command_cwd: str | None = None,
):
if self._initialized:
return
@ -32,6 +40,7 @@ class MCPClient:
logger.info(f"正在初始化MCPClient单例共有{len(server_config)}个服务器配置")
self.server_config = server_config
self.default_command_cwd = default_command_cwd
self.sessions = {}
self.exit_stack = AsyncExitStack()
self._session_exit_stacks: dict[str, AsyncExitStack] = {}
@ -47,12 +56,16 @@ class MCPClient:
logger.debug("MCPClient单例初始化成功")
@classmethod
def get_instance(cls, server_config: dict[str, MCPServerConfig] | None = None):
def get_instance(
cls,
server_config: dict[str, MCPServerConfig] | None = None,
default_command_cwd: str | None = None,
):
"""获取MCPClient实例"""
if cls._instance is None:
if server_config is None:
raise ValueError("server_config must be provided for first initialization")
cls._instance = cls(server_config)
cls._instance = cls(server_config, default_command_cwd)
return cls._instance
@classmethod
@ -79,8 +92,15 @@ class MCPClient:
sse_client(url=config.url, headers=config.headers)
)
elif config.command:
stdio_params: dict[str, Any] = {
"command": config.command,
"args": config.args or [],
"env": config.env or {},
}
if self.default_command_cwd:
stdio_params["cwd"] = self.default_command_cwd
transport = await session_stack.enter_async_context(
cast(Any, stdio_client(StdioServerParameters(**config.model_dump())))
cast(Any, stdio_client(StdioServerParameters(**stdio_params)))
)
else:
raise ValueError("Server config must have either url or command")