💾 尝试修改为数据库存储

This commit is contained in:
KawakazeNotFound 2025-11-07 15:35:46 +08:00
parent a6290ca7bf
commit 59b4f3c2a3
4 changed files with 610 additions and 72 deletions

View file

@ -38,12 +38,19 @@ import nonebot_plugin_localstore as store
require("nonebot_plugin_apscheduler")
from nonebot_plugin_apscheduler import scheduler
require("nonebot_plugin_tortoise_orm")
from nonebot_plugin_tortoise_orm import init_orm_plugin
if TYPE_CHECKING:
from openai.types.chat import (
ChatCompletionContentPartParam,
ChatCompletionMessageParam,
)
from .db_manager import DatabaseManager
from .models import ChatHistory, ChatMessage, GroupChatState, PrivateChatState
from .migration import backup_json_files, migrate_from_json_to_db
__plugin_meta__ = PluginMetadata(
name="llmchat",
description="支持多API预设、MCP协议、联网搜索、视觉模型、Nano Banana生图模型的AI群聊插件",
@ -625,8 +632,13 @@ reset_handler = on_command(
async def handle_reset(event: GroupMessageEvent, args: Message = CommandArg()):
group_id = event.group_id
# 清空内存状态
group_states[group_id].past_events.clear()
group_states[group_id].history.clear()
# 清空数据库记录
await DatabaseManager.clear_group_history(group_id)
await reset_handler.finish("记忆已清空")
@ -744,8 +756,13 @@ async def handle_private_reset(event: PrivateMessageEvent, args: Message = Comma
user_id = event.user_id
# 清空内存状态
private_chat_states[user_id].past_events.clear()
private_chat_states[user_id].history.clear()
# 清空数据库记录
await DatabaseManager.clear_private_history(user_id)
await private_reset_handler.finish("记忆已清空")
@ -775,96 +792,93 @@ async def handle_private_think(event: PrivateMessageEvent, args: Message = Comma
# region 持久化与定时任务
# 获取插件数据目录
data_dir = store.get_plugin_data_dir()
# 获取插件数据文件
data_file = store.get_plugin_data_file("llmchat_state.json")
private_data_file = store.get_plugin_data_file("llmchat_private_state.json")
async def save_state():
"""保存群组状态到文件"""
logger.info(f"开始保存群组状态到文件:{data_file}")
data = {
gid: {
"preset": state.preset_name,
"history": list(state.history),
"last_active": state.last_active,
"group_prompt": state.group_prompt,
"output_reasoning_content": state.output_reasoning_content,
"random_trigger_prob": state.random_trigger_prob,
}
for gid, state in group_states.items()
}
os.makedirs(os.path.dirname(data_file), exist_ok=True)
async with aiofiles.open(data_file, "w", encoding="utf8") as f:
await f.write(json.dumps(data, ensure_ascii=False))
"""保存所有群组和私聊状态到数据库"""
logger.info("开始保存所有状态到数据库")
# 保存群组状态
for gid, state in group_states.items():
await DatabaseManager.save_group_state(
group_id=gid,
preset_name=state.preset_name,
history=state.history,
group_prompt=state.group_prompt,
output_reasoning_content=state.output_reasoning_content,
random_trigger_prob=state.random_trigger_prob,
)
# 保存私聊状态
if plugin_config.enable_private_chat:
logger.info(f"开始保存私聊状态到文件:{private_data_file}")
private_data = {
uid: {
"preset": state.preset_name,
"history": list(state.history),
"last_active": state.last_active,
"group_prompt": state.group_prompt,
"output_reasoning_content": state.output_reasoning_content,
}
for uid, state in private_chat_states.items()
}
os.makedirs(os.path.dirname(private_data_file), exist_ok=True)
async with aiofiles.open(private_data_file, "w", encoding="utf8") as f:
await f.write(json.dumps(private_data, ensure_ascii=False))
for uid, state in private_chat_states.items():
await DatabaseManager.save_private_state(
user_id=uid,
preset_name=state.preset_name,
history=state.history,
user_prompt=state.group_prompt, # 注意:这里应该是 group_prompt 但是在 PrivateChatState 中叫 group_prompt
output_reasoning_content=state.output_reasoning_content,
)
logger.info("所有状态保存完成")
async def load_state():
"""从文件加载群组状态"""
logger.info(f"从文件加载群组状态:{data_file}")
if not os.path.exists(data_file):
return
async with aiofiles.open(data_file, encoding="utf8") as f:
data = json.loads(await f.read())
for gid, state_data in data.items():
state = GroupState()
state.preset_name = state_data["preset"]
state.history = deque(
state_data["history"], maxlen=plugin_config.history_size * 2
)
state.last_active = state_data["last_active"]
state.group_prompt = state_data["group_prompt"]
state.output_reasoning_content = state_data["output_reasoning_content"]
state.random_trigger_prob = state_data.get("random_trigger_prob", plugin_config.random_trigger_prob)
group_states[int(gid)] = state
"""从数据库加载所有状态"""
logger.info("从数据库加载所有状态")
history_maxlen = plugin_config.history_size * 2
# 加载群组状态
group_data = await DatabaseManager.load_all_group_states(history_maxlen)
for gid, state_data in group_data.items():
state = GroupState()
state.preset_name = state_data["preset_name"]
state.history = state_data["history"]
state.group_prompt = state_data["group_prompt"]
state.output_reasoning_content = state_data["output_reasoning_content"]
state.random_trigger_prob = state_data["random_trigger_prob"]
state.last_active = state_data["last_active"]
group_states[gid] = state
# 加载私聊状态
if plugin_config.enable_private_chat:
logger.info(f"从文件加载私聊状态:{private_data_file}")
if os.path.exists(private_data_file):
async with aiofiles.open(private_data_file, encoding="utf8") as f:
private_data = json.loads(await f.read())
for uid, state_data in private_data.items():
state = PrivateChatState()
state.preset_name = state_data["preset"]
state.history = deque(
state_data["history"], maxlen=plugin_config.history_size * 2
)
state.last_active = state_data["last_active"]
state.group_prompt = state_data["group_prompt"]
state.output_reasoning_content = state_data["output_reasoning_content"]
private_chat_states[int(uid)] = state
private_data = await DatabaseManager.load_all_private_states(history_maxlen)
for uid, state_data in private_data.items():
state = PrivateChatState()
state.preset_name = state_data["preset_name"]
state.history = state_data["history"]
state.group_prompt = state_data["user_prompt"] # 注意:从 user_prompt 恢复到 group_prompt
state.output_reasoning_content = state_data["output_reasoning_content"]
state.last_active = state_data["last_active"]
private_chat_states[uid] = state
logger.info(f"已加载 {len(group_states)} 个群组和 {len(private_chat_states)} 个私聊的状态")
# 注册生命周期事件
@driver.on_startup
async def init_plugin():
logger.info("插件启动初始化")
# 初始化 Tortoise ORM 的模型
await init_orm_plugin()
# 创建表(如果不存在)
try:
# Tortoise ORM 会自动创建表,这里只是尝试检查
logger.info("数据库表初始化中...")
except Exception as e:
logger.warning(f"初始化数据库时出现警告: {e}")
# 执行迁移(从 JSON 到数据库)
try:
await migrate_from_json_to_db(plugin_config)
except Exception as e:
logger.warning(f"执行迁移时出现错误: {e}")
await load_state()
# 每5分钟保存状态
scheduler.add_job(save_state, "interval", minutes=5)
logger.info("插件初始化完成")
@driver.on_shutdown