mirror of
https://github.com/FuQuan233/nonebot-plugin-llmchat.git
synced 2026-05-12 19:42:50 +00:00
Merge 7e87981167 into 7c7e270851
This commit is contained in:
commit
8f10807184
6 changed files with 1712 additions and 26 deletions
19
README.md
19
README.md
|
|
@ -68,6 +68,17 @@ _✨ 支持多API预设、MCP协议、内置工具、联网搜索、视觉模型
|
||||||
- 可动态修改群组专属系统提示词(`/修改设定`)
|
- 可动态修改群组专属系统提示词(`/修改设定`)
|
||||||
- 支持自定义默认提示词
|
- 支持自定义默认提示词
|
||||||
|
|
||||||
|
1. **子模型调用能力**
|
||||||
|
- 主模型可以调用其他子模型完成特定任务(如生成图片、语音、视频)
|
||||||
|
- 支持配置可调用的子模型列表(`call_model_list`)
|
||||||
|
- 调用失败时自动切换备选模型
|
||||||
|
- 子模型如果支持MCP,可以继续调用MCP工具
|
||||||
|
|
||||||
|
1. **定时任务功能**
|
||||||
|
- 支持创建各类定时提醒任务(一次性、每日、每周、每年、间隔)
|
||||||
|
- 任务触发时AI自动生成友好的提醒消息
|
||||||
|
- 任务触发时可调用MCP工具获取最新信息(如天气)
|
||||||
|
|
||||||
## 💿 安装
|
## 💿 安装
|
||||||
|
|
||||||
<details open>
|
<details open>
|
||||||
|
|
@ -143,7 +154,7 @@ _✨ 支持多API预设、MCP协议、内置工具、联网搜索、视觉模型
|
||||||
| ob__recall_message | 撤回指定消息 | 机器人需要管理员权限或为消息发送者 |
|
| ob__recall_message | 撤回指定消息 | 机器人需要管理员权限或为消息发送者 |
|
||||||
|
|
||||||
|
|
||||||
### MCP服务器配置
|
### API预设配置
|
||||||
|
|
||||||
其中LLMCHAT__API_PRESETS为一个列表,每项配置有以下的配置项
|
其中LLMCHAT__API_PRESETS为一个列表,每项配置有以下的配置项
|
||||||
| 配置项 | 必填 | 默认值 | 说明 |
|
| 配置项 | 必填 | 默认值 | 说明 |
|
||||||
|
|
@ -157,6 +168,12 @@ _✨ 支持多API预设、MCP协议、内置工具、联网搜索、视觉模型
|
||||||
| proxy | 否 | 无 | 请求API时使用的HTTP代理 |
|
| proxy | 否 | 无 | 请求API时使用的HTTP代理 |
|
||||||
| support_mcp | 否 | False | 是否支持MCP协议 |
|
| support_mcp | 否 | False | 是否支持MCP协议 |
|
||||||
| support_image | 否 | False | 是否支持图片输入 |
|
| support_image | 否 | False | 是否支持图片输入 |
|
||||||
|
| support_to_image | 否 | False | 是否支持生成图片(作为子模型被调用时) |
|
||||||
|
| support_to_voice | 否 | False | 是否支持生成语音(作为子模型被调用时) |
|
||||||
|
| support_to_video | 否 | False | 是否支持生成视频(作为子模型被调用时) |
|
||||||
|
| call_model_list | 否 | None | 可调用的子模型名称列表,用于扩展主模型能力 |
|
||||||
|
|
||||||
|
### MCP服务器配置
|
||||||
|
|
||||||
|
|
||||||
LLMCHAT__MCP_SERVERS同样为一个dict,key为服务器名称,value配置的格式基本兼容 Claude.app 的配置格式,具体支持如下
|
LLMCHAT__MCP_SERVERS同样为一个dict,key为服务器名称,value配置的格式基本兼容 Claude.app 的配置格式,具体支持如下
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ from openai import AsyncOpenAI
|
||||||
|
|
||||||
from .config import Config, PresetConfig
|
from .config import Config, PresetConfig
|
||||||
from .mcpclient import MCPClient
|
from .mcpclient import MCPClient
|
||||||
|
from .scheduler import SchedulerManager
|
||||||
|
|
||||||
require("nonebot_plugin_localstore")
|
require("nonebot_plugin_localstore")
|
||||||
import nonebot_plugin_localstore as store
|
import nonebot_plugin_localstore as store
|
||||||
|
|
@ -359,7 +360,7 @@ async def process_messages(context_id: int, is_group: bool = True):
|
||||||
logger.debug(f"从队列获取消息 用户:{context_id} 消息ID:{event.message_id}")
|
logger.debug(f"从队列获取消息 用户:{context_id} 消息ID:{event.message_id}")
|
||||||
group_id = None
|
group_id = None
|
||||||
past_events_snapshot = []
|
past_events_snapshot = []
|
||||||
mcp_client = MCPClient.get_instance(plugin_config.mcp_servers)
|
mcp_client = MCPClient.get_instance(plugin_config.mcp_servers, plugin_config)
|
||||||
try:
|
try:
|
||||||
# 构建系统提示,分成多行以满足行长限制
|
# 构建系统提示,分成多行以满足行长限制
|
||||||
chat_type = "群聊" if is_group else "私聊"
|
chat_type = "群聊" if is_group else "私聊"
|
||||||
|
|
@ -417,6 +418,9 @@ async def process_messages(context_id: int, is_group: bool = True):
|
||||||
|
|
||||||
content: list[ChatCompletionContentPartParam] = []
|
content: list[ChatCompletionContentPartParam] = []
|
||||||
|
|
||||||
|
# 收集用户消息中的图片(用于传递给子模型作为参考)
|
||||||
|
user_message_images: list[str] = []
|
||||||
|
|
||||||
# 将机器人错过的消息推送给LLM
|
# 将机器人错过的消息推送给LLM
|
||||||
past_events_snapshot = list(state.past_events)
|
past_events_snapshot = list(state.past_events)
|
||||||
state.past_events.clear()
|
state.past_events.clear()
|
||||||
|
|
@ -425,11 +429,19 @@ async def process_messages(context_id: int, is_group: bool = True):
|
||||||
content.append({"type": "text", "text": text_content})
|
content.append({"type": "text", "text": text_content})
|
||||||
|
|
||||||
# 将消息中的图片转成 base64
|
# 将消息中的图片转成 base64
|
||||||
|
base64_images = await process_images(ev)
|
||||||
|
|
||||||
|
# 收集图片用于子模型调用
|
||||||
|
user_message_images.extend(base64_images)
|
||||||
|
|
||||||
|
# 如果主模型支持图片输入,也传递给主模型
|
||||||
if preset.support_image:
|
if preset.support_image:
|
||||||
base64_images = await process_images(ev)
|
|
||||||
for base64_image in base64_images:
|
for base64_image in base64_images:
|
||||||
content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}})
|
content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}})
|
||||||
|
|
||||||
|
if user_message_images:
|
||||||
|
logger.info(f"用户消息中包含 {len(user_message_images)} 张图片,将用于子模型调用")
|
||||||
|
|
||||||
new_messages: list[ChatCompletionMessageParam] = [
|
new_messages: list[ChatCompletionMessageParam] = [
|
||||||
{"role": "user", "content": content}
|
{"role": "user", "content": content}
|
||||||
]
|
]
|
||||||
|
|
@ -446,9 +458,14 @@ async def process_messages(context_id: int, is_group: bool = True):
|
||||||
}
|
}
|
||||||
|
|
||||||
if preset.support_mcp:
|
if preset.support_mcp:
|
||||||
available_tools = await mcp_client.get_available_tools(is_group)
|
available_tools = await mcp_client.get_available_tools(is_group, preset)
|
||||||
client_config["tools"] = available_tools
|
client_config["tools"] = available_tools
|
||||||
|
|
||||||
|
# 用于存储子模型生成的多媒体内容
|
||||||
|
submodel_images: list[str] = []
|
||||||
|
submodel_voices: list[str] = []
|
||||||
|
submodel_videos: list[str] = []
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
**client_config,
|
**client_config,
|
||||||
messages=messages + new_messages,
|
messages=messages + new_messages,
|
||||||
|
|
@ -478,20 +495,58 @@ async def process_messages(context_id: int, is_group: bool = True):
|
||||||
# 发送工具调用提示
|
# 发送工具调用提示
|
||||||
await handler.send(Message(f"正在使用{mcp_client.get_friendly_name(tool_name)}"))
|
await handler.send(Message(f"正在使用{mcp_client.get_friendly_name(tool_name)}"))
|
||||||
|
|
||||||
|
# 对于子模型调用,传递用户消息中的图片作为参考
|
||||||
|
images_for_submodel = user_message_images if tool_name.startswith("submodel__") else None
|
||||||
|
|
||||||
if is_group:
|
if is_group:
|
||||||
result = await mcp_client.call_tool(
|
result = await mcp_client.call_tool(
|
||||||
tool_name,
|
tool_name,
|
||||||
tool_args,
|
tool_args,
|
||||||
group_id=event.group_id,
|
group_id=event.group_id,
|
||||||
bot_id=str(event.self_id)
|
bot_id=str(event.self_id),
|
||||||
|
user_id=event.user_id,
|
||||||
|
is_group=True,
|
||||||
|
current_preset=preset,
|
||||||
|
user_images=images_for_submodel
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = await mcp_client.call_tool(
|
result = await mcp_client.call_tool(
|
||||||
tool_name,
|
tool_name,
|
||||||
tool_args,
|
tool_args,
|
||||||
bot_id=str(event.self_id)
|
bot_id=str(event.self_id),
|
||||||
|
user_id=event.user_id,
|
||||||
|
is_group=False,
|
||||||
|
current_preset=preset,
|
||||||
|
user_images=images_for_submodel
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 处理子模型返回的结构化结果
|
||||||
|
if isinstance(result, dict) and tool_name.startswith("submodel__"):
|
||||||
|
if result.get("success"):
|
||||||
|
# 收集多媒体内容
|
||||||
|
if result.get("images"):
|
||||||
|
submodel_images.extend(result["images"])
|
||||||
|
logger.info(f"子模型生成了 {len(result['images'])} 张图片")
|
||||||
|
if result.get("audio"):
|
||||||
|
submodel_voices.append(result["audio"])
|
||||||
|
logger.info("子模型生成了语音")
|
||||||
|
if result.get("video"):
|
||||||
|
submodel_videos.append(result["video"])
|
||||||
|
logger.info("子模型生成了视频")
|
||||||
|
# 构建给主模型的结果消息
|
||||||
|
result_msg = f"成功使用模型 {result.get('model_used', '未知')} 生成内容。"
|
||||||
|
if result.get("content"):
|
||||||
|
result_msg += f"\n子模型回复:{result['content']}"
|
||||||
|
if result.get("images"):
|
||||||
|
result_msg += f"\n已生成 {len(result['images'])} 张图片,将在你回复后发送给用户。"
|
||||||
|
if result.get("audio"):
|
||||||
|
result_msg += "\n已生成语音,将在你回复后发送给用户。"
|
||||||
|
if result.get("video"):
|
||||||
|
result_msg += "\n已生成视频,将在你回复后发送给用户。"
|
||||||
|
result = result_msg
|
||||||
|
else:
|
||||||
|
result = f"生成失败:{result.get('error', '未知错误')}"
|
||||||
|
|
||||||
new_messages.append({
|
new_messages.append({
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"tool_call_id": tool_call.id,
|
"tool_call_id": tool_call.id,
|
||||||
|
|
@ -552,6 +607,7 @@ async def process_messages(context_id: int, is_group: bool = True):
|
||||||
assert reply is not None
|
assert reply is not None
|
||||||
await send_split_messages(handler, reply)
|
await send_split_messages(handler, reply)
|
||||||
|
|
||||||
|
# 发送主模型直接生成的图片
|
||||||
if reply_images:
|
if reply_images:
|
||||||
logger.debug(f"API响应 图片数:{len(reply_images)}")
|
logger.debug(f"API响应 图片数:{len(reply_images)}")
|
||||||
for i, image in enumerate(reply_images, start=1):
|
for i, image in enumerate(reply_images, start=1):
|
||||||
|
|
@ -560,6 +616,50 @@ async def process_messages(context_id: int, is_group: bool = True):
|
||||||
image_msg = MessageSegment.image(base64.b64decode(image_base64))
|
image_msg = MessageSegment.image(base64.b64decode(image_base64))
|
||||||
await handler.send(image_msg)
|
await handler.send(image_msg)
|
||||||
|
|
||||||
|
# 发送子模型生成的图片
|
||||||
|
if submodel_images:
|
||||||
|
logger.info(f"发送子模型生成的 {len(submodel_images)} 张图片")
|
||||||
|
for i, img_base64 in enumerate(submodel_images, start=1):
|
||||||
|
try:
|
||||||
|
logger.debug(f"正在发送子模型图片 {i}/{len(submodel_images)}")
|
||||||
|
# 处理可能的 data URL 前缀
|
||||||
|
if img_base64.startswith("data:"):
|
||||||
|
img_base64 = img_base64.split(",", 1)[-1] if "," in img_base64 else img_base64
|
||||||
|
image_msg = MessageSegment.image(base64.b64decode(img_base64))
|
||||||
|
await handler.send(image_msg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送子模型图片失败: {e}")
|
||||||
|
|
||||||
|
# 发送子模型生成的语音
|
||||||
|
if submodel_voices:
|
||||||
|
logger.info(f"发送子模型生成的 {len(submodel_voices)} 条语音")
|
||||||
|
for i, voice_data in enumerate(submodel_voices, start=1):
|
||||||
|
try:
|
||||||
|
logger.debug(f"正在发送子模型语音 {i}/{len(submodel_voices)}")
|
||||||
|
if voice_data.startswith("data:"):
|
||||||
|
voice_data = voice_data.split(",", 1)[-1] if "," in voice_data else voice_data
|
||||||
|
voice_msg = MessageSegment.record(base64.b64decode(voice_data))
|
||||||
|
await handler.send(voice_msg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送子模型语音失败: {e}")
|
||||||
|
|
||||||
|
# 发送子模型生成的视频
|
||||||
|
if submodel_videos:
|
||||||
|
logger.info(f"发送子模型生成的 {len(submodel_videos)} 个视频")
|
||||||
|
for i, video_data in enumerate(submodel_videos, start=1):
|
||||||
|
try:
|
||||||
|
logger.debug(f"正在发送子模型视频 {i}/{len(submodel_videos)}")
|
||||||
|
# 视频可能是 URL 或 base64
|
||||||
|
if video_data.startswith("http"):
|
||||||
|
video_msg = MessageSegment.video(video_data)
|
||||||
|
else:
|
||||||
|
if video_data.startswith("data:"):
|
||||||
|
video_data = video_data.split(",", 1)[-1] if "," in video_data else video_data
|
||||||
|
video_msg = MessageSegment.video(base64.b64decode(video_data))
|
||||||
|
await handler.send(video_msg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送子模型视频失败: {e}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.opt(exception=e).error(f"API请求失败 {'群号' if is_group else '用户'}:{context_id}")
|
logger.opt(exception=e).error(f"API请求失败 {'群号' if is_group else '用户'}:{context_id}")
|
||||||
# 如果在处理过程中出现异常,恢复未处理的消息到state中
|
# 如果在处理过程中出现异常,恢复未处理的消息到state中
|
||||||
|
|
@ -856,6 +956,9 @@ async def load_state():
|
||||||
async def init_plugin():
|
async def init_plugin():
|
||||||
logger.info("插件启动初始化")
|
logger.info("插件启动初始化")
|
||||||
await load_state()
|
await load_state()
|
||||||
|
# 加载定时任务
|
||||||
|
scheduler_manager = SchedulerManager.get_instance()
|
||||||
|
await scheduler_manager.load_tasks()
|
||||||
# 每5分钟保存状态
|
# 每5分钟保存状态
|
||||||
scheduler.add_job(save_state, "interval", minutes=5)
|
scheduler.add_job(save_state, "interval", minutes=5)
|
||||||
|
|
||||||
|
|
@ -864,5 +967,8 @@ async def init_plugin():
|
||||||
async def cleanup_plugin():
|
async def cleanup_plugin():
|
||||||
logger.info("插件关闭清理")
|
logger.info("插件关闭清理")
|
||||||
await save_state()
|
await save_state()
|
||||||
|
# 保存定时任务
|
||||||
|
scheduler_manager = SchedulerManager.get_instance()
|
||||||
|
await scheduler_manager.save_tasks()
|
||||||
# 销毁MCPClient单例
|
# 销毁MCPClient单例
|
||||||
await MCPClient.destroy_instance()
|
await MCPClient.destroy_instance()
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -14,6 +16,14 @@ class PresetConfig(BaseModel):
|
||||||
support_mcp: bool = Field(False, description="是否支持MCP")
|
support_mcp: bool = Field(False, description="是否支持MCP")
|
||||||
support_image: bool = Field(False, description="是否支持图片输入")
|
support_image: bool = Field(False, description="是否支持图片输入")
|
||||||
|
|
||||||
|
# 子模型能力标记
|
||||||
|
support_to_image: bool = Field(False, description="是否支持生成图片")
|
||||||
|
support_to_voice: bool = Field(False, description="是否支持生成语音")
|
||||||
|
support_to_video: bool = Field(False, description="是否支持生成视频")
|
||||||
|
|
||||||
|
# 可调用的子模型列表
|
||||||
|
call_model_list: list[str] | None = Field(None, description="可调用的子模型名称列表")
|
||||||
|
|
||||||
class MCPServerConfig(BaseModel):
|
class MCPServerConfig(BaseModel):
|
||||||
"""MCP服务器配置"""
|
"""MCP服务器配置"""
|
||||||
command: str | None = Field(None, description="stdio模式下MCP命令")
|
command: str | None = Field(None, description="stdio模式下MCP命令")
|
||||||
|
|
@ -21,6 +31,8 @@ class MCPServerConfig(BaseModel):
|
||||||
env: dict[str, str] | None = Field({}, description="stdio模式下MCP命令环境变量")
|
env: dict[str, str] | None = Field({}, description="stdio模式下MCP命令环境变量")
|
||||||
url: str | None = Field(None, description="sse模式下MCP服务器地址")
|
url: str | None = Field(None, description="sse模式下MCP服务器地址")
|
||||||
headers: dict[str, str] | None = Field({}, description="sse模式下http请求头,用于认证或其他设置")
|
headers: dict[str, str] | None = Field({}, description="sse模式下http请求头,用于认证或其他设置")
|
||||||
|
transport_type: str | None = Field(None, description="请求类型 sse、stdio 或 streamablehttp")
|
||||||
|
|
||||||
|
|
||||||
# 额外字段
|
# 额外字段
|
||||||
friendly_name: str | None = Field(None, description="MCP服务器友好名称")
|
friendly_name: str | None = Field(None, description="MCP服务器友好名称")
|
||||||
|
|
@ -51,7 +63,18 @@ class ScopedConfig(BaseModel):
|
||||||
)
|
)
|
||||||
enable_private_chat: bool = Field(False, description="是否启用私聊功能")
|
enable_private_chat: bool = Field(False, description="是否启用私聊功能")
|
||||||
private_chat_preset: str = Field("off", description="私聊默认使用的预设名称")
|
private_chat_preset: str = Field("off", description="私聊默认使用的预设名称")
|
||||||
|
scheduler_max_retry: int = Field(5, description="定时任务AI调用最大重试次数")
|
||||||
|
scheduler_default_reminder: str = Field(
|
||||||
|
"您设置的提醒时间到了:{description}",
|
||||||
|
description="AI调用失败时的默认提醒模板"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
llmchat: ScopedConfig
|
llmchat: ScopedConfig
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class transportType:
|
||||||
|
sse = "sse"
|
||||||
|
stdio = "stdio"
|
||||||
|
streamablehttp = "streamablehttp"
|
||||||
|
|
@ -4,22 +4,28 @@ from contextlib import AsyncExitStack
|
||||||
from mcp import ClientSession, StdioServerParameters
|
from mcp import ClientSession, StdioServerParameters
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
from mcp.client.stdio import stdio_client
|
from mcp.client.stdio import stdio_client
|
||||||
|
try:
|
||||||
|
from mcp.client.streamable_http import streamable_http_client as streamablehttp_client
|
||||||
|
except:
|
||||||
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
from nonebot import logger
|
from nonebot import logger
|
||||||
|
|
||||||
from .config import MCPServerConfig
|
from .config import MCPServerConfig, PresetConfig, ScopedConfig, transportType
|
||||||
from .onebottools import OneBotTools
|
from .onebottools import OneBotTools
|
||||||
|
from .scheduler import SchedulerManager
|
||||||
|
from .submodel_caller import SubModelCaller
|
||||||
|
|
||||||
|
|
||||||
class MCPClient:
|
class MCPClient:
|
||||||
_instance = None
|
_instance = None
|
||||||
_initialized = False
|
_initialized = False
|
||||||
|
|
||||||
def __new__(cls, server_config: dict[str, MCPServerConfig] | None = None):
|
def __new__(cls, server_config: dict[str, MCPServerConfig] | None = None, plugin_config: ScopedConfig | None = None):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self, server_config: dict[str, MCPServerConfig] | None = None):
|
def __init__(self, server_config: dict[str, MCPServerConfig] | None = None, plugin_config: ScopedConfig | None = None):
|
||||||
if self._initialized:
|
if self._initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -28,6 +34,7 @@ class MCPClient:
|
||||||
|
|
||||||
logger.info(f"正在初始化MCPClient单例,共有{len(server_config)}个服务器配置")
|
logger.info(f"正在初始化MCPClient单例,共有{len(server_config)}个服务器配置")
|
||||||
self.server_config = server_config
|
self.server_config = server_config
|
||||||
|
self.plugin_config = plugin_config
|
||||||
self.sessions = {}
|
self.sessions = {}
|
||||||
self.exit_stack = AsyncExitStack()
|
self.exit_stack = AsyncExitStack()
|
||||||
# 添加工具列表缓存
|
# 添加工具列表缓存
|
||||||
|
|
@ -35,16 +42,20 @@ class MCPClient:
|
||||||
self._cache_initialized = False
|
self._cache_initialized = False
|
||||||
# 初始化OneBot工具
|
# 初始化OneBot工具
|
||||||
self.onebot_tools = OneBotTools()
|
self.onebot_tools = OneBotTools()
|
||||||
|
# 初始化定时任务管理器
|
||||||
|
self.scheduler_manager = SchedulerManager.get_instance()
|
||||||
|
# 初始化子模型调用器(如果有 plugin_config)
|
||||||
|
self.submodel_caller = SubModelCaller.get_instance(plugin_config) if plugin_config else None
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.debug("MCPClient单例初始化成功")
|
logger.debug("MCPClient单例初始化成功")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls, server_config: dict[str, MCPServerConfig] | None = None):
|
def get_instance(cls, server_config: dict[str, MCPServerConfig] | None = None, plugin_config: ScopedConfig | None = None):
|
||||||
"""获取MCPClient实例"""
|
"""获取MCPClient实例"""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
if server_config is None:
|
if server_config is None:
|
||||||
raise ValueError("server_config must be provided for first initialization")
|
raise ValueError("server_config must be provided for first initialization")
|
||||||
cls._instance = cls(server_config)
|
cls._instance = cls(server_config, plugin_config)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -85,18 +96,33 @@ class MCPClient:
|
||||||
self.exit_stack = AsyncExitStack()
|
self.exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
if config.url:
|
if config.transport_type is None:
|
||||||
transport = await self.exit_stack.enter_async_context(
|
if config.url:
|
||||||
sse_client(url=config.url, headers=config.headers)
|
config.transport_type = transportType.sse
|
||||||
)
|
elif config.command:
|
||||||
elif config.command:
|
config.transport_type = transportType.stdio
|
||||||
transport = await self.exit_stack.enter_async_context(
|
else:
|
||||||
stdio_client(StdioServerParameters(**config.model_dump()))
|
raise ValueError("Server config must have either url or command")
|
||||||
)
|
|
||||||
else:
|
match config.transport_type:
|
||||||
raise ValueError("Server config must have either url or command")
|
case transportType.sse:
|
||||||
|
transport = await self.exit_stack.enter_async_context(
|
||||||
|
sse_client(url=config.url, headers=config.headers)
|
||||||
|
)
|
||||||
|
read, write = transport
|
||||||
|
case transportType.stdio:
|
||||||
|
transport = await self.exit_stack.enter_async_context(
|
||||||
|
stdio_client(StdioServerParameters(**config.model_dump()))
|
||||||
|
)
|
||||||
|
read, write = transport
|
||||||
|
case transportType.streamablehttp:
|
||||||
|
transport = await self.exit_stack.enter_async_context(
|
||||||
|
streamablehttp_client(url=config.url, headers=config.headers)
|
||||||
|
)
|
||||||
|
read, write, session_callback = transport
|
||||||
|
case _:
|
||||||
|
raise ValueError("Server config must have either url or command")
|
||||||
|
|
||||||
read, write = transport
|
|
||||||
self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
||||||
await self.session.initialize()
|
await self.session.initialize()
|
||||||
return self.session
|
return self.session
|
||||||
|
|
@ -138,18 +164,52 @@ class MCPClient:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def get_available_tools(self, is_group: bool):
|
async def get_available_tools(self, is_group: bool, current_preset: PresetConfig | None = None):
|
||||||
"""获取可用工具列表,使用缓存机制"""
|
"""获取可用工具列表,使用缓存机制
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_group: 是否群聊场景
|
||||||
|
current_preset: 当前使用的预设配置(用于获取子模型工具)
|
||||||
|
"""
|
||||||
await self.init_tools_cache()
|
await self.init_tools_cache()
|
||||||
available_tools = self._tools_cache.copy() if self._tools_cache else []
|
available_tools = self._tools_cache.copy() if self._tools_cache else []
|
||||||
if is_group:
|
if is_group:
|
||||||
# 群聊场景,包含OneBot工具和MCP工具
|
# 群聊场景,包含OneBot工具和MCP工具
|
||||||
available_tools.extend(self.onebot_tools.get_available_tools())
|
available_tools.extend(self.onebot_tools.get_available_tools())
|
||||||
|
# 添加定时任务工具(群聊和私聊都可用)
|
||||||
|
available_tools.extend(self.scheduler_manager.get_available_tools())
|
||||||
|
# 添加子模型调用工具(根据当前预设的 call_model_list 动态生成)
|
||||||
|
if self.submodel_caller and current_preset:
|
||||||
|
submodel_tools = self.submodel_caller.get_available_tools(current_preset)
|
||||||
|
available_tools.extend(submodel_tools)
|
||||||
|
if submodel_tools:
|
||||||
|
logger.debug(f"添加了 {len(submodel_tools)} 个子模型调用工具")
|
||||||
logger.debug(f"获取可用工具列表,共{len(available_tools)}个工具")
|
logger.debug(f"获取可用工具列表,共{len(available_tools)}个工具")
|
||||||
return available_tools
|
return available_tools
|
||||||
|
|
||||||
async def call_tool(self, tool_name: str, tool_args: dict, group_id: int | None = None, bot_id: str | None = None):
|
async def call_tool(
|
||||||
"""按需连接调用工具,调用后立即断开"""
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
tool_args: dict,
|
||||||
|
group_id: int | None = None,
|
||||||
|
bot_id: str | None = None,
|
||||||
|
user_id: int | None = None,
|
||||||
|
is_group: bool = True,
|
||||||
|
current_preset: PresetConfig | None = None,
|
||||||
|
user_images: list[str] | None = None
|
||||||
|
):
|
||||||
|
"""按需连接调用工具,调用后立即断开
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
tool_args: 工具参数
|
||||||
|
group_id: 群号(群聊时必需)
|
||||||
|
bot_id: 机器人ID
|
||||||
|
user_id: 用户ID
|
||||||
|
is_group: 是否群聊
|
||||||
|
current_preset: 当前使用的预设配置(子模型调用时必需)
|
||||||
|
user_images: 用户消息中的图片列表(base64 编码),用于子模型参考
|
||||||
|
"""
|
||||||
# 检查是否是OneBot内置工具
|
# 检查是否是OneBot内置工具
|
||||||
if tool_name.startswith("ob__"):
|
if tool_name.startswith("ob__"):
|
||||||
if group_id is None or bot_id is None:
|
if group_id is None or bot_id is None:
|
||||||
|
|
@ -157,6 +217,29 @@ class MCPClient:
|
||||||
logger.info(f"调用OneBot工具[{tool_name}]")
|
logger.info(f"调用OneBot工具[{tool_name}]")
|
||||||
return await self.onebot_tools.call_tool(tool_name, tool_args, group_id, bot_id)
|
return await self.onebot_tools.call_tool(tool_name, tool_args, group_id, bot_id)
|
||||||
|
|
||||||
|
# 检查是否是定时任务工具
|
||||||
|
if tool_name.startswith("scheduler__"):
|
||||||
|
context_id = group_id if is_group else user_id
|
||||||
|
if context_id is None or user_id is None:
|
||||||
|
return "定时任务工具需要提供context_id和user_id参数"
|
||||||
|
logger.info(f"调用定时任务工具[{tool_name}]")
|
||||||
|
return await self.scheduler_manager.call_tool(
|
||||||
|
tool_name, tool_args, context_id, is_group, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查是否是子模型调用工具
|
||||||
|
if tool_name.startswith("submodel__"):
|
||||||
|
if not self.submodel_caller:
|
||||||
|
return "子模型调用器未初始化"
|
||||||
|
if not current_preset:
|
||||||
|
return "子模型调用需要提供 current_preset 参数"
|
||||||
|
logger.info(f"调用子模型工具[{tool_name}],参考图片数: {len(user_images) if user_images else 0}")
|
||||||
|
result = await self.submodel_caller.call_tool(
|
||||||
|
tool_name, tool_args, current_preset, reference_images=user_images
|
||||||
|
)
|
||||||
|
# 返回结构化结果,让上层处理
|
||||||
|
return result
|
||||||
|
|
||||||
# 检查是否是MCP工具
|
# 检查是否是MCP工具
|
||||||
if tool_name.startswith("mcp__"):
|
if tool_name.startswith("mcp__"):
|
||||||
# MCP工具处理:mcp__server_name__tool_name
|
# MCP工具处理:mcp__server_name__tool_name
|
||||||
|
|
@ -186,6 +269,16 @@ class MCPClient:
|
||||||
if tool_name.startswith("ob__"):
|
if tool_name.startswith("ob__"):
|
||||||
return self.onebot_tools.get_friendly_name(tool_name)
|
return self.onebot_tools.get_friendly_name(tool_name)
|
||||||
|
|
||||||
|
# 检查是否是定时任务工具
|
||||||
|
if tool_name.startswith("scheduler__"):
|
||||||
|
return self.scheduler_manager.get_friendly_name(tool_name)
|
||||||
|
|
||||||
|
# 检查是否是子模型调用工具
|
||||||
|
if tool_name.startswith("submodel__"):
|
||||||
|
if self.submodel_caller:
|
||||||
|
return self.submodel_caller.get_friendly_name(tool_name)
|
||||||
|
return tool_name
|
||||||
|
|
||||||
# 检查是否是MCP工具
|
# 检查是否是MCP工具
|
||||||
if tool_name.startswith("mcp__"):
|
if tool_name.startswith("mcp__"):
|
||||||
# MCP工具处理:mcp__server_name__tool_name
|
# MCP工具处理:mcp__server_name__tool_name
|
||||||
|
|
|
||||||
785
nonebot_plugin_llmchat/scheduler.py
Normal file
785
nonebot_plugin_llmchat/scheduler.py
Normal file
|
|
@ -0,0 +1,785 @@
|
||||||
|
"""定时任务管理模块"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import httpx
|
||||||
|
from nonebot import get_bot, get_driver, logger, require
|
||||||
|
from nonebot.adapters.onebot.v11 import Bot, Message
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
require("nonebot_plugin_localstore")
|
||||||
|
import nonebot_plugin_localstore as store
|
||||||
|
|
||||||
|
require("nonebot_plugin_apscheduler")
|
||||||
|
from nonebot_plugin_apscheduler import scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleType(str, Enum):
|
||||||
|
"""定时任务类型"""
|
||||||
|
INTERVAL_MINUTES = "interval_minutes" # 每N分钟
|
||||||
|
DAILY = "daily" # 每天指定时间
|
||||||
|
WEEKLY = "weekly" # 每周指定天
|
||||||
|
YEARLY = "yearly" # 每年指定日期
|
||||||
|
ONCE = "once" # 一次性任务
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduledTask(BaseModel):
|
||||||
|
"""定时任务模型"""
|
||||||
|
task_id: str = Field(default_factory=lambda: str(uuid.uuid4())[:8])
|
||||||
|
context_id: int # 群号或用户ID
|
||||||
|
is_group: bool # 是否群聊
|
||||||
|
schedule_type: ScheduleType # 任务类型
|
||||||
|
description: str # 任务描述(用于AI生成提醒)
|
||||||
|
creator_id: int # 创建者用户ID
|
||||||
|
created_at: datetime = Field(default_factory=datetime.now)
|
||||||
|
|
||||||
|
# 调度参数
|
||||||
|
interval_minutes: int | None = None # 间隔分钟数
|
||||||
|
hour: int | None = None # 小时 (0-23)
|
||||||
|
minute: int | None = None # 分钟 (0-59)
|
||||||
|
day_of_week: int | None = None # 周几 (0-6, 0=周一)
|
||||||
|
month: int | None = None # 月份 (1-12)
|
||||||
|
day: int | None = None # 日期 (1-31)
|
||||||
|
|
||||||
|
# 一次性任务
|
||||||
|
trigger_time: datetime | None = None # 触发时间
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerTools:
|
||||||
|
"""定时任务工具定义"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "scheduler__create_task",
|
||||||
|
"description": """创建一个定时提醒任务。支持以下类型:
|
||||||
|
- interval_minutes: 每隔N分钟提醒,需提供 interval_minutes (1-10080)
|
||||||
|
- daily: 每天指定时间提醒,需提供 hour (0-23) 和 minute (0-59)
|
||||||
|
- weekly: 每周指定天提醒,需提供 hour, minute, day_of_week (0=周一, 1=周二...6=周日)
|
||||||
|
- yearly: 每年指定日期提醒,需提供 month (1-12), day (1-31), hour, minute
|
||||||
|
- once: 一次性提醒,需提供 minutes_later 表示几分钟后触发 (1-525600)
|
||||||
|
|
||||||
|
重要提示:
|
||||||
|
- description 字段非常重要,它将在任务触发时用于生成提醒信息
|
||||||
|
- 如果任务需要获取实时信息(如天气、新闻等),请在描述中明确说明,触发时AI会调用相应工具获取最新数据
|
||||||
|
- 如果用户没有提供完整信息(如查天气但没说城市),可以先创建任务,然后询问用户,得到答案后用 scheduler__update_task 更新描述""",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"schedule_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "任务类型",
|
||||||
|
"enum": ["interval_minutes", "daily", "weekly", "yearly", "once"]
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "任务描述,将用于生成提醒信息"
|
||||||
|
},
|
||||||
|
"interval_minutes": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "间隔分钟数,仅interval_minutes类型需要",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 10080
|
||||||
|
},
|
||||||
|
"hour": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "小时 (0-23)",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 23
|
||||||
|
},
|
||||||
|
"minute": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "分钟 (0-59)",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 59
|
||||||
|
},
|
||||||
|
"day_of_week": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "周几 (0=周一, 1=周二...6=周日)",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 6
|
||||||
|
},
|
||||||
|
"month": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "月份 (1-12)",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 12
|
||||||
|
},
|
||||||
|
"day": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "日期 (1-31)",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 31
|
||||||
|
},
|
||||||
|
"minutes_later": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "几分钟后触发,仅once类型需要",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 525600
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["schedule_type", "description"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "scheduler__list_tasks",
|
||||||
|
"description": "列出当前聊天的所有定时任务",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "scheduler__delete_task",
|
||||||
|
"description": "删除指定的定时任务",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"task_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "要删除的任务ID"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["task_id"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "scheduler__update_task",
|
||||||
|
"description": """更新定时任务的描述。重要提示:
|
||||||
|
- 当用户在后续对话中补充了与已创建定时任务相关的信息时(如地点、具体要求、人名等),你应该主动调用此工具更新任务描述
|
||||||
|
- 例如:用户创建了"提醒我明天天气"的任务,后来告诉你他在"北京",你应该更新描述为"提醒我北京明天的天气"
|
||||||
|
- 任务描述应包含执行任务所需的所有关键信息,因为触发时AI会根据描述来生成提醒或执行操作
|
||||||
|
- 如果不确定最近创建的任务ID,可以先调用 scheduler__list_tasks 查看""",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"task_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "要更新的任务ID"
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "新的任务描述,应包含执行任务所需的所有关键信息"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["task_id", "description"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_available_tools(self) -> list[dict[str, Any]]:
|
||||||
|
"""获取可用的工具列表"""
|
||||||
|
return self.tools
|
||||||
|
|
||||||
|
def get_friendly_name(self, tool_name: str) -> str:
|
||||||
|
"""获取工具的友好名称"""
|
||||||
|
friendly_names = {
|
||||||
|
"scheduler__create_task": "定时任务 - 创建任务",
|
||||||
|
"scheduler__list_tasks": "定时任务 - 列出任务",
|
||||||
|
"scheduler__delete_task": "定时任务 - 删除任务",
|
||||||
|
"scheduler__update_task": "定时任务 - 更新任务",
|
||||||
|
}
|
||||||
|
return friendly_names.get(tool_name, tool_name)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerManager:
|
||||||
|
"""定时任务管理器"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_initialized = False
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.tasks: dict[str, ScheduledTask] = {}
|
||||||
|
self.tools = SchedulerTools()
|
||||||
|
self.data_file = store.get_plugin_data_file("llmchat_scheduler_tasks.json")
|
||||||
|
self._initialized = True
|
||||||
|
logger.info("SchedulerManager 初始化完成")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls) -> "SchedulerManager":
|
||||||
|
"""获取单例实例"""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
async def load_tasks(self):
|
||||||
|
"""从文件加载任务"""
|
||||||
|
logger.info(f"从文件加载定时任务: {self.data_file}")
|
||||||
|
if not os.path.exists(self.data_file):
|
||||||
|
logger.debug("定时任务文件不存在,跳过加载")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiofiles.open(self.data_file, encoding="utf8") as f:
|
||||||
|
data = json.loads(await f.read())
|
||||||
|
for task_id, task_data in data.items():
|
||||||
|
# 转换字符串为datetime
|
||||||
|
if task_data.get("created_at"):
|
||||||
|
task_data["created_at"] = datetime.fromisoformat(task_data["created_at"])
|
||||||
|
if task_data.get("trigger_time"):
|
||||||
|
task_data["trigger_time"] = datetime.fromisoformat(task_data["trigger_time"])
|
||||||
|
self.tasks[task_id] = ScheduledTask(**task_data)
|
||||||
|
|
||||||
|
logger.info(f"成功加载 {len(self.tasks)} 个定时任务")
|
||||||
|
# 注册所有任务到APScheduler
|
||||||
|
await self.register_all_jobs()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载定时任务失败: {e}")
|
||||||
|
|
||||||
|
async def save_tasks(self):
|
||||||
|
"""保存任务到文件"""
|
||||||
|
logger.info(f"保存定时任务到文件: {self.data_file}")
|
||||||
|
try:
|
||||||
|
data = {}
|
||||||
|
for task_id, task in self.tasks.items():
|
||||||
|
task_dict = task.model_dump()
|
||||||
|
# 转换datetime为字符串
|
||||||
|
if task_dict.get("created_at"):
|
||||||
|
task_dict["created_at"] = task_dict["created_at"].isoformat()
|
||||||
|
if task_dict.get("trigger_time"):
|
||||||
|
task_dict["trigger_time"] = task_dict["trigger_time"].isoformat()
|
||||||
|
data[task_id] = task_dict
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(self.data_file), exist_ok=True)
|
||||||
|
async with aiofiles.open(self.data_file, "w", encoding="utf8") as f:
|
||||||
|
await f.write(json.dumps(data, ensure_ascii=False, indent=2))
|
||||||
|
logger.info(f"成功保存 {len(self.tasks)} 个定时任务")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存定时任务失败: {e}")
|
||||||
|
|
||||||
|
def _validate_task_params(self, schedule_type: ScheduleType, **kwargs) -> str | None:
|
||||||
|
"""校验任务参数,返回错误信息或None"""
|
||||||
|
if schedule_type == ScheduleType.INTERVAL_MINUTES:
|
||||||
|
interval = kwargs.get("interval_minutes")
|
||||||
|
if interval is None:
|
||||||
|
return "interval_minutes类型需要提供 interval_minutes 参数"
|
||||||
|
if not 1 <= interval <= 10080:
|
||||||
|
return "interval_minutes 必须在 1-10080 之间"
|
||||||
|
|
||||||
|
elif schedule_type == ScheduleType.DAILY:
|
||||||
|
hour = kwargs.get("hour")
|
||||||
|
minute = kwargs.get("minute")
|
||||||
|
if hour is None or minute is None:
|
||||||
|
return "daily类型需要提供 hour 和 minute 参数"
|
||||||
|
if not 0 <= hour <= 23:
|
||||||
|
return "hour 必须在 0-23 之间"
|
||||||
|
if not 0 <= minute <= 59:
|
||||||
|
return "minute 必须在 0-59 之间"
|
||||||
|
|
||||||
|
elif schedule_type == ScheduleType.WEEKLY:
|
||||||
|
hour = kwargs.get("hour")
|
||||||
|
minute = kwargs.get("minute")
|
||||||
|
day_of_week = kwargs.get("day_of_week")
|
||||||
|
if hour is None or minute is None or day_of_week is None:
|
||||||
|
return "weekly类型需要提供 hour, minute 和 day_of_week 参数"
|
||||||
|
if not 0 <= hour <= 23:
|
||||||
|
return "hour 必须在 0-23 之间"
|
||||||
|
if not 0 <= minute <= 59:
|
||||||
|
return "minute 必须在 0-59 之间"
|
||||||
|
if not 0 <= day_of_week <= 6:
|
||||||
|
return "day_of_week 必须在 0-6 之间 (0=周一)"
|
||||||
|
|
||||||
|
elif schedule_type == ScheduleType.YEARLY:
|
||||||
|
hour = kwargs.get("hour")
|
||||||
|
minute = kwargs.get("minute")
|
||||||
|
month = kwargs.get("month")
|
||||||
|
day = kwargs.get("day")
|
||||||
|
if hour is None or minute is None or month is None or day is None:
|
||||||
|
return "yearly类型需要提供 hour, minute, month 和 day 参数"
|
||||||
|
if not 0 <= hour <= 23:
|
||||||
|
return "hour 必须在 0-23 之间"
|
||||||
|
if not 0 <= minute <= 59:
|
||||||
|
return "minute 必须在 0-59 之间"
|
||||||
|
if not 1 <= month <= 12:
|
||||||
|
return "month 必须在 1-12 之间"
|
||||||
|
if not 1 <= day <= 31:
|
||||||
|
return "day 必须在 1-31 之间"
|
||||||
|
|
||||||
|
elif schedule_type == ScheduleType.ONCE:
|
||||||
|
minutes_later = kwargs.get("minutes_later")
|
||||||
|
if minutes_later is None:
|
||||||
|
return "once类型需要提供 minutes_later 参数"
|
||||||
|
if not 1 <= minutes_later <= 525600:
|
||||||
|
return "minutes_later 必须在 1-525600 之间"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def create_task(
|
||||||
|
self,
|
||||||
|
context_id: int,
|
||||||
|
is_group: bool,
|
||||||
|
creator_id: int,
|
||||||
|
schedule_type: str,
|
||||||
|
description: str,
|
||||||
|
**kwargs
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
"""创建定时任务"""
|
||||||
|
try:
|
||||||
|
stype = ScheduleType(schedule_type)
|
||||||
|
except ValueError:
|
||||||
|
return False, f"无效的任务类型: {schedule_type}"
|
||||||
|
|
||||||
|
# 参数校验
|
||||||
|
error = self._validate_task_params(stype, **kwargs)
|
||||||
|
if error:
|
||||||
|
return False, error
|
||||||
|
|
||||||
|
# 计算一次性任务的触发时间
|
||||||
|
trigger_time = None
|
||||||
|
if stype == ScheduleType.ONCE:
|
||||||
|
minutes_later = kwargs.get("minutes_later", 0)
|
||||||
|
trigger_time = datetime.now() + timedelta(minutes=minutes_later)
|
||||||
|
|
||||||
|
# 创建任务
|
||||||
|
task = ScheduledTask(
|
||||||
|
context_id=context_id,
|
||||||
|
is_group=is_group,
|
||||||
|
schedule_type=stype,
|
||||||
|
description=description,
|
||||||
|
creator_id=creator_id,
|
||||||
|
interval_minutes=kwargs.get("interval_minutes"),
|
||||||
|
hour=kwargs.get("hour"),
|
||||||
|
minute=kwargs.get("minute"),
|
||||||
|
day_of_week=kwargs.get("day_of_week"),
|
||||||
|
month=kwargs.get("month"),
|
||||||
|
day=kwargs.get("day"),
|
||||||
|
trigger_time=trigger_time
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tasks[task.task_id] = task
|
||||||
|
|
||||||
|
# 注册到APScheduler
|
||||||
|
self._register_job(task)
|
||||||
|
|
||||||
|
# 保存
|
||||||
|
await self.save_tasks()
|
||||||
|
|
||||||
|
logger.info(f"创建定时任务成功: {task.task_id} - {description}")
|
||||||
|
return True, f"创建成功!任务ID: {task.task_id}"
|
||||||
|
|
||||||
|
async def delete_task(self, task_id: str, context_id: int, is_group: bool) -> tuple[bool, str]:
|
||||||
|
"""删除定时任务"""
|
||||||
|
if task_id not in self.tasks:
|
||||||
|
return False, f"任务不存在: {task_id}"
|
||||||
|
|
||||||
|
task = self.tasks[task_id]
|
||||||
|
|
||||||
|
# 检查权限:只能删除同一聊天的任务
|
||||||
|
if task.context_id != context_id or task.is_group != is_group:
|
||||||
|
return False, "无法删除其他聊天的任务"
|
||||||
|
|
||||||
|
# 从APScheduler移除
|
||||||
|
job_id = f"scheduler_{task_id}"
|
||||||
|
if scheduler.get_job(job_id):
|
||||||
|
scheduler.remove_job(job_id)
|
||||||
|
|
||||||
|
# 删除任务
|
||||||
|
del self.tasks[task_id]
|
||||||
|
await self.save_tasks()
|
||||||
|
|
||||||
|
logger.info(f"删除定时任务: {task_id}")
|
||||||
|
return True, f"任务 {task_id} 已删除"
|
||||||
|
|
||||||
|
async def update_task(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
context_id: int,
|
||||||
|
is_group: bool,
|
||||||
|
description: str
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
"""更新定时任务"""
|
||||||
|
if task_id not in self.tasks:
|
||||||
|
return False, f"任务不存在: {task_id}"
|
||||||
|
|
||||||
|
task = self.tasks[task_id]
|
||||||
|
|
||||||
|
# 检查权限
|
||||||
|
if task.context_id != context_id or task.is_group != is_group:
|
||||||
|
return False, "无法更新其他聊天的任务"
|
||||||
|
|
||||||
|
task.description = description
|
||||||
|
await self.save_tasks()
|
||||||
|
|
||||||
|
logger.info(f"更新定时任务: {task_id}")
|
||||||
|
return True, f"任务 {task_id} 已更新"
|
||||||
|
|
||||||
|
def list_tasks(self, context_id: int, is_group: bool) -> list[dict]:
|
||||||
|
"""列出指定聊天的所有任务"""
|
||||||
|
result = []
|
||||||
|
for task in self.tasks.values():
|
||||||
|
if task.context_id == context_id and task.is_group == is_group:
|
||||||
|
task_info = {
|
||||||
|
"task_id": task.task_id,
|
||||||
|
"description": task.description,
|
||||||
|
"schedule_type": task.schedule_type.value,
|
||||||
|
"created_at": task.created_at.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加具体时间信息
|
||||||
|
if task.schedule_type == ScheduleType.INTERVAL_MINUTES:
|
||||||
|
task_info["schedule"] = f"每 {task.interval_minutes} 分钟"
|
||||||
|
elif task.schedule_type == ScheduleType.DAILY:
|
||||||
|
task_info["schedule"] = f"每天 {task.hour:02d}:{task.minute:02d}"
|
||||||
|
elif task.schedule_type == ScheduleType.WEEKLY:
|
||||||
|
weekdays = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
|
||||||
|
task_info["schedule"] = f"每{weekdays[task.day_of_week]} {task.hour:02d}:{task.minute:02d}"
|
||||||
|
elif task.schedule_type == ScheduleType.YEARLY:
|
||||||
|
task_info["schedule"] = f"每年 {task.month}月{task.day}日 {task.hour:02d}:{task.minute:02d}"
|
||||||
|
elif task.schedule_type == ScheduleType.ONCE:
|
||||||
|
if task.trigger_time:
|
||||||
|
task_info["schedule"] = f"一次性: {task.trigger_time.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
|
|
||||||
|
result.append(task_info)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _register_job(self, task: ScheduledTask):
|
||||||
|
"""注册单个任务到APScheduler"""
|
||||||
|
job_id = f"scheduler_{task.task_id}"
|
||||||
|
|
||||||
|
# 移除已存在的同ID任务
|
||||||
|
if scheduler.get_job(job_id):
|
||||||
|
scheduler.remove_job(job_id)
|
||||||
|
|
||||||
|
if task.schedule_type == ScheduleType.INTERVAL_MINUTES:
|
||||||
|
scheduler.add_job(
|
||||||
|
self._trigger_task,
|
||||||
|
"interval",
|
||||||
|
minutes=task.interval_minutes,
|
||||||
|
id=job_id,
|
||||||
|
args=[task.task_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
elif task.schedule_type == ScheduleType.DAILY:
|
||||||
|
scheduler.add_job(
|
||||||
|
self._trigger_task,
|
||||||
|
"cron",
|
||||||
|
hour=task.hour,
|
||||||
|
minute=task.minute,
|
||||||
|
id=job_id,
|
||||||
|
args=[task.task_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
elif task.schedule_type == ScheduleType.WEEKLY:
|
||||||
|
# APScheduler的day_of_week: 0=周一...6=周日
|
||||||
|
scheduler.add_job(
|
||||||
|
self._trigger_task,
|
||||||
|
"cron",
|
||||||
|
day_of_week=task.day_of_week,
|
||||||
|
hour=task.hour,
|
||||||
|
minute=task.minute,
|
||||||
|
id=job_id,
|
||||||
|
args=[task.task_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
elif task.schedule_type == ScheduleType.YEARLY:
|
||||||
|
scheduler.add_job(
|
||||||
|
self._trigger_task,
|
||||||
|
"cron",
|
||||||
|
month=task.month,
|
||||||
|
day=task.day,
|
||||||
|
hour=task.hour,
|
||||||
|
minute=task.minute,
|
||||||
|
id=job_id,
|
||||||
|
args=[task.task_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
elif task.schedule_type == ScheduleType.ONCE:
|
||||||
|
if task.trigger_time and task.trigger_time > datetime.now():
|
||||||
|
scheduler.add_job(
|
||||||
|
self._trigger_task,
|
||||||
|
"date",
|
||||||
|
run_date=task.trigger_time,
|
||||||
|
id=job_id,
|
||||||
|
args=[task.task_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"注册定时任务到APScheduler: {job_id}")
|
||||||
|
|
||||||
|
async def register_all_jobs(self):
|
||||||
|
"""注册所有任务到APScheduler"""
|
||||||
|
logger.info(f"注册 {len(self.tasks)} 个任务到APScheduler")
|
||||||
|
for task in self.tasks.values():
|
||||||
|
self._register_job(task)
|
||||||
|
|
||||||
|
async def _trigger_task(self, task_id: str):
|
||||||
|
"""任务触发处理"""
|
||||||
|
if task_id not in self.tasks:
|
||||||
|
logger.warning(f"触发的任务不存在: {task_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
task = self.tasks[task_id]
|
||||||
|
logger.info(f"定时任务触发: {task_id} - {task.description}")
|
||||||
|
|
||||||
|
# 导入配置(避免循环导入)
|
||||||
|
from .config import ScopedConfig
|
||||||
|
from nonebot import get_plugin_config
|
||||||
|
from .config import Config
|
||||||
|
plugin_config = get_plugin_config(Config).llmchat
|
||||||
|
|
||||||
|
# 获取Bot
|
||||||
|
try:
|
||||||
|
bots = list(get_driver().bots.values())
|
||||||
|
if not bots:
|
||||||
|
logger.error("没有可用的Bot")
|
||||||
|
return
|
||||||
|
bot: Bot = bots[0] # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取Bot失败: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 尝试调用AI生成提醒信息
|
||||||
|
reminder_message = await self._generate_ai_reminder(task, plugin_config)
|
||||||
|
|
||||||
|
# 发送消息
|
||||||
|
try:
|
||||||
|
if task.is_group:
|
||||||
|
await bot.send_group_msg(group_id=task.context_id, message=Message(reminder_message))
|
||||||
|
else:
|
||||||
|
await bot.send_private_msg(user_id=task.context_id, message=Message(reminder_message))
|
||||||
|
logger.info(f"定时任务提醒发送成功: {task_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送提醒消息失败: {e}")
|
||||||
|
|
||||||
|
# 一次性任务触发后删除
|
||||||
|
if task.schedule_type == ScheduleType.ONCE:
|
||||||
|
logger.info(f"删除一次性任务: {task_id}")
|
||||||
|
del self.tasks[task_id]
|
||||||
|
await self.save_tasks()
|
||||||
|
|
||||||
|
async def _generate_ai_reminder(self, task: ScheduledTask, plugin_config) -> str:
|
||||||
|
"""调用AI生成提醒信息,支持调用MCP工具获取实时信息"""
|
||||||
|
max_retry = plugin_config.scheduler_max_retry
|
||||||
|
default_reminder = plugin_config.scheduler_default_reminder.format(description=task.description)
|
||||||
|
|
||||||
|
# 获取预设配置
|
||||||
|
preset = None
|
||||||
|
if task.is_group:
|
||||||
|
from . import group_states
|
||||||
|
state = group_states.get(task.context_id)
|
||||||
|
if state and state.preset_name != "off":
|
||||||
|
for p in plugin_config.api_presets:
|
||||||
|
if p.name == state.preset_name:
|
||||||
|
preset = p
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
from . import private_chat_states
|
||||||
|
state = private_chat_states.get(task.context_id)
|
||||||
|
if state and state.preset_name != "off":
|
||||||
|
for p in plugin_config.api_presets:
|
||||||
|
if p.name == state.preset_name:
|
||||||
|
preset = p
|
||||||
|
break
|
||||||
|
|
||||||
|
if not preset:
|
||||||
|
# 没有配置预设,使用默认提醒
|
||||||
|
logger.debug("没有可用的API预设,使用默认提醒")
|
||||||
|
return default_reminder
|
||||||
|
|
||||||
|
# 构建AI请求
|
||||||
|
system_prompt = f"""你是一个友好的提醒助手。用户设置了一个定时提醒任务,现在任务触发了。
|
||||||
|
请根据任务描述生成一条简短、友好的提醒消息。
|
||||||
|
要求:
|
||||||
|
- 如果任务描述中涉及需要获取实时信息的内容(如天气、新闻、股票等),你应该先使用相应的工具获取最新信息,然后基于获取到的信息生成提醒
|
||||||
|
- 消息要简洁,不要太长
|
||||||
|
- 语气要友好、亲切
|
||||||
|
- 可以适当使用语气词或颜文字
|
||||||
|
- 不要有多余的解释,直接发送提醒内容
|
||||||
|
- 如果使用了工具获取信息,请将关键信息整合到提醒消息中"""
|
||||||
|
|
||||||
|
user_prompt = f"任务描述:{task.description}"
|
||||||
|
|
||||||
|
# 初始化OpenAI客户端
|
||||||
|
if preset.proxy:
|
||||||
|
client = AsyncOpenAI(
|
||||||
|
base_url=preset.api_base,
|
||||||
|
api_key=preset.api_key,
|
||||||
|
timeout=plugin_config.request_timeout,
|
||||||
|
http_client=httpx.AsyncClient(proxy=preset.proxy),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
client = AsyncOpenAI(
|
||||||
|
base_url=preset.api_base,
|
||||||
|
api_key=preset.api_key,
|
||||||
|
timeout=plugin_config.request_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取可用工具(如果预设支持MCP)
|
||||||
|
available_tools = None
|
||||||
|
mcp_client = None
|
||||||
|
if preset.support_mcp:
|
||||||
|
try:
|
||||||
|
from .mcpclient import MCPClient
|
||||||
|
mcp_client = MCPClient.get_instance(plugin_config.mcp_servers)
|
||||||
|
# 获取MCP工具,但不包含OneBot工具和定时任务工具(避免在提醒时创建新任务)
|
||||||
|
await mcp_client.init_tools_cache()
|
||||||
|
available_tools = mcp_client._tools_cache.copy() if mcp_client._tools_cache else []
|
||||||
|
logger.debug(f"定时任务触发时可用工具数: {len(available_tools)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"获取MCP工具列表失败: {e}")
|
||||||
|
available_tools = None
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt}
|
||||||
|
]
|
||||||
|
|
||||||
|
for attempt in range(max_retry):
|
||||||
|
try:
|
||||||
|
# 构建请求参数
|
||||||
|
request_params = {
|
||||||
|
"model": preset.model_name,
|
||||||
|
"max_tokens": 512, # 增加token限制以支持工具调用
|
||||||
|
"temperature": 0.7,
|
||||||
|
"messages": messages
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果有可用工具,添加到请求中
|
||||||
|
if available_tools:
|
||||||
|
request_params["tools"] = available_tools
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(**request_params)
|
||||||
|
message = response.choices[0].message
|
||||||
|
|
||||||
|
# 处理工具调用
|
||||||
|
while available_tools and mcp_client and message and message.tool_calls:
|
||||||
|
logger.info(f"定时任务触发时AI调用工具: {[tc.function.name for tc in message.tool_calls]}")
|
||||||
|
|
||||||
|
# 添加assistant消息
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [tool_call.model_dump() for tool_call in message.tool_calls]
|
||||||
|
})
|
||||||
|
|
||||||
|
# 处理每个工具调用
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
tool_name = tool_call.function.name
|
||||||
|
tool_args = json.loads(tool_call.function.arguments)
|
||||||
|
|
||||||
|
logger.debug(f"调用工具: {tool_name}, 参数: {tool_args}")
|
||||||
|
|
||||||
|
# 调用MCP工具(使用简化的参数,因为这里不需要群操作)
|
||||||
|
try:
|
||||||
|
result = await mcp_client.call_tool(
|
||||||
|
tool_name,
|
||||||
|
tool_args,
|
||||||
|
group_id=task.context_id if task.is_group else None,
|
||||||
|
bot_id=None,
|
||||||
|
user_id=task.creator_id,
|
||||||
|
is_group=task.is_group
|
||||||
|
)
|
||||||
|
result_str = str(result) if result else "工具调用成功但无返回结果"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"工具调用失败: {e}")
|
||||||
|
result_str = f"工具调用失败: {e}"
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call.id,
|
||||||
|
"content": result_str
|
||||||
|
})
|
||||||
|
|
||||||
|
# 再次调用AI处理工具结果
|
||||||
|
response = await client.chat.completions.create(**request_params)
|
||||||
|
message = response.choices[0].message
|
||||||
|
|
||||||
|
content = message.content
|
||||||
|
if content:
|
||||||
|
logger.debug(f"AI生成提醒成功: {content[:50]}...")
|
||||||
|
return content.strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"AI生成提醒失败 (尝试 {attempt + 1}/{max_retry}): {e}")
|
||||||
|
if attempt < max_retry - 1:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# 重试失败,返回默认提醒
|
||||||
|
logger.warning(f"AI生成提醒全部失败,使用默认提醒")
|
||||||
|
return default_reminder
|
||||||
|
|
||||||
|
async def call_tool(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
tool_args: dict[str, Any],
|
||||||
|
context_id: int,
|
||||||
|
is_group: bool,
|
||||||
|
creator_id: int
|
||||||
|
) -> str:
|
||||||
|
"""调用定时任务工具"""
|
||||||
|
if tool_name == "scheduler__create_task":
|
||||||
|
success, message = await self.create_task(
|
||||||
|
context_id=context_id,
|
||||||
|
is_group=is_group,
|
||||||
|
creator_id=creator_id,
|
||||||
|
schedule_type=tool_args.get("schedule_type", ""),
|
||||||
|
description=tool_args.get("description", ""),
|
||||||
|
interval_minutes=tool_args.get("interval_minutes"),
|
||||||
|
hour=tool_args.get("hour"),
|
||||||
|
minute=tool_args.get("minute"),
|
||||||
|
day_of_week=tool_args.get("day_of_week"),
|
||||||
|
month=tool_args.get("month"),
|
||||||
|
day=tool_args.get("day"),
|
||||||
|
minutes_later=tool_args.get("minutes_later")
|
||||||
|
)
|
||||||
|
return message
|
||||||
|
|
||||||
|
elif tool_name == "scheduler__list_tasks":
|
||||||
|
tasks = self.list_tasks(context_id, is_group)
|
||||||
|
if not tasks:
|
||||||
|
return "当前没有定时任务"
|
||||||
|
return json.dumps(tasks, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
elif tool_name == "scheduler__delete_task":
|
||||||
|
success, message = await self.delete_task(
|
||||||
|
task_id=tool_args.get("task_id", ""),
|
||||||
|
context_id=context_id,
|
||||||
|
is_group=is_group
|
||||||
|
)
|
||||||
|
return message
|
||||||
|
|
||||||
|
elif tool_name == "scheduler__update_task":
|
||||||
|
success, message = await self.update_task(
|
||||||
|
task_id=tool_args.get("task_id", ""),
|
||||||
|
context_id=context_id,
|
||||||
|
is_group=is_group,
|
||||||
|
description=tool_args.get("description", "")
|
||||||
|
)
|
||||||
|
return message
|
||||||
|
|
||||||
|
return f"未知的工具: {tool_name}"
|
||||||
|
|
||||||
|
def get_available_tools(self) -> list[dict[str, Any]]:
|
||||||
|
"""获取可用工具列表"""
|
||||||
|
return self.tools.get_available_tools()
|
||||||
|
|
||||||
|
def get_friendly_name(self, tool_name: str) -> str:
|
||||||
|
"""获取工具友好名称"""
|
||||||
|
return self.tools.get_friendly_name(tool_name)
|
||||||
662
nonebot_plugin_llmchat/submodel_caller.py
Normal file
662
nonebot_plugin_llmchat/submodel_caller.py
Normal file
|
|
@ -0,0 +1,662 @@
|
||||||
|
"""子模型调用模块
|
||||||
|
|
||||||
|
允许主模型通过 function tool 调用其他模型来完成特定任务(如生成图片、语音、视频)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from nonebot import logger
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from .config import PresetConfig, ScopedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SubModelCaller:
|
||||||
|
"""子模型调用管理器"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_initialized = False
|
||||||
|
|
||||||
|
def __new__(cls, plugin_config: ScopedConfig | None = None):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self, plugin_config: ScopedConfig | None = None):
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
if plugin_config is None:
|
||||||
|
raise ValueError("plugin_config must be provided for first initialization")
|
||||||
|
|
||||||
|
self.plugin_config = plugin_config
|
||||||
|
self._preset_map: dict[str, PresetConfig] = {
|
||||||
|
p.name: p for p in plugin_config.api_presets
|
||||||
|
}
|
||||||
|
self._initialized = True
|
||||||
|
logger.info("SubModelCaller 初始化完成")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls, plugin_config: ScopedConfig | None = None) -> "SubModelCaller":
|
||||||
|
"""获取单例实例"""
|
||||||
|
if cls._instance is None:
|
||||||
|
if plugin_config is None:
|
||||||
|
raise ValueError("plugin_config must be provided for first initialization")
|
||||||
|
cls._instance = cls(plugin_config)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def _get_callable_presets(self, current_preset: PresetConfig) -> list[PresetConfig]:
|
||||||
|
"""获取当前预设可调用的子模型预设列表"""
|
||||||
|
if not current_preset.call_model_list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
callable_presets = []
|
||||||
|
for name in current_preset.call_model_list:
|
||||||
|
if name in self._preset_map:
|
||||||
|
callable_presets.append(self._preset_map[name])
|
||||||
|
else:
|
||||||
|
logger.warning(f"call_model_list 中的模型 '{name}' 不存在于 api_presets 中")
|
||||||
|
|
||||||
|
return callable_presets
|
||||||
|
|
||||||
|
def _get_presets_with_capability(
|
||||||
|
self,
|
||||||
|
current_preset: PresetConfig,
|
||||||
|
capability: str
|
||||||
|
) -> list[PresetConfig]:
|
||||||
|
"""获取具有特定能力的可调用子模型列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_preset: 当前主模型预设
|
||||||
|
capability: 能力名称,如 'support_to_image'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
具有该能力的子模型预设列表(按 call_model_list 顺序)
|
||||||
|
"""
|
||||||
|
callable_presets = self._get_callable_presets(current_preset)
|
||||||
|
return [p for p in callable_presets if getattr(p, capability, False)]
|
||||||
|
|
||||||
|
def get_available_tools(self, current_preset: PresetConfig) -> list[dict[str, Any]]:
|
||||||
|
"""根据当前预设的 call_model_list 动态生成可用的子模型调用工具
|
||||||
|
|
||||||
|
只有当 call_model_list 中存在具有相应能力的模型时,才会生成对应的工具。
|
||||||
|
"""
|
||||||
|
tools = []
|
||||||
|
|
||||||
|
# 检查是否有可调用的图片生成模型
|
||||||
|
image_models = self._get_presets_with_capability(current_preset, "support_to_image")
|
||||||
|
if image_models:
|
||||||
|
model_names = [m.name for m in image_models]
|
||||||
|
tools.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "submodel__generate_image",
|
||||||
|
"description": f"""调用子模型生成图片。可用的图片生成模型:{', '.join(model_names)}。
|
||||||
|
使用说明:
|
||||||
|
- 当用户要求生成图片时使用此工具
|
||||||
|
- prompt 应该是详细的图片描述,用英文效果更好
|
||||||
|
- 如果用户消息中包含图片(发送或引用),系统会自动将这些图片作为参考传递给子模型,无需在 prompt 中描述
|
||||||
|
- 系统会自动选择最优的模型,如果失败会自动切换备选模型
|
||||||
|
- 返回结果包含 base64 编码的图片数据""",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "图片生成提示词,描述要生成的图片内容或对参考图片的修改要求"
|
||||||
|
},
|
||||||
|
"preferred_model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": f"可选:指定使用的模型名称,可选值:{', '.join(model_names)}",
|
||||||
|
"enum": model_names
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["prompt"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# 检查是否有可调用的语音生成模型
|
||||||
|
voice_models = self._get_presets_with_capability(current_preset, "support_to_voice")
|
||||||
|
if voice_models:
|
||||||
|
model_names = [m.name for m in voice_models]
|
||||||
|
tools.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "submodel__generate_voice",
|
||||||
|
"description": f"""调用子模型生成语音。可用的语音生成模型:{', '.join(model_names)}。
|
||||||
|
使用说明:
|
||||||
|
- 当用户要求生成语音或朗读文本时使用此工具
|
||||||
|
- text 是要转换为语音的文本内容
|
||||||
|
- 返回结果包含 base64 编码的音频数据""",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "要转换为语音的文本内容"
|
||||||
|
},
|
||||||
|
"preferred_model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": f"可选:指定使用的模型名称,可选值:{', '.join(model_names)}",
|
||||||
|
"enum": model_names
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["text"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# 检查是否有可调用的视频生成模型
|
||||||
|
video_models = self._get_presets_with_capability(current_preset, "support_to_video")
|
||||||
|
if video_models:
|
||||||
|
model_names = [m.name for m in video_models]
|
||||||
|
tools.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "submodel__generate_video",
|
||||||
|
"description": f"""调用子模型生成视频。可用的视频生成模型:{', '.join(model_names)}。
|
||||||
|
使用说明:
|
||||||
|
- 当用户要求生成视频时使用此工具
|
||||||
|
- prompt 是视频内容描述
|
||||||
|
- 返回结果包含视频数据或URL""",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "视频生成提示词,描述要生成的视频内容"
|
||||||
|
},
|
||||||
|
"preferred_model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": f"可选:指定使用的模型名称,可选值:{', '.join(model_names)}",
|
||||||
|
"enum": model_names
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["prompt"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return tools
|
||||||
|
|
||||||
|
async def _call_model_api(
|
||||||
|
self,
|
||||||
|
preset: PresetConfig,
|
||||||
|
messages: list[dict],
|
||||||
|
tools: list[dict] | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""调用模型 API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preset: 模型预设配置
|
||||||
|
messages: 消息列表
|
||||||
|
tools: 可选的工具列表(如果模型支持 MCP)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含响应内容的字典
|
||||||
|
"""
|
||||||
|
# 初始化 OpenAI 客户端
|
||||||
|
if preset.proxy:
|
||||||
|
client = AsyncOpenAI(
|
||||||
|
base_url=preset.api_base,
|
||||||
|
api_key=preset.api_key,
|
||||||
|
timeout=self.plugin_config.request_timeout,
|
||||||
|
http_client=httpx.AsyncClient(proxy=preset.proxy),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
client = AsyncOpenAI(
|
||||||
|
base_url=preset.api_base,
|
||||||
|
api_key=preset.api_key,
|
||||||
|
timeout=self.plugin_config.request_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建请求参数
|
||||||
|
request_params = {
|
||||||
|
"model": preset.model_name,
|
||||||
|
"max_tokens": preset.max_tokens,
|
||||||
|
"temperature": preset.temperature,
|
||||||
|
"messages": messages
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果模型支持 MCP 并且提供了工具,添加到请求中
|
||||||
|
if preset.support_mcp and tools:
|
||||||
|
request_params["tools"] = tools
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(**request_params)
|
||||||
|
message = response.choices[0].message
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"content": message.content,
|
||||||
|
"tool_calls": message.tool_calls,
|
||||||
|
"images": getattr(message, "images", None),
|
||||||
|
"audio": getattr(message, "audio", None),
|
||||||
|
"video": getattr(message, "video", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _call_with_mcp_support(
|
||||||
|
self,
|
||||||
|
preset: PresetConfig,
|
||||||
|
initial_messages: list[dict],
|
||||||
|
mcp_tools: list[dict] | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""调用模型并处理可能的 MCP 工具调用
|
||||||
|
|
||||||
|
如果模型支持 MCP,会处理工具调用循环直到得到最终响应。
|
||||||
|
"""
|
||||||
|
messages = initial_messages.copy()
|
||||||
|
tools = mcp_tools if preset.support_mcp else None
|
||||||
|
|
||||||
|
# 最多进行 5 轮工具调用
|
||||||
|
max_tool_rounds = 5
|
||||||
|
|
||||||
|
for _ in range(max_tool_rounds):
|
||||||
|
result = await self._call_model_api(preset, messages, tools)
|
||||||
|
|
||||||
|
# 如果没有工具调用,直接返回结果
|
||||||
|
if not result["tool_calls"]:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 处理工具调用
|
||||||
|
logger.info(f"子模型 {preset.name} 请求调用工具: {[tc.function.name for tc in result['tool_calls']]}")
|
||||||
|
|
||||||
|
# 添加 assistant 消息
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [tc.model_dump() for tc in result["tool_calls"]]
|
||||||
|
})
|
||||||
|
|
||||||
|
# 处理每个工具调用
|
||||||
|
for tool_call in result["tool_calls"]:
|
||||||
|
tool_name = tool_call.function.name
|
||||||
|
tool_args = json.loads(tool_call.function.arguments)
|
||||||
|
|
||||||
|
# 调用 MCP 工具
|
||||||
|
try:
|
||||||
|
from .mcpclient import MCPClient
|
||||||
|
mcp_client = MCPClient.get_instance(self.plugin_config.mcp_servers)
|
||||||
|
tool_result = await mcp_client.call_tool(
|
||||||
|
tool_name,
|
||||||
|
tool_args,
|
||||||
|
group_id=None,
|
||||||
|
bot_id=None,
|
||||||
|
user_id=None,
|
||||||
|
is_group=False
|
||||||
|
)
|
||||||
|
result_str = str(tool_result) if tool_result else "工具调用成功"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"子模型 MCP 工具调用失败: {e}")
|
||||||
|
result_str = f"工具调用失败: {e}"
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call.id,
|
||||||
|
"content": result_str
|
||||||
|
})
|
||||||
|
|
||||||
|
# 超过最大轮数,返回最后的结果
|
||||||
|
logger.warning(f"子模型 {preset.name} 工具调用超过 {max_tool_rounds} 轮")
|
||||||
|
return await self._call_model_api(preset, messages, None)
|
||||||
|
|
||||||
|
async def generate_image(
|
||||||
|
self,
|
||||||
|
current_preset: PresetConfig,
|
||||||
|
prompt: str,
|
||||||
|
preferred_model: str | None = None,
|
||||||
|
reference_images: list[str] | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""生成图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_preset: 当前主模型预设
|
||||||
|
prompt: 图片生成提示词
|
||||||
|
preferred_model: 可选的指定模型名称
|
||||||
|
reference_images: 可选的参考图片列表(base64 编码)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含生成结果的字典:
|
||||||
|
- success: bool
|
||||||
|
- images: list[str] (base64 编码的图片)
|
||||||
|
- content: str (模型的文本回复)
|
||||||
|
- error: str (如果失败)
|
||||||
|
- model_used: str (实际使用的模型名称)
|
||||||
|
"""
|
||||||
|
image_models = self._get_presets_with_capability(current_preset, "support_to_image")
|
||||||
|
|
||||||
|
if not image_models:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "没有可用的图片生成模型",
|
||||||
|
"images": [],
|
||||||
|
"content": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果指定了模型,调整顺序
|
||||||
|
if preferred_model:
|
||||||
|
image_models = sorted(
|
||||||
|
image_models,
|
||||||
|
key=lambda p: 0 if p.name == preferred_model else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取 MCP 工具(如果需要)
|
||||||
|
mcp_tools = None
|
||||||
|
try:
|
||||||
|
from .mcpclient import MCPClient
|
||||||
|
mcp_client = MCPClient.get_instance(self.plugin_config.mcp_servers)
|
||||||
|
await mcp_client.init_tools_cache()
|
||||||
|
mcp_tools = mcp_client._tools_cache.copy() if mcp_client._tools_cache else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"获取 MCP 工具失败: {e}")
|
||||||
|
|
||||||
|
# 构建用户消息内容
|
||||||
|
user_content: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
# 添加文本提示
|
||||||
|
user_content.append({"type": "text", "text": prompt})
|
||||||
|
|
||||||
|
# 如果有参考图片,添加到消息中
|
||||||
|
if reference_images:
|
||||||
|
logger.info(f"子模型调用包含 {len(reference_images)} 张参考图片")
|
||||||
|
for img_base64 in reference_images:
|
||||||
|
# 确保格式正确
|
||||||
|
if not img_base64.startswith("data:"):
|
||||||
|
img_base64 = f"data:image/jpeg;base64,{img_base64}"
|
||||||
|
user_content.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": img_base64}
|
||||||
|
})
|
||||||
|
|
||||||
|
# 构建消息
|
||||||
|
system_prompt = "你是一个图片生成助手。请根据用户的描述生成图片。直接生成图片,不需要额外解释。"
|
||||||
|
if reference_images:
|
||||||
|
system_prompt += "\n用户提供了参考图片,请根据参考图片和用户的描述来生成或修改图片。"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": system_prompt
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": user_content if reference_images else prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
for preset in image_models:
|
||||||
|
logger.info(f"尝试使用模型 {preset.name} 生成图片")
|
||||||
|
try:
|
||||||
|
result = await self._call_with_mcp_support(preset, messages, mcp_tools)
|
||||||
|
|
||||||
|
# 检查是否有图片返回
|
||||||
|
images = result.get("images")
|
||||||
|
if images:
|
||||||
|
# 提取 base64 图片数据
|
||||||
|
image_list = []
|
||||||
|
for img in images:
|
||||||
|
if isinstance(img, dict) and "image_url" in img:
|
||||||
|
url = img["image_url"].get("url", "")
|
||||||
|
# 移除 data URL 前缀
|
||||||
|
if url.startswith("data:"):
|
||||||
|
# 格式: data:image/png;base64,xxxxx
|
||||||
|
base64_data = url.split(",", 1)[-1] if "," in url else url
|
||||||
|
else:
|
||||||
|
base64_data = url
|
||||||
|
image_list.append(base64_data)
|
||||||
|
elif isinstance(img, str):
|
||||||
|
image_list.append(img)
|
||||||
|
|
||||||
|
if image_list:
|
||||||
|
logger.info(f"模型 {preset.name} 成功生成 {len(image_list)} 张图片")
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"images": image_list,
|
||||||
|
"content": result.get("content", ""),
|
||||||
|
"model_used": preset.name
|
||||||
|
}
|
||||||
|
|
||||||
|
# 没有图片但有内容,可能是模型回复了文本
|
||||||
|
if result.get("content"):
|
||||||
|
logger.warning(f"模型 {preset.name} 返回了文本但没有图片")
|
||||||
|
errors.append(f"{preset.name}: 模型未生成图片")
|
||||||
|
else:
|
||||||
|
errors.append(f"{preset.name}: 模型无响应")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"模型 {preset.name} 调用失败: {e}")
|
||||||
|
errors.append(f"{preset.name}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 所有模型都失败了
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"所有模型都无法生成图片。详情:{'; '.join(errors)}",
|
||||||
|
"images": [],
|
||||||
|
"content": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
async def generate_voice(
|
||||||
|
self,
|
||||||
|
current_preset: PresetConfig,
|
||||||
|
text: str,
|
||||||
|
preferred_model: str | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""生成语音
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_preset: 当前主模型预设
|
||||||
|
text: 要转换为语音的文本
|
||||||
|
preferred_model: 可选的指定模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含生成结果的字典
|
||||||
|
"""
|
||||||
|
voice_models = self._get_presets_with_capability(current_preset, "support_to_voice")
|
||||||
|
|
||||||
|
if not voice_models:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "没有可用的语音生成模型",
|
||||||
|
"audio": None,
|
||||||
|
"content": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if preferred_model:
|
||||||
|
voice_models = sorted(
|
||||||
|
voice_models,
|
||||||
|
key=lambda p: 0 if p.name == preferred_model else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "你是一个语音生成助手。请将用户提供的文本转换为语音。"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"请将以下文本转换为语音:\n{text}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
for preset in voice_models:
|
||||||
|
logger.info(f"尝试使用模型 {preset.name} 生成语音")
|
||||||
|
try:
|
||||||
|
result = await self._call_with_mcp_support(preset, messages, None)
|
||||||
|
|
||||||
|
audio = result.get("audio")
|
||||||
|
if audio:
|
||||||
|
logger.info(f"模型 {preset.name} 成功生成语音")
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"audio": audio,
|
||||||
|
"content": result.get("content", ""),
|
||||||
|
"model_used": preset.name
|
||||||
|
}
|
||||||
|
|
||||||
|
errors.append(f"{preset.name}: 模型未生成语音")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"模型 {preset.name} 调用失败: {e}")
|
||||||
|
errors.append(f"{preset.name}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"所有模型都无法生成语音。详情:{'; '.join(errors)}",
|
||||||
|
"audio": None,
|
||||||
|
"content": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
async def generate_video(
|
||||||
|
self,
|
||||||
|
current_preset: PresetConfig,
|
||||||
|
prompt: str,
|
||||||
|
preferred_model: str | None = None,
|
||||||
|
reference_images: list[str] | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""生成视频
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_preset: 当前主模型预设
|
||||||
|
prompt: 视频生成提示词
|
||||||
|
preferred_model: 可选的指定模型名称
|
||||||
|
reference_images: 可选的参考图片列表(base64 编码)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含生成结果的字典
|
||||||
|
"""
|
||||||
|
video_models = self._get_presets_with_capability(current_preset, "support_to_video")
|
||||||
|
|
||||||
|
if not video_models:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "没有可用的视频生成模型",
|
||||||
|
"video": None,
|
||||||
|
"content": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if preferred_model:
|
||||||
|
video_models = sorted(
|
||||||
|
video_models,
|
||||||
|
key=lambda p: 0 if p.name == preferred_model else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建用户消息内容
|
||||||
|
user_content: list[dict[str, Any]] = []
|
||||||
|
user_content.append({"type": "text", "text": prompt})
|
||||||
|
|
||||||
|
# 如果有参考图片,添加到消息中
|
||||||
|
if reference_images:
|
||||||
|
logger.info(f"视频生成包含 {len(reference_images)} 张参考图片")
|
||||||
|
for img_base64 in reference_images:
|
||||||
|
if not img_base64.startswith("data:"):
|
||||||
|
img_base64 = f"data:image/jpeg;base64,{img_base64}"
|
||||||
|
user_content.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": img_base64}
|
||||||
|
})
|
||||||
|
|
||||||
|
system_prompt = "你是一个视频生成助手。请根据用户的描述生成视频。"
|
||||||
|
if reference_images:
|
||||||
|
system_prompt += "\n用户提供了参考图片,请根据参考图片和用户的描述来生成视频。"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": system_prompt
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": user_content if reference_images else prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
for preset in video_models:
|
||||||
|
logger.info(f"尝试使用模型 {preset.name} 生成视频")
|
||||||
|
try:
|
||||||
|
result = await self._call_with_mcp_support(preset, messages, None)
|
||||||
|
|
||||||
|
video = result.get("video")
|
||||||
|
if video:
|
||||||
|
logger.info(f"模型 {preset.name} 成功生成视频")
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"video": video,
|
||||||
|
"content": result.get("content", ""),
|
||||||
|
"model_used": preset.name
|
||||||
|
}
|
||||||
|
|
||||||
|
errors.append(f"{preset.name}: 模型未生成视频")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"模型 {preset.name} 调用失败: {e}")
|
||||||
|
errors.append(f"{preset.name}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"所有模型都无法生成视频。详情:{'; '.join(errors)}",
|
||||||
|
"video": None,
|
||||||
|
"content": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
async def call_tool(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
tool_args: dict[str, Any],
|
||||||
|
current_preset: PresetConfig,
|
||||||
|
reference_images: list[str] | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""工具调用入口
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
tool_args: 工具参数
|
||||||
|
current_preset: 当前主模型预设
|
||||||
|
reference_images: 可选的参考图片列表(base64 编码),来自用户消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具调用结果
|
||||||
|
"""
|
||||||
|
if tool_name == "submodel__generate_image":
|
||||||
|
return await self.generate_image(
|
||||||
|
current_preset=current_preset,
|
||||||
|
prompt=tool_args.get("prompt", ""),
|
||||||
|
preferred_model=tool_args.get("preferred_model"),
|
||||||
|
reference_images=reference_images
|
||||||
|
)
|
||||||
|
elif tool_name == "submodel__generate_voice":
|
||||||
|
return await self.generate_voice(
|
||||||
|
current_preset=current_preset,
|
||||||
|
text=tool_args.get("text", ""),
|
||||||
|
preferred_model=tool_args.get("preferred_model")
|
||||||
|
)
|
||||||
|
elif tool_name == "submodel__generate_video":
|
||||||
|
return await self.generate_video(
|
||||||
|
current_preset=current_preset,
|
||||||
|
prompt=tool_args.get("prompt", ""),
|
||||||
|
preferred_model=tool_args.get("preferred_model"),
|
||||||
|
reference_images=reference_images
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"未知的子模型工具: {tool_name}"
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_friendly_name(self, tool_name: str) -> str:
|
||||||
|
"""获取工具的友好名称"""
|
||||||
|
friendly_names = {
|
||||||
|
"submodel__generate_image": "子模型 - 生成图片",
|
||||||
|
"submodel__generate_voice": "子模型 - 生成语音",
|
||||||
|
"submodel__generate_video": "子模型 - 生成视频",
|
||||||
|
}
|
||||||
|
return friendly_names.get(tool_name, tool_name)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue