尝试改用ORM进行存储

This commit is contained in:
KawakazeNotFound 2025-11-07 16:00:57 +08:00
parent 4ab2faef93
commit 153e278fac
3 changed files with 18 additions and 33 deletions

View file

@ -39,7 +39,10 @@ require("nonebot_plugin_apscheduler")
from nonebot_plugin_apscheduler import scheduler from nonebot_plugin_apscheduler import scheduler
require("nonebot_plugin_tortoise_orm") require("nonebot_plugin_tortoise_orm")
from nonebot_plugin_tortoise_orm import init_orm_plugin # 必须在 require 之后导入模型,才能正确注册到 Tortoise ORM
from . import models # noqa: F401
require("nonebot_plugin_tortoise_orm")
if TYPE_CHECKING: if TYPE_CHECKING:
from openai.types.chat import ( from openai.types.chat import (
@ -49,7 +52,6 @@ if TYPE_CHECKING:
from .db_manager import DatabaseManager from .db_manager import DatabaseManager
from .models import ChatHistory, ChatMessage, GroupChatState, PrivateChatState from .models import ChatHistory, ChatMessage, GroupChatState, PrivateChatState
from .migration import backup_json_files, migrate_from_json_to_db
__plugin_meta__ = PluginMetadata( __plugin_meta__ = PluginMetadata(
name="llmchat", name="llmchat",
@ -644,11 +646,13 @@ async def handle_reset(event: GroupMessageEvent | PrivateMessageEvent, args: Mes
if isinstance(event, GroupMessageEvent): if isinstance(event, GroupMessageEvent):
context_id = event.group_id context_id = event.group_id
state = group_states[context_id] state = group_states[context_id]
await DatabaseManager.clear_group_history(context_id)
else: # PrivateMessageEvent else: # PrivateMessageEvent
if not plugin_config.enable_private_chat: if not plugin_config.enable_private_chat:
return return
context_id = event.user_id context_id = event.user_id
state = private_chat_states[context_id] state = private_chat_states[context_id]
await DatabaseManager.clear_private_history(context_id)
state.past_events.clear() state.past_events.clear()
state.history.clear() state.history.clear()
@ -774,25 +778,9 @@ async def load_state():
@driver.on_startup @driver.on_startup
async def init_plugin(): async def init_plugin():
logger.info("插件启动初始化") logger.info("插件启动初始化")
# 初始化 Tortoise ORM 的模型
await init_orm_plugin()
# 创建表(如果不存在)
try:
# Tortoise ORM 会自动创建表,这里只是尝试检查
logger.info("数据库表初始化中...")
except Exception as e:
logger.warning(f"初始化数据库时出现警告: {e}")
# 执行迁移(从 JSON 到数据库)
try:
await migrate_from_json_to_db(plugin_config)
except Exception as e:
logger.warning(f"执行迁移时出现错误: {e}")
await load_state() await load_state()
# 每5分钟保存状态 # 每5分钟保存状态
scheduler.add_job(save_state, "interval", minutes=5) scheduler.add_job(save_state, "interval", minutes=5)
logger.info("插件初始化完成") logger.info("插件初始化完成")

View file

@ -8,7 +8,6 @@ from datetime import datetime
from typing import Optional from typing import Optional
from nonebot import logger from nonebot import logger
from tortoise.exceptions import DoesNotExist
from .models import ChatHistory, ChatMessage, GroupChatState, PrivateChatState from .models import ChatHistory, ChatMessage, GroupChatState, PrivateChatState
@ -25,7 +24,7 @@ class DatabaseManager:
output_reasoning_content: bool, output_reasoning_content: bool,
random_trigger_prob: float, random_trigger_prob: float,
): ):
"""保存群组状态和历史到数据库""" """保存群组状态"""
try: try:
# 保存或更新群组状态 # 保存或更新群组状态
state, _ = await GroupChatState.get_or_create( state, _ = await GroupChatState.get_or_create(
@ -72,7 +71,7 @@ class DatabaseManager:
user_prompt: Optional[str], user_prompt: Optional[str],
output_reasoning_content: bool, output_reasoning_content: bool,
): ):
"""保存私聊状态和历史到数据库""" """保存私聊状态"""
try: try:
# 保存或更新私聊状态 # 保存或更新私聊状态
state, _ = await PrivateChatState.get_or_create( state, _ = await PrivateChatState.get_or_create(
@ -110,7 +109,7 @@ class DatabaseManager:
logger.error(f"保存私聊状态失败 用户: {user_id}, 错误: {e}") logger.error(f"保存私聊状态失败 用户: {user_id}, 错误: {e}")
@staticmethod @staticmethod
async def load_group_state(group_id: int, history_maxlen: int) -> dict: async def load_group_state(group_id: int, history_maxlen: int) -> Optional[dict]:
"""从数据库加载群组状态""" """从数据库加载群组状态"""
try: try:
state = await GroupChatState.get_or_none(group_id=group_id) state = await GroupChatState.get_or_none(group_id=group_id)
@ -145,7 +144,7 @@ class DatabaseManager:
return None return None
@staticmethod @staticmethod
async def load_private_state(user_id: int, history_maxlen: int) -> dict: async def load_private_state(user_id: int, history_maxlen: int) -> Optional[dict]:
"""从数据库加载私聊状态""" """从数据库加载私聊状态"""
try: try:
state = await PrivateChatState.get_or_none(user_id=user_id) state = await PrivateChatState.get_or_none(user_id=user_id)

View file

@ -3,12 +3,14 @@ Tortoise ORM 模型定义
用于存储聊天历史和群组/私聊状态 用于存储聊天历史和群组/私聊状态
""" """
import json import json
from datetime import datetime
from typing import Optional
from nonebot_plugin_tortoise_orm import add_model
from tortoise import fields from tortoise import fields
from tortoise.models import Model from tortoise.models import Model
# 注册模型到 Tortoise ORM
add_model(__name__)
class GroupChatState(Model): class GroupChatState(Model):
"""群组聊天状态""" """群组聊天状态"""
@ -24,7 +26,7 @@ class GroupChatState(Model):
class Meta: class Meta:
table = "llmchat_group_state" table = "llmchat_group_state"
description = "群组聊天状态表" table_description = "群组聊天状态表"
class PrivateChatState(Model): class PrivateChatState(Model):
@ -40,7 +42,7 @@ class PrivateChatState(Model):
class Meta: class Meta:
table = "llmchat_private_state" table = "llmchat_private_state"
description = "私聊状态表" table_description = "私聊状态表"
class ChatMessage(Model): class ChatMessage(Model):
@ -59,11 +61,7 @@ class ChatMessage(Model):
class Meta: class Meta:
table = "llmchat_message" table = "llmchat_message"
description = "聊天消息历史表" table_description = "聊天消息历史表"
indexes = [
("group_id", "is_private", "created_at"), # 复合索引用于快速查询
("user_id", "is_private", "created_at"),
]
@staticmethod @staticmethod
def serialize_content(content) -> str: def serialize_content(content) -> str:
@ -89,7 +87,7 @@ class ChatHistory(Model):
class Meta: class Meta:
table = "llmchat_history" table = "llmchat_history"
description = "聊天历史快照表(用于快速加载)" table_description = "聊天历史快照表(用于快速加载)"
@staticmethod @staticmethod
def serialize_messages(messages_list) -> str: def serialize_messages(messages_list) -> str: