mirror of
https://github.com/FuQuan233/nonebot-plugin-llmchat.git
synced 2025-09-04 10:20:45 +00:00
✨ support Model Context Protocol (MCP)
This commit is contained in:
parent
dfe3b5308c
commit
eb1038e09e
3 changed files with 160 additions and 12 deletions
|
@ -28,6 +28,7 @@ from nonebot.rule import Rule
|
|||
from openai import AsyncOpenAI
|
||||
|
||||
from .config import Config, PresetConfig
|
||||
from .mcpclient import MCPClient
|
||||
|
||||
require("nonebot_plugin_localstore")
|
||||
import nonebot_plugin_localstore as store
|
||||
|
@ -241,6 +242,11 @@ async def process_messages(group_id: int):
|
|||
- 如果你需要思考的话,你应该思考尽量少,以节省时间。
|
||||
下面是关于你性格的设定,如果设定中提到让你扮演某个人,或者设定中有提到名字,则优先使用设定中的名字。
|
||||
{state.group_prompt or plugin_config.default_prompt}
|
||||
"""
|
||||
if preset.support_mcp:
|
||||
systemPrompt += f"""
|
||||
你也可以使用一些工具,下面是关于这些工具的额外说明:
|
||||
{"\n".join([mcp_config.addtional_prompt for mcp_config in plugin_config.mcp_servers.values()])}
|
||||
"""
|
||||
|
||||
messages: Iterable[ChatCompletionMessageParam] = [
|
||||
|
@ -256,20 +262,74 @@ async def process_messages(group_id: int):
|
|||
# 将机器人错过的消息推送给LLM
|
||||
content = ",".join([format_message(ev) for ev in state.past_events])
|
||||
|
||||
new_messages = [{"role": "user", "content": content}]
|
||||
|
||||
logger.debug(
|
||||
f"发送API请求 模型:{preset.model_name} 历史消息数:{len(messages)}"
|
||||
)
|
||||
mcp_client = MCPClient(plugin_config.mcp_servers)
|
||||
await mcp_client.connect_to_servers()
|
||||
|
||||
available_tools = await mcp_client.get_available_tools()
|
||||
|
||||
client_config = {
|
||||
"model": preset.model_name,
|
||||
"max_tokens": preset.max_tokens,
|
||||
"temperature": preset.temperature,
|
||||
"timeout": 60,
|
||||
}
|
||||
|
||||
if preset.support_mcp:
|
||||
client_config["tools"] = available_tools
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=preset.model_name,
|
||||
messages=[*messages, {"role": "user", "content": content}],
|
||||
max_tokens=preset.max_tokens,
|
||||
temperature=preset.temperature,
|
||||
timeout=60,
|
||||
**client_config,
|
||||
messages=messages + new_messages,
|
||||
)
|
||||
|
||||
if response.usage is not None:
|
||||
logger.debug(f"收到API响应 使用token数:{response.usage.total_tokens}")
|
||||
|
||||
final_message = []
|
||||
message = response.choices[0].message
|
||||
if message.content:
|
||||
final_message.append(message.content)
|
||||
|
||||
# 处理响应并处理工具调用
|
||||
while preset.support_mcp and message.tool_calls:
|
||||
new_messages.append({
|
||||
"role": "assistant",
|
||||
"tool_calls": [tool_call.dict() for tool_call in message.tool_calls]
|
||||
})
|
||||
# 处理每个工具调用
|
||||
for tool_call in message.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
# 发送工具调用提示
|
||||
await handler.send(Message(f"正在使用{mcp_client.get_friendly_name(tool_name)}"))
|
||||
|
||||
# 执行工具调用
|
||||
result = await mcp_client.call_tool(tool_name, tool_args)
|
||||
|
||||
new_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": str(result.content)
|
||||
})
|
||||
|
||||
# 将工具调用的结果交给 LLM
|
||||
response = await client.chat.completions.create(
|
||||
**client_config,
|
||||
messages=messages + new_messages,
|
||||
)
|
||||
|
||||
message = response.choices[0].message
|
||||
if message.content:
|
||||
final_message.append(message.content)
|
||||
|
||||
await mcp_client.cleanup()
|
||||
|
||||
reply, matched_reasoning_content = pop_reasoning_content(
|
||||
response.choices[0].message.content
|
||||
)
|
||||
|
@ -279,9 +339,8 @@ async def process_messages(group_id: int):
|
|||
)
|
||||
|
||||
# 请求成功后再保存历史记录,保证user和assistant穿插,防止R1模型报错
|
||||
state.history.append({"role": "user", "content": content})
|
||||
# 添加助手回复到历史
|
||||
state.history.append({"role": "assistant", "content": reply})
|
||||
for message in new_messages:
|
||||
state.history.append(message)
|
||||
state.past_events.clear()
|
||||
|
||||
if state.output_reasoning_content and reasoning_content:
|
||||
|
@ -467,10 +526,7 @@ async def load_state():
|
|||
state.last_active = state_data["last_active"]
|
||||
state.group_prompt = state_data["group_prompt"]
|
||||
state.output_reasoning_content = state_data["output_reasoning_content"]
|
||||
state.random_trigger_prob = (
|
||||
state_data.get("random_trigger_prob")
|
||||
or plugin_config.random_trigger_prob
|
||||
)
|
||||
state.random_trigger_prob = state_data.get("random_trigger_prob", plugin_config.random_trigger_prob)
|
||||
group_states[int(gid)] = state
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
|
@ -11,7 +13,18 @@ class PresetConfig(BaseModel):
|
|||
max_tokens: int = Field(2048, description="最大响应token数")
|
||||
temperature: float = Field(0.7, description="生成温度(0-2]")
|
||||
proxy: str = Field("", description="HTTP代理服务器")
|
||||
support_mcp: bool = Field(False, description="是否支持MCP")
|
||||
|
||||
class MCPServerConfig(BaseModel):
|
||||
"""MCP服务器配置"""
|
||||
command: Optional[str] = Field(None, description="stdio模式下MCP命令")
|
||||
args: Optional[list[str]] = Field([], description="stdio模式下MCP命令参数")
|
||||
env: Optional[dict[str, str]] = Field({}, description="stdio模式下MCP命令环境变量")
|
||||
url: Optional[str] = Field(None, description="sse模式下MCP服务器地址")
|
||||
|
||||
# 额外字段
|
||||
friendly_name: str = Field("", description="MCP服务器友好名称")
|
||||
addtional_prompt: str = Field("", description="额外提示词")
|
||||
|
||||
class ScopedConfig(BaseModel):
|
||||
"""LLM Chat Plugin配置"""
|
||||
|
@ -30,6 +43,7 @@ class ScopedConfig(BaseModel):
|
|||
"你的回答应该尽量简洁、幽默、可以使用一些语气词、颜文字。你应该拒绝回答任何政治相关的问题。",
|
||||
description="默认提示词",
|
||||
)
|
||||
mcp_servers: dict[str, MCPServerConfig] = Field({}, description="MCP服务器配置")
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
|
|
78
nonebot_plugin_llmchat/mcpclient.py
Normal file
78
nonebot_plugin_llmchat/mcpclient.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
from contextlib import AsyncExitStack
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from nonebot import logger
|
||||
|
||||
from .config import MCPServerConfig
|
||||
|
||||
|
||||
class MCPClient:
|
||||
def __init__(self, server_config: dict[str, MCPServerConfig]):
|
||||
logger.info(f"正在初始化MCPClient,共有{len(server_config)}个服务器配置")
|
||||
self.server_config = server_config
|
||||
self.sessions = {}
|
||||
self.exit_stack = AsyncExitStack()
|
||||
logger.debug("MCPClient初始化成功")
|
||||
|
||||
async def connect_to_servers(self):
|
||||
logger.info(f"开始连接{len(self.server_config)}个MCP服务器")
|
||||
for server_name, config in self.server_config.items():
|
||||
logger.debug(f"正在连接服务器[{server_name}]")
|
||||
if config.url:
|
||||
sse_transport = await self.exit_stack.enter_async_context(sse_client(url=config.url))
|
||||
read, write = sse_transport
|
||||
self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
||||
await self.sessions[server_name].initialize()
|
||||
elif config.command:
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
stdio_client(StdioServerParameters(**config.model_dump()))
|
||||
)
|
||||
read, write = stdio_transport
|
||||
self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
||||
await self.sessions[server_name].initialize()
|
||||
else:
|
||||
raise ValueError("Server config must have either url or command")
|
||||
|
||||
logger.info(f"已成功连接到MCP服务器[{server_name}]")
|
||||
|
||||
async def get_available_tools(self):
|
||||
logger.info(f"正在从{len(self.sessions)}个已连接的服务器获取可用工具")
|
||||
available_tools = []
|
||||
|
||||
for server_name, session in self.sessions.items():
|
||||
logger.debug(f"正在列出服务器[{server_name}]中的工具")
|
||||
response = await session.list_tools()
|
||||
tools = response.tools
|
||||
logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具")
|
||||
|
||||
available_tools.extend(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"{server_name}___{tool.name}",
|
||||
"description": tool.description,
|
||||
"parameters": tool.inputSchema,
|
||||
},
|
||||
}
|
||||
for tool in tools
|
||||
)
|
||||
return available_tools
|
||||
|
||||
async def call_tool(self, tool_name: str, tool_args: dict):
|
||||
server_name, real_tool_name = tool_name.split("___")
|
||||
logger.info(f"正在服务器[{server_name}]上调用工具[{real_tool_name}]")
|
||||
session = self.sessions[server_name]
|
||||
response = await session.call_tool(real_tool_name, tool_args)
|
||||
logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}")
|
||||
return response
|
||||
|
||||
def get_friendly_name(self, tool_name: str):
|
||||
server_name, real_tool_name = tool_name.split("___")
|
||||
return self.server_config[server_name].friendly_name + " - " + real_tool_name
|
||||
|
||||
async def cleanup(self):
|
||||
logger.debug("正在清理MCPClient资源")
|
||||
await self.exit_stack.aclose()
|
||||
logger.debug("MCPClient资源清理完成")
|
Loading…
Add table
Add a link
Reference in a new issue