mirror of
https://github.com/FuQuan233/nonebot-plugin-llmchat.git
synced 2026-02-05 11:38:05 +00:00
💾 尝试修改为数据库存储
This commit is contained in:
parent
a6290ca7bf
commit
59b4f3c2a3
4 changed files with 610 additions and 72 deletions
|
|
@ -38,12 +38,19 @@ import nonebot_plugin_localstore as store
|
|||
require("nonebot_plugin_apscheduler")
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
|
||||
require("nonebot_plugin_tortoise_orm")
|
||||
from nonebot_plugin_tortoise_orm import init_orm_plugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartParam,
|
||||
ChatCompletionMessageParam,
|
||||
)
|
||||
|
||||
from .db_manager import DatabaseManager
|
||||
from .models import ChatHistory, ChatMessage, GroupChatState, PrivateChatState
|
||||
from .migration import backup_json_files, migrate_from_json_to_db
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="llmchat",
|
||||
description="支持多API预设、MCP协议、联网搜索、视觉模型、Nano Banana(生图模型)的AI群聊插件",
|
||||
|
|
@ -625,8 +632,13 @@ reset_handler = on_command(
|
|||
async def handle_reset(event: GroupMessageEvent, args: Message = CommandArg()):
|
||||
group_id = event.group_id
|
||||
|
||||
# 清空内存状态
|
||||
group_states[group_id].past_events.clear()
|
||||
group_states[group_id].history.clear()
|
||||
|
||||
# 清空数据库记录
|
||||
await DatabaseManager.clear_group_history(group_id)
|
||||
|
||||
await reset_handler.finish("记忆已清空")
|
||||
|
||||
|
||||
|
|
@ -744,8 +756,13 @@ async def handle_private_reset(event: PrivateMessageEvent, args: Message = Comma
|
|||
|
||||
user_id = event.user_id
|
||||
|
||||
# 清空内存状态
|
||||
private_chat_states[user_id].past_events.clear()
|
||||
private_chat_states[user_id].history.clear()
|
||||
|
||||
# 清空数据库记录
|
||||
await DatabaseManager.clear_private_history(user_id)
|
||||
|
||||
await private_reset_handler.finish("记忆已清空")
|
||||
|
||||
|
||||
|
|
@ -775,96 +792,93 @@ async def handle_private_think(event: PrivateMessageEvent, args: Message = Comma
|
|||
|
||||
# region 持久化与定时任务
|
||||
|
||||
# 获取插件数据目录
|
||||
data_dir = store.get_plugin_data_dir()
|
||||
# 获取插件数据文件
|
||||
data_file = store.get_plugin_data_file("llmchat_state.json")
|
||||
private_data_file = store.get_plugin_data_file("llmchat_private_state.json")
|
||||
|
||||
|
||||
async def save_state():
|
||||
"""保存群组状态到文件"""
|
||||
logger.info(f"开始保存群组状态到文件:{data_file}")
|
||||
data = {
|
||||
gid: {
|
||||
"preset": state.preset_name,
|
||||
"history": list(state.history),
|
||||
"last_active": state.last_active,
|
||||
"group_prompt": state.group_prompt,
|
||||
"output_reasoning_content": state.output_reasoning_content,
|
||||
"random_trigger_prob": state.random_trigger_prob,
|
||||
}
|
||||
for gid, state in group_states.items()
|
||||
}
|
||||
|
||||
os.makedirs(os.path.dirname(data_file), exist_ok=True)
|
||||
async with aiofiles.open(data_file, "w", encoding="utf8") as f:
|
||||
await f.write(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
"""保存所有群组和私聊状态到数据库"""
|
||||
logger.info("开始保存所有状态到数据库")
|
||||
|
||||
# 保存群组状态
|
||||
for gid, state in group_states.items():
|
||||
await DatabaseManager.save_group_state(
|
||||
group_id=gid,
|
||||
preset_name=state.preset_name,
|
||||
history=state.history,
|
||||
group_prompt=state.group_prompt,
|
||||
output_reasoning_content=state.output_reasoning_content,
|
||||
random_trigger_prob=state.random_trigger_prob,
|
||||
)
|
||||
|
||||
# 保存私聊状态
|
||||
if plugin_config.enable_private_chat:
|
||||
logger.info(f"开始保存私聊状态到文件:{private_data_file}")
|
||||
private_data = {
|
||||
uid: {
|
||||
"preset": state.preset_name,
|
||||
"history": list(state.history),
|
||||
"last_active": state.last_active,
|
||||
"group_prompt": state.group_prompt,
|
||||
"output_reasoning_content": state.output_reasoning_content,
|
||||
}
|
||||
for uid, state in private_chat_states.items()
|
||||
}
|
||||
|
||||
os.makedirs(os.path.dirname(private_data_file), exist_ok=True)
|
||||
async with aiofiles.open(private_data_file, "w", encoding="utf8") as f:
|
||||
await f.write(json.dumps(private_data, ensure_ascii=False))
|
||||
for uid, state in private_chat_states.items():
|
||||
await DatabaseManager.save_private_state(
|
||||
user_id=uid,
|
||||
preset_name=state.preset_name,
|
||||
history=state.history,
|
||||
user_prompt=state.group_prompt, # 注意:这里应该是 group_prompt 但是在 PrivateChatState 中叫 group_prompt
|
||||
output_reasoning_content=state.output_reasoning_content,
|
||||
)
|
||||
|
||||
logger.info("所有状态保存完成")
|
||||
|
||||
|
||||
async def load_state():
|
||||
"""从文件加载群组状态"""
|
||||
logger.info(f"从文件加载群组状态:{data_file}")
|
||||
if not os.path.exists(data_file):
|
||||
return
|
||||
|
||||
async with aiofiles.open(data_file, encoding="utf8") as f:
|
||||
data = json.loads(await f.read())
|
||||
for gid, state_data in data.items():
|
||||
state = GroupState()
|
||||
state.preset_name = state_data["preset"]
|
||||
state.history = deque(
|
||||
state_data["history"], maxlen=plugin_config.history_size * 2
|
||||
)
|
||||
state.last_active = state_data["last_active"]
|
||||
state.group_prompt = state_data["group_prompt"]
|
||||
state.output_reasoning_content = state_data["output_reasoning_content"]
|
||||
state.random_trigger_prob = state_data.get("random_trigger_prob", plugin_config.random_trigger_prob)
|
||||
group_states[int(gid)] = state
|
||||
|
||||
"""从数据库加载所有状态"""
|
||||
logger.info("从数据库加载所有状态")
|
||||
|
||||
history_maxlen = plugin_config.history_size * 2
|
||||
|
||||
# 加载群组状态
|
||||
group_data = await DatabaseManager.load_all_group_states(history_maxlen)
|
||||
for gid, state_data in group_data.items():
|
||||
state = GroupState()
|
||||
state.preset_name = state_data["preset_name"]
|
||||
state.history = state_data["history"]
|
||||
state.group_prompt = state_data["group_prompt"]
|
||||
state.output_reasoning_content = state_data["output_reasoning_content"]
|
||||
state.random_trigger_prob = state_data["random_trigger_prob"]
|
||||
state.last_active = state_data["last_active"]
|
||||
group_states[gid] = state
|
||||
|
||||
# 加载私聊状态
|
||||
if plugin_config.enable_private_chat:
|
||||
logger.info(f"从文件加载私聊状态:{private_data_file}")
|
||||
if os.path.exists(private_data_file):
|
||||
async with aiofiles.open(private_data_file, encoding="utf8") as f:
|
||||
private_data = json.loads(await f.read())
|
||||
for uid, state_data in private_data.items():
|
||||
state = PrivateChatState()
|
||||
state.preset_name = state_data["preset"]
|
||||
state.history = deque(
|
||||
state_data["history"], maxlen=plugin_config.history_size * 2
|
||||
)
|
||||
state.last_active = state_data["last_active"]
|
||||
state.group_prompt = state_data["group_prompt"]
|
||||
state.output_reasoning_content = state_data["output_reasoning_content"]
|
||||
private_chat_states[int(uid)] = state
|
||||
private_data = await DatabaseManager.load_all_private_states(history_maxlen)
|
||||
for uid, state_data in private_data.items():
|
||||
state = PrivateChatState()
|
||||
state.preset_name = state_data["preset_name"]
|
||||
state.history = state_data["history"]
|
||||
state.group_prompt = state_data["user_prompt"] # 注意:从 user_prompt 恢复到 group_prompt
|
||||
state.output_reasoning_content = state_data["output_reasoning_content"]
|
||||
state.last_active = state_data["last_active"]
|
||||
private_chat_states[uid] = state
|
||||
|
||||
logger.info(f"已加载 {len(group_states)} 个群组和 {len(private_chat_states)} 个私聊的状态")
|
||||
|
||||
|
||||
# 注册生命周期事件
|
||||
@driver.on_startup
|
||||
async def init_plugin():
|
||||
logger.info("插件启动初始化")
|
||||
# 初始化 Tortoise ORM 的模型
|
||||
await init_orm_plugin()
|
||||
# 创建表(如果不存在)
|
||||
try:
|
||||
# Tortoise ORM 会自动创建表,这里只是尝试检查
|
||||
logger.info("数据库表初始化中...")
|
||||
except Exception as e:
|
||||
logger.warning(f"初始化数据库时出现警告: {e}")
|
||||
|
||||
# 执行迁移(从 JSON 到数据库)
|
||||
try:
|
||||
await migrate_from_json_to_db(plugin_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"执行迁移时出现错误: {e}")
|
||||
|
||||
await load_state()
|
||||
# 每5分钟保存状态
|
||||
scheduler.add_job(save_state, "interval", minutes=5)
|
||||
|
||||
logger.info("插件初始化完成")
|
||||
|
||||
|
||||
@driver.on_shutdown
|
||||
|
|
|
|||
270
nonebot_plugin_llmchat/db_manager.py
Normal file
270
nonebot_plugin_llmchat/db_manager.py
Normal file
|
|
@ -0,0 +1,270 @@
|
|||
"""
|
||||
数据库操作层
|
||||
处理聊天历史和状态的持久化
|
||||
"""
|
||||
import json
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from nonebot import logger
|
||||
from tortoise.exceptions import DoesNotExist
|
||||
|
||||
from .models import ChatHistory, ChatMessage, GroupChatState, PrivateChatState
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""数据库管理器"""
|
||||
|
||||
@staticmethod
|
||||
async def save_group_state(
|
||||
group_id: int,
|
||||
preset_name: str,
|
||||
history: deque,
|
||||
group_prompt: Optional[str],
|
||||
output_reasoning_content: bool,
|
||||
random_trigger_prob: float,
|
||||
):
|
||||
"""保存群组状态和历史到数据库"""
|
||||
try:
|
||||
# 保存或更新群组状态
|
||||
state, _ = await GroupChatState.get_or_create(
|
||||
group_id=group_id,
|
||||
defaults={
|
||||
"preset_name": preset_name,
|
||||
"group_prompt": group_prompt,
|
||||
"output_reasoning_content": output_reasoning_content,
|
||||
"random_trigger_prob": random_trigger_prob,
|
||||
},
|
||||
)
|
||||
if _: # 如果是新创建的
|
||||
logger.debug(f"创建群组状态记录: {group_id}")
|
||||
else:
|
||||
# 更新现有记录
|
||||
state.preset_name = preset_name
|
||||
state.group_prompt = group_prompt
|
||||
state.output_reasoning_content = output_reasoning_content
|
||||
state.random_trigger_prob = random_trigger_prob
|
||||
await state.save()
|
||||
logger.debug(f"更新群组状态记录: {group_id}")
|
||||
|
||||
# 保存历史快照
|
||||
messages_list = list(history)
|
||||
history_record, _ = await ChatHistory.get_or_create(
|
||||
group_id=group_id,
|
||||
is_private=False,
|
||||
defaults={"messages_json": ChatHistory.serialize_messages(messages_list)},
|
||||
)
|
||||
if not _:
|
||||
history_record.messages_json = ChatHistory.serialize_messages(messages_list)
|
||||
await history_record.save()
|
||||
|
||||
logger.debug(f"已保存群组 {group_id} 的历史记录({len(messages_list)} 条消息)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存群组状态失败 群号: {group_id}, 错误: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def save_private_state(
|
||||
user_id: int,
|
||||
preset_name: str,
|
||||
history: deque,
|
||||
user_prompt: Optional[str],
|
||||
output_reasoning_content: bool,
|
||||
):
|
||||
"""保存私聊状态和历史到数据库"""
|
||||
try:
|
||||
# 保存或更新私聊状态
|
||||
state, _ = await PrivateChatState.get_or_create(
|
||||
user_id=user_id,
|
||||
defaults={
|
||||
"preset_name": preset_name,
|
||||
"user_prompt": user_prompt,
|
||||
"output_reasoning_content": output_reasoning_content,
|
||||
},
|
||||
)
|
||||
if _: # 如果是新创建的
|
||||
logger.debug(f"创建私聊状态记录: {user_id}")
|
||||
else:
|
||||
# 更新现有记录
|
||||
state.preset_name = preset_name
|
||||
state.user_prompt = user_prompt
|
||||
state.output_reasoning_content = output_reasoning_content
|
||||
await state.save()
|
||||
logger.debug(f"更新私聊状态记录: {user_id}")
|
||||
|
||||
# 保存历史快照
|
||||
messages_list = list(history)
|
||||
history_record, _ = await ChatHistory.get_or_create(
|
||||
user_id=user_id,
|
||||
is_private=True,
|
||||
defaults={"messages_json": ChatHistory.serialize_messages(messages_list)},
|
||||
)
|
||||
if not _:
|
||||
history_record.messages_json = ChatHistory.serialize_messages(messages_list)
|
||||
await history_record.save()
|
||||
|
||||
logger.debug(f"已保存用户 {user_id} 的历史记录({len(messages_list)} 条消息)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存私聊状态失败 用户: {user_id}, 错误: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def load_group_state(group_id: int, history_maxlen: int) -> dict:
|
||||
"""从数据库加载群组状态"""
|
||||
try:
|
||||
state = await GroupChatState.get_or_none(group_id=group_id)
|
||||
if not state:
|
||||
logger.debug(f"未找到群组 {group_id} 的状态记录,返回默认值")
|
||||
return None
|
||||
|
||||
# 加载历史
|
||||
history_record = await ChatHistory.get_or_none(
|
||||
group_id=group_id, is_private=False
|
||||
)
|
||||
history = deque(
|
||||
ChatHistory.deserialize_messages(history_record.messages_json)
|
||||
if history_record
|
||||
else [],
|
||||
maxlen=history_maxlen,
|
||||
)
|
||||
|
||||
logger.debug(f"已加载群组 {group_id} 的状态({len(history)} 条历史)")
|
||||
|
||||
return {
|
||||
"preset_name": state.preset_name,
|
||||
"history": history,
|
||||
"group_prompt": state.group_prompt,
|
||||
"output_reasoning_content": state.output_reasoning_content,
|
||||
"random_trigger_prob": state.random_trigger_prob,
|
||||
"last_active": state.last_active.timestamp(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载群组状态失败 群号: {group_id}, 错误: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def load_private_state(user_id: int, history_maxlen: int) -> dict:
|
||||
"""从数据库加载私聊状态"""
|
||||
try:
|
||||
state = await PrivateChatState.get_or_none(user_id=user_id)
|
||||
if not state:
|
||||
logger.debug(f"未找到用户 {user_id} 的状态记录,返回默认值")
|
||||
return None
|
||||
|
||||
# 加载历史
|
||||
history_record = await ChatHistory.get_or_none(
|
||||
user_id=user_id, is_private=True
|
||||
)
|
||||
history = deque(
|
||||
ChatHistory.deserialize_messages(history_record.messages_json)
|
||||
if history_record
|
||||
else [],
|
||||
maxlen=history_maxlen,
|
||||
)
|
||||
|
||||
logger.debug(f"已加载用户 {user_id} 的状态({len(history)} 条历史)")
|
||||
|
||||
return {
|
||||
"preset_name": state.preset_name,
|
||||
"history": history,
|
||||
"user_prompt": state.user_prompt,
|
||||
"output_reasoning_content": state.output_reasoning_content,
|
||||
"last_active": state.last_active.timestamp(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载私聊状态失败 用户: {user_id}, 错误: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def load_all_group_states(history_maxlen: int) -> dict:
|
||||
"""加载所有群组状态"""
|
||||
try:
|
||||
states = await GroupChatState.all()
|
||||
result = {}
|
||||
|
||||
for state in states:
|
||||
history_record = await ChatHistory.get_or_none(
|
||||
group_id=state.group_id, is_private=False
|
||||
)
|
||||
history = deque(
|
||||
ChatHistory.deserialize_messages(history_record.messages_json)
|
||||
if history_record
|
||||
else [],
|
||||
maxlen=history_maxlen,
|
||||
)
|
||||
|
||||
result[state.group_id] = {
|
||||
"preset_name": state.preset_name,
|
||||
"history": history,
|
||||
"group_prompt": state.group_prompt,
|
||||
"output_reasoning_content": state.output_reasoning_content,
|
||||
"random_trigger_prob": state.random_trigger_prob,
|
||||
"last_active": state.last_active.timestamp(),
|
||||
}
|
||||
|
||||
logger.info(f"已加载 {len(result)} 个群组的状态")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载所有群组状态失败, 错误: {e}")
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
async def load_all_private_states(history_maxlen: int) -> dict:
|
||||
"""加载所有私聊状态"""
|
||||
try:
|
||||
states = await PrivateChatState.all()
|
||||
result = {}
|
||||
|
||||
for state in states:
|
||||
history_record = await ChatHistory.get_or_none(
|
||||
user_id=state.user_id, is_private=True
|
||||
)
|
||||
history = deque(
|
||||
ChatHistory.deserialize_messages(history_record.messages_json)
|
||||
if history_record
|
||||
else [],
|
||||
maxlen=history_maxlen,
|
||||
)
|
||||
|
||||
result[state.user_id] = {
|
||||
"preset_name": state.preset_name,
|
||||
"history": history,
|
||||
"user_prompt": state.user_prompt,
|
||||
"output_reasoning_content": state.output_reasoning_content,
|
||||
"last_active": state.last_active.timestamp(),
|
||||
}
|
||||
|
||||
logger.info(f"已加载 {len(result)} 个用户的私聊状态")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载所有私聊状态失败, 错误: {e}")
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
async def clear_group_history(group_id: int):
|
||||
"""清空群组历史"""
|
||||
try:
|
||||
await ChatHistory.filter(group_id=group_id, is_private=False).delete()
|
||||
state = await GroupChatState.get_or_none(group_id=group_id)
|
||||
if state:
|
||||
await state.delete()
|
||||
logger.info(f"已清空群组 {group_id} 的历史记录")
|
||||
except Exception as e:
|
||||
logger.error(f"清空群组历史失败 群号: {group_id}, 错误: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def clear_private_history(user_id: int):
|
||||
"""清空私聊历史"""
|
||||
try:
|
||||
await ChatHistory.filter(user_id=user_id, is_private=True).delete()
|
||||
state = await PrivateChatState.get_or_none(user_id=user_id)
|
||||
if state:
|
||||
await state.delete()
|
||||
logger.info(f"已清空用户 {user_id} 的历史记录")
|
||||
except Exception as e:
|
||||
logger.error(f"清空私聊历史失败 用户: {user_id}, 错误: {e}")
|
||||
152
nonebot_plugin_llmchat/migration.py
Normal file
152
nonebot_plugin_llmchat/migration.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
"""
|
||||
数据迁移脚本
|
||||
将聊天数据从 JSON 文件迁移到数据库
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
|
||||
from nonebot import logger
|
||||
|
||||
from .config import Config, get_plugin_config
|
||||
from .db_manager import DatabaseManager
|
||||
from .models import ChatHistory, GroupChatState, PrivateChatState
|
||||
|
||||
# 获取插件数据目录
|
||||
try:
|
||||
import nonebot_plugin_localstore as store
|
||||
data_dir = store.get_plugin_data_dir()
|
||||
data_file = store.get_plugin_data_file("llmchat_state.json")
|
||||
private_data_file = store.get_plugin_data_file("llmchat_private_state.json")
|
||||
except ImportError:
|
||||
logger.warning("无法找到 nonebot_plugin_localstore,迁移可能失败")
|
||||
data_dir = None
|
||||
data_file = None
|
||||
private_data_file = None
|
||||
|
||||
|
||||
async def migrate_from_json_to_db(plugin_config: Config):
|
||||
"""从 JSON 文件迁移到数据库"""
|
||||
logger.info("开始从 JSON 文件迁移到数据库")
|
||||
|
||||
if not data_file or not os.path.exists(data_file):
|
||||
logger.info("未找到群组状态 JSON 文件,跳过迁移")
|
||||
return
|
||||
|
||||
try:
|
||||
# 迁移群组状态
|
||||
logger.info(f"正在迁移群组状态数据: {data_file}")
|
||||
with open(data_file, "r", encoding="utf8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
migrated_groups = 0
|
||||
for gid_str, state_data in data.items():
|
||||
try:
|
||||
gid = int(gid_str)
|
||||
# 检查是否已存在
|
||||
existing = await GroupChatState.get_or_none(group_id=gid)
|
||||
if existing:
|
||||
logger.debug(f"群组 {gid} 已存在于数据库,跳过迁移")
|
||||
continue
|
||||
|
||||
# 创建新的状态记录
|
||||
await GroupChatState.create(
|
||||
group_id=gid,
|
||||
preset_name=state_data.get("preset", "off"),
|
||||
group_prompt=state_data.get("group_prompt"),
|
||||
output_reasoning_content=state_data.get("output_reasoning_content", False),
|
||||
random_trigger_prob=state_data.get("random_trigger_prob", 0.05),
|
||||
last_active=datetime.fromtimestamp(state_data.get("last_active", datetime.now().timestamp())),
|
||||
)
|
||||
|
||||
# 创建历史记录
|
||||
messages = state_data.get("history", [])
|
||||
if messages:
|
||||
await ChatHistory.create(
|
||||
group_id=gid,
|
||||
is_private=False,
|
||||
messages_json=ChatHistory.serialize_messages(messages),
|
||||
)
|
||||
|
||||
migrated_groups += 1
|
||||
logger.debug(f"已迁移群组 {gid}({len(messages)} 条消息)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移群组 {gid_str} 失败: {e}")
|
||||
|
||||
logger.info(f"成功迁移 {migrated_groups} 个群组的状态")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移群组状态失败: {e}")
|
||||
|
||||
# 迁移私聊状态
|
||||
if plugin_config.llmchat.enable_private_chat and private_data_file and os.path.exists(private_data_file):
|
||||
try:
|
||||
logger.info(f"正在迁移私聊状态数据: {private_data_file}")
|
||||
with open(private_data_file, "r", encoding="utf8") as f:
|
||||
private_data = json.load(f)
|
||||
|
||||
migrated_users = 0
|
||||
for uid_str, state_data in private_data.items():
|
||||
try:
|
||||
uid = int(uid_str)
|
||||
# 检查是否已存在
|
||||
existing = await PrivateChatState.get_or_none(user_id=uid)
|
||||
if existing:
|
||||
logger.debug(f"用户 {uid} 已存在于数据库,跳过迁移")
|
||||
continue
|
||||
|
||||
# 创建新的状态记录
|
||||
await PrivateChatState.create(
|
||||
user_id=uid,
|
||||
preset_name=state_data.get("preset", "off"),
|
||||
user_prompt=state_data.get("group_prompt"), # JSON 中存的是 group_prompt
|
||||
output_reasoning_content=state_data.get("output_reasoning_content", False),
|
||||
last_active=datetime.fromtimestamp(state_data.get("last_active", datetime.now().timestamp())),
|
||||
)
|
||||
|
||||
# 创建历史记录
|
||||
messages = state_data.get("history", [])
|
||||
if messages:
|
||||
await ChatHistory.create(
|
||||
user_id=uid,
|
||||
is_private=True,
|
||||
messages_json=ChatHistory.serialize_messages(messages),
|
||||
)
|
||||
|
||||
migrated_users += 1
|
||||
logger.debug(f"已迁移用户 {uid}({len(messages)} 条消息)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移用户 {uid_str} 失败: {e}")
|
||||
|
||||
logger.info(f"成功迁移 {migrated_users} 个用户的私聊状态")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移私聊状态失败: {e}")
|
||||
|
||||
logger.info("JSON 迁移完成")
|
||||
|
||||
|
||||
async def backup_json_files():
|
||||
"""备份旧的 JSON 文件"""
|
||||
if not data_file:
|
||||
return
|
||||
|
||||
if os.path.exists(data_file):
|
||||
backup_file = f"{data_file}.backup"
|
||||
try:
|
||||
os.rename(data_file, backup_file)
|
||||
logger.info(f"已备份群组状态文件: {backup_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"备份文件失败: {e}")
|
||||
|
||||
if private_data_file and os.path.exists(private_data_file):
|
||||
backup_file = f"{private_data_file}.backup"
|
||||
try:
|
||||
os.rename(private_data_file, backup_file)
|
||||
logger.info(f"已备份私聊状态文件: {backup_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"备份文件失败: {e}")
|
||||
102
nonebot_plugin_llmchat/models.py
Normal file
102
nonebot_plugin_llmchat/models.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""
|
||||
Tortoise ORM 模型定义
|
||||
用于存储聊天历史和群组/私聊状态
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from tortoise import fields
|
||||
from tortoise.models import Model
|
||||
|
||||
|
||||
class GroupChatState(Model):
|
||||
"""群组聊天状态"""
|
||||
|
||||
id = fields.IntField(pk=True)
|
||||
group_id = fields.BigIntField(unique=True, description="群号")
|
||||
preset_name = fields.CharField(max_length=50, description="当前使用的 API 预设名")
|
||||
group_prompt = fields.TextField(null=True, description="群组自定义提示词")
|
||||
output_reasoning_content = fields.BooleanField(default=False, description="是否输出推理内容")
|
||||
random_trigger_prob = fields.FloatField(default=0.05, description="随机触发概率")
|
||||
last_active = fields.DatetimeField(auto_now=True, description="最后活跃时间")
|
||||
created_at = fields.DatetimeField(auto_now_add=True, description="创建时间")
|
||||
|
||||
class Meta:
|
||||
table = "llmchat_group_state"
|
||||
description = "群组聊天状态表"
|
||||
|
||||
|
||||
class PrivateChatState(Model):
|
||||
"""私聊状态"""
|
||||
|
||||
id = fields.IntField(pk=True)
|
||||
user_id = fields.BigIntField(unique=True, description="用户 QQ")
|
||||
preset_name = fields.CharField(max_length=50, description="当前使用的 API 预设名")
|
||||
user_prompt = fields.TextField(null=True, description="用户自定义提示词")
|
||||
output_reasoning_content = fields.BooleanField(default=False, description="是否输出推理内容")
|
||||
last_active = fields.DatetimeField(auto_now=True, description="最后活跃时间")
|
||||
created_at = fields.DatetimeField(auto_now_add=True, description="创建时间")
|
||||
|
||||
class Meta:
|
||||
table = "llmchat_private_state"
|
||||
description = "私聊状态表"
|
||||
|
||||
|
||||
class ChatMessage(Model):
|
||||
"""聊天消息历史"""
|
||||
|
||||
id = fields.IntField(pk=True)
|
||||
group_id = fields.BigIntField(null=True, description="群号(私聊时为 NULL)")
|
||||
user_id = fields.BigIntField(null=True, description="用户 QQ(私聊时有值)")
|
||||
is_private = fields.BooleanField(default=False, description="是否为私聊")
|
||||
role = fields.CharField(
|
||||
max_length=20,
|
||||
description="消息角色: user/assistant/system/tool",
|
||||
)
|
||||
content = fields.TextField(description="消息内容(JSON 序列化)")
|
||||
created_at = fields.DatetimeField(auto_now_add=True, description="消息时间")
|
||||
|
||||
class Meta:
|
||||
table = "llmchat_message"
|
||||
description = "聊天消息历史表"
|
||||
indexes = [
|
||||
("group_id", "is_private", "created_at"), # 复合索引用于快速查询
|
||||
("user_id", "is_private", "created_at"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def serialize_content(content) -> str:
|
||||
"""将内容序列化为 JSON 字符串"""
|
||||
return json.dumps(content, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def deserialize_content(content_str: str):
|
||||
"""从 JSON 字符串反序列化内容"""
|
||||
return json.loads(content_str)
|
||||
|
||||
|
||||
class ChatHistory(Model):
|
||||
"""聊天历史快照(用于快速加载)"""
|
||||
|
||||
id = fields.IntField(pk=True)
|
||||
group_id = fields.BigIntField(null=True, unique=True, description="群号(私聊时为 NULL)")
|
||||
user_id = fields.BigIntField(null=True, unique=True, description="用户 QQ(私聊时有值)")
|
||||
is_private = fields.BooleanField(default=False, description="是否为私聊")
|
||||
# 存储最近 history_size*2 条消息的 JSON 数组
|
||||
messages_json = fields.TextField(description="消息历史(JSON 数组)")
|
||||
last_update = fields.DatetimeField(auto_now=True, description="最后更新时间")
|
||||
|
||||
class Meta:
|
||||
table = "llmchat_history"
|
||||
description = "聊天历史快照表(用于快速加载)"
|
||||
|
||||
@staticmethod
|
||||
def serialize_messages(messages_list) -> str:
|
||||
"""将消息列表序列化为 JSON 字符串"""
|
||||
return json.dumps(messages_list, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def deserialize_messages(messages_json: str):
|
||||
"""从 JSON 字符串反序列化消息列表"""
|
||||
return json.loads(messages_json)
|
||||
Loading…
Add table
Add a link
Reference in a new issue