尝试修改为数据库存储

This commit is contained in:
KawakazeNotFound 2025-11-07 15:35:46 +08:00
parent f2e882f885
commit fa49841011
4 changed files with 610 additions and 72 deletions

View file

@ -38,12 +38,19 @@ import nonebot_plugin_localstore as store
require("nonebot_plugin_apscheduler")
from nonebot_plugin_apscheduler import scheduler
require("nonebot_plugin_tortoise_orm")
from nonebot_plugin_tortoise_orm import init_orm_plugin
if TYPE_CHECKING:
from openai.types.chat import (
ChatCompletionContentPartParam,
ChatCompletionMessageParam,
)
from .db_manager import DatabaseManager
from .models import ChatHistory, ChatMessage, GroupChatState, PrivateChatState
from .migration import backup_json_files, migrate_from_json_to_db
__plugin_meta__ = PluginMetadata(
name="llmchat",
description="支持多API预设、MCP协议、联网搜索、视觉模型、Nano Banana生图模型的AI群聊插件",
@ -625,8 +632,13 @@ reset_handler = on_command(
async def handle_reset(event: GroupMessageEvent, args: Message = CommandArg()):
group_id = event.group_id
# 清空内存状态
group_states[group_id].past_events.clear()
group_states[group_id].history.clear()
# 清空数据库记录
await DatabaseManager.clear_group_history(group_id)
await reset_handler.finish("记忆已清空")
@ -744,8 +756,13 @@ async def handle_private_reset(event: PrivateMessageEvent, args: Message = Comma
user_id = event.user_id
# 清空内存状态
private_chat_states[user_id].past_events.clear()
private_chat_states[user_id].history.clear()
# 清空数据库记录
await DatabaseManager.clear_private_history(user_id)
await private_reset_handler.finish("记忆已清空")
@ -775,96 +792,93 @@ async def handle_private_think(event: PrivateMessageEvent, args: Message = Comma
# region 持久化与定时任务
# 获取插件数据目录
data_dir = store.get_plugin_data_dir()
# 获取插件数据文件
data_file = store.get_plugin_data_file("llmchat_state.json")
private_data_file = store.get_plugin_data_file("llmchat_private_state.json")
async def save_state():
"""保存群组状态到文件"""
logger.info(f"开始保存群组状态到文件:{data_file}")
data = {
gid: {
"preset": state.preset_name,
"history": list(state.history),
"last_active": state.last_active,
"group_prompt": state.group_prompt,
"output_reasoning_content": state.output_reasoning_content,
"random_trigger_prob": state.random_trigger_prob,
}
for gid, state in group_states.items()
}
os.makedirs(os.path.dirname(data_file), exist_ok=True)
async with aiofiles.open(data_file, "w", encoding="utf8") as f:
await f.write(json.dumps(data, ensure_ascii=False))
"""保存所有群组和私聊状态到数据库"""
logger.info("开始保存所有状态到数据库")
# 保存群组状态
for gid, state in group_states.items():
await DatabaseManager.save_group_state(
group_id=gid,
preset_name=state.preset_name,
history=state.history,
group_prompt=state.group_prompt,
output_reasoning_content=state.output_reasoning_content,
random_trigger_prob=state.random_trigger_prob,
)
# 保存私聊状态
if plugin_config.enable_private_chat:
logger.info(f"开始保存私聊状态到文件:{private_data_file}")
private_data = {
uid: {
"preset": state.preset_name,
"history": list(state.history),
"last_active": state.last_active,
"group_prompt": state.group_prompt,
"output_reasoning_content": state.output_reasoning_content,
}
for uid, state in private_chat_states.items()
}
os.makedirs(os.path.dirname(private_data_file), exist_ok=True)
async with aiofiles.open(private_data_file, "w", encoding="utf8") as f:
await f.write(json.dumps(private_data, ensure_ascii=False))
for uid, state in private_chat_states.items():
await DatabaseManager.save_private_state(
user_id=uid,
preset_name=state.preset_name,
history=state.history,
user_prompt=state.group_prompt, # 注意:这里应该是 group_prompt 但是在 PrivateChatState 中叫 group_prompt
output_reasoning_content=state.output_reasoning_content,
)
logger.info("所有状态保存完成")
async def load_state():
"""从文件加载群组状态"""
logger.info(f"从文件加载群组状态:{data_file}")
if not os.path.exists(data_file):
return
async with aiofiles.open(data_file, encoding="utf8") as f:
data = json.loads(await f.read())
for gid, state_data in data.items():
state = GroupState()
state.preset_name = state_data["preset"]
state.history = deque(
state_data["history"], maxlen=plugin_config.history_size * 2
)
state.last_active = state_data["last_active"]
state.group_prompt = state_data["group_prompt"]
state.output_reasoning_content = state_data["output_reasoning_content"]
state.random_trigger_prob = state_data.get("random_trigger_prob", plugin_config.random_trigger_prob)
group_states[int(gid)] = state
"""从数据库加载所有状态"""
logger.info("从数据库加载所有状态")
history_maxlen = plugin_config.history_size * 2
# 加载群组状态
group_data = await DatabaseManager.load_all_group_states(history_maxlen)
for gid, state_data in group_data.items():
state = GroupState()
state.preset_name = state_data["preset_name"]
state.history = state_data["history"]
state.group_prompt = state_data["group_prompt"]
state.output_reasoning_content = state_data["output_reasoning_content"]
state.random_trigger_prob = state_data["random_trigger_prob"]
state.last_active = state_data["last_active"]
group_states[gid] = state
# 加载私聊状态
if plugin_config.enable_private_chat:
logger.info(f"从文件加载私聊状态:{private_data_file}")
if os.path.exists(private_data_file):
async with aiofiles.open(private_data_file, encoding="utf8") as f:
private_data = json.loads(await f.read())
for uid, state_data in private_data.items():
state = PrivateChatState()
state.preset_name = state_data["preset"]
state.history = deque(
state_data["history"], maxlen=plugin_config.history_size * 2
)
state.last_active = state_data["last_active"]
state.group_prompt = state_data["group_prompt"]
state.output_reasoning_content = state_data["output_reasoning_content"]
private_chat_states[int(uid)] = state
private_data = await DatabaseManager.load_all_private_states(history_maxlen)
for uid, state_data in private_data.items():
state = PrivateChatState()
state.preset_name = state_data["preset_name"]
state.history = state_data["history"]
state.group_prompt = state_data["user_prompt"] # 注意:从 user_prompt 恢复到 group_prompt
state.output_reasoning_content = state_data["output_reasoning_content"]
state.last_active = state_data["last_active"]
private_chat_states[uid] = state
logger.info(f"已加载 {len(group_states)} 个群组和 {len(private_chat_states)} 个私聊的状态")
# 注册生命周期事件
@driver.on_startup
async def init_plugin():
logger.info("插件启动初始化")
# 初始化 Tortoise ORM 的模型
await init_orm_plugin()
# 创建表(如果不存在)
try:
# Tortoise ORM 会自动创建表,这里只是尝试检查
logger.info("数据库表初始化中...")
except Exception as e:
logger.warning(f"初始化数据库时出现警告: {e}")
# 执行迁移(从 JSON 到数据库)
try:
await migrate_from_json_to_db(plugin_config)
except Exception as e:
logger.warning(f"执行迁移时出现错误: {e}")
await load_state()
# 每5分钟保存状态
scheduler.add_job(save_state, "interval", minutes=5)
logger.info("插件初始化完成")
@driver.on_shutdown

View file

@ -0,0 +1,270 @@
"""
数据库操作层
处理聊天历史和状态的持久化
"""
import json
from collections import deque
from datetime import datetime
from typing import Optional
from nonebot import logger
from tortoise.exceptions import DoesNotExist
from .models import ChatHistory, ChatMessage, GroupChatState, PrivateChatState
class DatabaseManager:
"""数据库管理器"""
@staticmethod
async def save_group_state(
group_id: int,
preset_name: str,
history: deque,
group_prompt: Optional[str],
output_reasoning_content: bool,
random_trigger_prob: float,
):
"""保存群组状态和历史到数据库"""
try:
# 保存或更新群组状态
state, _ = await GroupChatState.get_or_create(
group_id=group_id,
defaults={
"preset_name": preset_name,
"group_prompt": group_prompt,
"output_reasoning_content": output_reasoning_content,
"random_trigger_prob": random_trigger_prob,
},
)
if _: # 如果是新创建的
logger.debug(f"创建群组状态记录: {group_id}")
else:
# 更新现有记录
state.preset_name = preset_name
state.group_prompt = group_prompt
state.output_reasoning_content = output_reasoning_content
state.random_trigger_prob = random_trigger_prob
await state.save()
logger.debug(f"更新群组状态记录: {group_id}")
# 保存历史快照
messages_list = list(history)
history_record, _ = await ChatHistory.get_or_create(
group_id=group_id,
is_private=False,
defaults={"messages_json": ChatHistory.serialize_messages(messages_list)},
)
if not _:
history_record.messages_json = ChatHistory.serialize_messages(messages_list)
await history_record.save()
logger.debug(f"已保存群组 {group_id} 的历史记录({len(messages_list)} 条消息)")
except Exception as e:
logger.error(f"保存群组状态失败 群号: {group_id}, 错误: {e}")
@staticmethod
async def save_private_state(
user_id: int,
preset_name: str,
history: deque,
user_prompt: Optional[str],
output_reasoning_content: bool,
):
"""保存私聊状态和历史到数据库"""
try:
# 保存或更新私聊状态
state, _ = await PrivateChatState.get_or_create(
user_id=user_id,
defaults={
"preset_name": preset_name,
"user_prompt": user_prompt,
"output_reasoning_content": output_reasoning_content,
},
)
if _: # 如果是新创建的
logger.debug(f"创建私聊状态记录: {user_id}")
else:
# 更新现有记录
state.preset_name = preset_name
state.user_prompt = user_prompt
state.output_reasoning_content = output_reasoning_content
await state.save()
logger.debug(f"更新私聊状态记录: {user_id}")
# 保存历史快照
messages_list = list(history)
history_record, _ = await ChatHistory.get_or_create(
user_id=user_id,
is_private=True,
defaults={"messages_json": ChatHistory.serialize_messages(messages_list)},
)
if not _:
history_record.messages_json = ChatHistory.serialize_messages(messages_list)
await history_record.save()
logger.debug(f"已保存用户 {user_id} 的历史记录({len(messages_list)} 条消息)")
except Exception as e:
logger.error(f"保存私聊状态失败 用户: {user_id}, 错误: {e}")
@staticmethod
async def load_group_state(group_id: int, history_maxlen: int) -> dict:
"""从数据库加载群组状态"""
try:
state = await GroupChatState.get_or_none(group_id=group_id)
if not state:
logger.debug(f"未找到群组 {group_id} 的状态记录,返回默认值")
return None
# 加载历史
history_record = await ChatHistory.get_or_none(
group_id=group_id, is_private=False
)
history = deque(
ChatHistory.deserialize_messages(history_record.messages_json)
if history_record
else [],
maxlen=history_maxlen,
)
logger.debug(f"已加载群组 {group_id} 的状态({len(history)} 条历史)")
return {
"preset_name": state.preset_name,
"history": history,
"group_prompt": state.group_prompt,
"output_reasoning_content": state.output_reasoning_content,
"random_trigger_prob": state.random_trigger_prob,
"last_active": state.last_active.timestamp(),
}
except Exception as e:
logger.error(f"加载群组状态失败 群号: {group_id}, 错误: {e}")
return None
@staticmethod
async def load_private_state(user_id: int, history_maxlen: int) -> dict:
"""从数据库加载私聊状态"""
try:
state = await PrivateChatState.get_or_none(user_id=user_id)
if not state:
logger.debug(f"未找到用户 {user_id} 的状态记录,返回默认值")
return None
# 加载历史
history_record = await ChatHistory.get_or_none(
user_id=user_id, is_private=True
)
history = deque(
ChatHistory.deserialize_messages(history_record.messages_json)
if history_record
else [],
maxlen=history_maxlen,
)
logger.debug(f"已加载用户 {user_id} 的状态({len(history)} 条历史)")
return {
"preset_name": state.preset_name,
"history": history,
"user_prompt": state.user_prompt,
"output_reasoning_content": state.output_reasoning_content,
"last_active": state.last_active.timestamp(),
}
except Exception as e:
logger.error(f"加载私聊状态失败 用户: {user_id}, 错误: {e}")
return None
@staticmethod
async def load_all_group_states(history_maxlen: int) -> dict:
"""加载所有群组状态"""
try:
states = await GroupChatState.all()
result = {}
for state in states:
history_record = await ChatHistory.get_or_none(
group_id=state.group_id, is_private=False
)
history = deque(
ChatHistory.deserialize_messages(history_record.messages_json)
if history_record
else [],
maxlen=history_maxlen,
)
result[state.group_id] = {
"preset_name": state.preset_name,
"history": history,
"group_prompt": state.group_prompt,
"output_reasoning_content": state.output_reasoning_content,
"random_trigger_prob": state.random_trigger_prob,
"last_active": state.last_active.timestamp(),
}
logger.info(f"已加载 {len(result)} 个群组的状态")
return result
except Exception as e:
logger.error(f"加载所有群组状态失败, 错误: {e}")
return {}
@staticmethod
async def load_all_private_states(history_maxlen: int) -> dict:
"""加载所有私聊状态"""
try:
states = await PrivateChatState.all()
result = {}
for state in states:
history_record = await ChatHistory.get_or_none(
user_id=state.user_id, is_private=True
)
history = deque(
ChatHistory.deserialize_messages(history_record.messages_json)
if history_record
else [],
maxlen=history_maxlen,
)
result[state.user_id] = {
"preset_name": state.preset_name,
"history": history,
"user_prompt": state.user_prompt,
"output_reasoning_content": state.output_reasoning_content,
"last_active": state.last_active.timestamp(),
}
logger.info(f"已加载 {len(result)} 个用户的私聊状态")
return result
except Exception as e:
logger.error(f"加载所有私聊状态失败, 错误: {e}")
return {}
@staticmethod
async def clear_group_history(group_id: int):
"""清空群组历史"""
try:
await ChatHistory.filter(group_id=group_id, is_private=False).delete()
state = await GroupChatState.get_or_none(group_id=group_id)
if state:
await state.delete()
logger.info(f"已清空群组 {group_id} 的历史记录")
except Exception as e:
logger.error(f"清空群组历史失败 群号: {group_id}, 错误: {e}")
@staticmethod
async def clear_private_history(user_id: int):
"""清空私聊历史"""
try:
await ChatHistory.filter(user_id=user_id, is_private=True).delete()
state = await PrivateChatState.get_or_none(user_id=user_id)
if state:
await state.delete()
logger.info(f"已清空用户 {user_id} 的历史记录")
except Exception as e:
logger.error(f"清空私聊历史失败 用户: {user_id}, 错误: {e}")

View file

@ -0,0 +1,152 @@
"""
数据迁移脚本
将聊天数据从 JSON 文件迁移到数据库
"""
import asyncio
import json
import os
from collections import deque
from datetime import datetime
from nonebot import logger
from .config import Config, get_plugin_config
from .db_manager import DatabaseManager
from .models import ChatHistory, GroupChatState, PrivateChatState
# 获取插件数据目录
try:
import nonebot_plugin_localstore as store
data_dir = store.get_plugin_data_dir()
data_file = store.get_plugin_data_file("llmchat_state.json")
private_data_file = store.get_plugin_data_file("llmchat_private_state.json")
except ImportError:
logger.warning("无法找到 nonebot_plugin_localstore迁移可能失败")
data_dir = None
data_file = None
private_data_file = None
async def migrate_from_json_to_db(plugin_config: Config):
"""从 JSON 文件迁移到数据库"""
logger.info("开始从 JSON 文件迁移到数据库")
if not data_file or not os.path.exists(data_file):
logger.info("未找到群组状态 JSON 文件,跳过迁移")
return
try:
# 迁移群组状态
logger.info(f"正在迁移群组状态数据: {data_file}")
with open(data_file, "r", encoding="utf8") as f:
data = json.load(f)
migrated_groups = 0
for gid_str, state_data in data.items():
try:
gid = int(gid_str)
# 检查是否已存在
existing = await GroupChatState.get_or_none(group_id=gid)
if existing:
logger.debug(f"群组 {gid} 已存在于数据库,跳过迁移")
continue
# 创建新的状态记录
await GroupChatState.create(
group_id=gid,
preset_name=state_data.get("preset", "off"),
group_prompt=state_data.get("group_prompt"),
output_reasoning_content=state_data.get("output_reasoning_content", False),
random_trigger_prob=state_data.get("random_trigger_prob", 0.05),
last_active=datetime.fromtimestamp(state_data.get("last_active", datetime.now().timestamp())),
)
# 创建历史记录
messages = state_data.get("history", [])
if messages:
await ChatHistory.create(
group_id=gid,
is_private=False,
messages_json=ChatHistory.serialize_messages(messages),
)
migrated_groups += 1
logger.debug(f"已迁移群组 {gid}{len(messages)} 条消息)")
except Exception as e:
logger.error(f"迁移群组 {gid_str} 失败: {e}")
logger.info(f"成功迁移 {migrated_groups} 个群组的状态")
except Exception as e:
logger.error(f"迁移群组状态失败: {e}")
# 迁移私聊状态
if plugin_config.llmchat.enable_private_chat and private_data_file and os.path.exists(private_data_file):
try:
logger.info(f"正在迁移私聊状态数据: {private_data_file}")
with open(private_data_file, "r", encoding="utf8") as f:
private_data = json.load(f)
migrated_users = 0
for uid_str, state_data in private_data.items():
try:
uid = int(uid_str)
# 检查是否已存在
existing = await PrivateChatState.get_or_none(user_id=uid)
if existing:
logger.debug(f"用户 {uid} 已存在于数据库,跳过迁移")
continue
# 创建新的状态记录
await PrivateChatState.create(
user_id=uid,
preset_name=state_data.get("preset", "off"),
user_prompt=state_data.get("group_prompt"), # JSON 中存的是 group_prompt
output_reasoning_content=state_data.get("output_reasoning_content", False),
last_active=datetime.fromtimestamp(state_data.get("last_active", datetime.now().timestamp())),
)
# 创建历史记录
messages = state_data.get("history", [])
if messages:
await ChatHistory.create(
user_id=uid,
is_private=True,
messages_json=ChatHistory.serialize_messages(messages),
)
migrated_users += 1
logger.debug(f"已迁移用户 {uid}{len(messages)} 条消息)")
except Exception as e:
logger.error(f"迁移用户 {uid_str} 失败: {e}")
logger.info(f"成功迁移 {migrated_users} 个用户的私聊状态")
except Exception as e:
logger.error(f"迁移私聊状态失败: {e}")
logger.info("JSON 迁移完成")
async def backup_json_files():
"""备份旧的 JSON 文件"""
if not data_file:
return
if os.path.exists(data_file):
backup_file = f"{data_file}.backup"
try:
os.rename(data_file, backup_file)
logger.info(f"已备份群组状态文件: {backup_file}")
except Exception as e:
logger.warning(f"备份文件失败: {e}")
if private_data_file and os.path.exists(private_data_file):
backup_file = f"{private_data_file}.backup"
try:
os.rename(private_data_file, backup_file)
logger.info(f"已备份私聊状态文件: {backup_file}")
except Exception as e:
logger.warning(f"备份文件失败: {e}")

View file

@ -0,0 +1,102 @@
"""
Tortoise ORM 模型定义
用于存储聊天历史和群组/私聊状态
"""
import json
from datetime import datetime
from typing import Optional
from tortoise import fields
from tortoise.models import Model
class GroupChatState(Model):
"""群组聊天状态"""
id = fields.IntField(pk=True)
group_id = fields.BigIntField(unique=True, description="群号")
preset_name = fields.CharField(max_length=50, description="当前使用的 API 预设名")
group_prompt = fields.TextField(null=True, description="群组自定义提示词")
output_reasoning_content = fields.BooleanField(default=False, description="是否输出推理内容")
random_trigger_prob = fields.FloatField(default=0.05, description="随机触发概率")
last_active = fields.DatetimeField(auto_now=True, description="最后活跃时间")
created_at = fields.DatetimeField(auto_now_add=True, description="创建时间")
class Meta:
table = "llmchat_group_state"
description = "群组聊天状态表"
class PrivateChatState(Model):
"""私聊状态"""
id = fields.IntField(pk=True)
user_id = fields.BigIntField(unique=True, description="用户 QQ")
preset_name = fields.CharField(max_length=50, description="当前使用的 API 预设名")
user_prompt = fields.TextField(null=True, description="用户自定义提示词")
output_reasoning_content = fields.BooleanField(default=False, description="是否输出推理内容")
last_active = fields.DatetimeField(auto_now=True, description="最后活跃时间")
created_at = fields.DatetimeField(auto_now_add=True, description="创建时间")
class Meta:
table = "llmchat_private_state"
description = "私聊状态表"
class ChatMessage(Model):
"""聊天消息历史"""
id = fields.IntField(pk=True)
group_id = fields.BigIntField(null=True, description="群号(私聊时为 NULL")
user_id = fields.BigIntField(null=True, description="用户 QQ私聊时有值")
is_private = fields.BooleanField(default=False, description="是否为私聊")
role = fields.CharField(
max_length=20,
description="消息角色: user/assistant/system/tool",
)
content = fields.TextField(description="消息内容JSON 序列化)")
created_at = fields.DatetimeField(auto_now_add=True, description="消息时间")
class Meta:
table = "llmchat_message"
description = "聊天消息历史表"
indexes = [
("group_id", "is_private", "created_at"), # 复合索引用于快速查询
("user_id", "is_private", "created_at"),
]
@staticmethod
def serialize_content(content) -> str:
"""将内容序列化为 JSON 字符串"""
return json.dumps(content, ensure_ascii=False)
@staticmethod
def deserialize_content(content_str: str):
"""从 JSON 字符串反序列化内容"""
return json.loads(content_str)
class ChatHistory(Model):
"""聊天历史快照(用于快速加载)"""
id = fields.IntField(pk=True)
group_id = fields.BigIntField(null=True, unique=True, description="群号(私聊时为 NULL")
user_id = fields.BigIntField(null=True, unique=True, description="用户 QQ私聊时有值")
is_private = fields.BooleanField(default=False, description="是否为私聊")
# 存储最近 history_size*2 条消息的 JSON 数组
messages_json = fields.TextField(description="消息历史JSON 数组)")
last_update = fields.DatetimeField(auto_now=True, description="最后更新时间")
class Meta:
table = "llmchat_history"
description = "聊天历史快照表(用于快速加载)"
@staticmethod
def serialize_messages(messages_list) -> str:
"""将消息列表序列化为 JSON 字符串"""
return json.dumps(messages_list, ensure_ascii=False)
@staticmethod
def deserialize_messages(messages_json: str):
"""从 JSON 字符串反序列化消息列表"""
return json.loads(messages_json)