diff --git a/README.md b/README.md index 198d121..f945db3 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,7 @@ LLMCHAT__MCP_SERVERS同样为一个dict,key为服务器名称,value配置的 | arg | 否 | [] | stdio服务器MCP命令参数 | | env | 否 | {} | stdio服务器环境变量 | | url | sse服务器必填 | 无 | sse服务器地址 | +| headers | 否 | {} | sse模式下http请求头,用于认证或其他设置 | 以下为在 Claude.app 的MCP服务器配置基础上增加的字段 | 配置项 | 必填 | 默认值 | 说明 | @@ -179,7 +180,10 @@ LLMCHAT__MCP_SERVERS同样为一个dict,key为服务器名称,value配置的 "AISearch": { "friendly_name": "百度搜索", "additional_prompt": "遇到你不知道的问题或者时效性比较强的问题时,可以使用AISearch搜索,在使用AISearch时不要使用其他AI模型。", - "url": "http://appbuilder.baidu.com/v2/ai_search/mcp/sse?api_key=Bearer+" + "url": "http://appbuilder.baidu.com/v2/ai_search/mcp/sse?api_key=Bearer+", + "headers": { + "Authorization": "" + } }, "fetch": { "friendly_name": "网页浏览", diff --git a/nonebot_plugin_llmchat/__init__.py b/nonebot_plugin_llmchat/__init__.py index e89e485..79290a1 100755 --- a/nonebot_plugin_llmchat/__init__.py +++ b/nonebot_plugin_llmchat/__init__.py @@ -278,6 +278,8 @@ async def process_messages(group_id: int): while not state.queue.empty(): event = await state.queue.get() logger.debug(f"从队列获取消息 群号:{group_id} 消息ID:{event.message_id}") + past_events_snapshot = [] + mcp_client = MCPClient(plugin_config.mcp_servers) try: systemPrompt = f""" 我想要你帮我在群聊中闲聊,大家一般叫你{"、".join(list(driver.config.nickname))},我将会在后面的信息中告诉你每条群聊信息的发送者和发送时间,你可以直接称呼发送者为他对应的昵称。 @@ -320,6 +322,7 @@ async def process_messages(group_id: int): # 将机器人错过的消息推送给LLM past_events_snapshot = list(state.past_events) + state.past_events.clear() for ev in past_events_snapshot: text_content = format_message(ev) content.append({"type": "text", "text": text_content}) @@ -345,7 +348,6 @@ async def process_messages(group_id: int): "timeout": 60, } - mcp_client = MCPClient(plugin_config.mcp_servers) if preset.support_mcp: await mcp_client.connect_to_servers() available_tools = await mcp_client.get_available_tools() @@ -397,8 +399,6 @@ async def process_messages(group_id: int): message = response.choices[0].message - await mcp_client.cleanup() - reply, matched_reasoning_content = pop_reasoning_content( response.choices[0].message.content ) @@ -423,7 +423,6 @@ async def process_messages(group_id: int): # 请求成功后再保存历史记录,保证user和assistant穿插,防止R1模型报错 for message in new_messages: state.history.append(message) - state.past_events.clear() if state.output_reasoning_content and reasoning_content: try: @@ -450,11 +449,13 @@ async def process_messages(group_id: int): except Exception as e: logger.opt(exception=e).error(f"API请求失败 群号:{group_id}") + # 如果在处理过程中出现异常,恢复未处理的消息到state中 + state.past_events.extendleft(reversed(past_events_snapshot)) await handler.send(Message(f"服务暂时不可用,请稍后再试\n{e!s}")) finally: + state.processing = False state.queue.task_done() - - state.processing = False + await mcp_client.cleanup() # 预设切换命令 diff --git a/nonebot_plugin_llmchat/config.py b/nonebot_plugin_llmchat/config.py index d658875..ed88dd2 100755 --- a/nonebot_plugin_llmchat/config.py +++ b/nonebot_plugin_llmchat/config.py @@ -20,6 +20,7 @@ class MCPServerConfig(BaseModel): args: list[str] | None = Field([], description="stdio模式下MCP命令参数") env: dict[str, str] | None = Field({}, description="stdio模式下MCP命令环境变量") url: str | None = Field(None, description="sse模式下MCP服务器地址") + headers: dict[str, str] | None = Field({}, description="sse模式下http请求头,用于认证或其他设置") # 额外字段 friendly_name: str | None = Field(None, description="MCP服务器友好名称") diff --git a/nonebot_plugin_llmchat/mcpclient.py b/nonebot_plugin_llmchat/mcpclient.py index 55e1b44..c3f3224 100644 --- a/nonebot_plugin_llmchat/mcpclient.py +++ b/nonebot_plugin_llmchat/mcpclient.py @@ -22,7 +22,7 @@ class MCPClient: for server_name, config in self.server_config.items(): logger.debug(f"正在连接服务器[{server_name}]") if config.url: - sse_transport = await self.exit_stack.enter_async_context(sse_client(url=config.url)) + sse_transport = await self.exit_stack.enter_async_context(sse_client(url=config.url, headers=config.headers)) read, write = sse_transport self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write)) await self.sessions[server_name].initialize() @@ -74,6 +74,7 @@ class MCPClient: return response.content def get_friendly_name(self, tool_name: str): + logger.debug(tool_name) server_name, real_tool_name = tool_name.split("___") return (self.server_config[server_name].friendly_name or server_name) + " - " + real_tool_name