74 lines
2.8 KiB
Python
74 lines
2.8 KiB
Python
import json
|
||
import os
|
||
from pathlib import Path
|
||
from typing import Union
|
||
from qwen_agent.tools.base import BaseTool, register_tool
|
||
|
||
# 从环境变量读取,如果读不到则默认为当前目录下的 memory.json
|
||
# 使用 .resolve() 自动处理相对路径转绝对路径的逻辑
|
||
MEMORY_FILE = Path(os.getenv('MEMORY_FILE_PATH', './memory.json')).resolve()
|
||
|
||
def _load_memory() -> list:
|
||
"""内部函数:安全加载记忆并强制转换为列表格式"""
|
||
if not MEMORY_FILE.exists():
|
||
return []
|
||
try:
|
||
content = MEMORY_FILE.read_text(encoding='utf-8').strip()
|
||
if not content:
|
||
return []
|
||
data = json.loads(content)
|
||
# 核心修复:如果读到的是字典或其他格式,强制转为列表
|
||
if isinstance(data, list):
|
||
return data
|
||
return []
|
||
except Exception:
|
||
return []
|
||
|
||
def _save_memory(memories: list):
|
||
"""内部函数:安全保存"""
|
||
try:
|
||
MEMORY_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||
MEMORY_FILE.write_text(json.dumps(memories, ensure_ascii=False, indent=2), encoding='utf-8')
|
||
except Exception as e:
|
||
print(f"写入记忆文件失败: {e}")
|
||
|
||
@register_tool('manage_memory', allow_overwrite=True)
|
||
class MemoryTool(BaseTool):
|
||
description = '长期记忆管理工具。支持 add (添加), list (查看), delete (删除索引)。'
|
||
parameters = {
|
||
'type': 'object',
|
||
'properties': {
|
||
'operation': {'type': 'string', 'description': '操作类型: add|list|delete'},
|
||
'content': {'type': 'string', 'description': '记忆内容(仅add模式)'},
|
||
'index': {'type': 'integer', 'description': '索引号(仅delete模式)'}
|
||
},
|
||
'required': ['operation'],
|
||
}
|
||
|
||
def call(self, params: Union[str, dict], **kwargs) -> str:
|
||
params = self._verify_json_format_args(params)
|
||
op = params['operation'].lower()
|
||
memories = _load_memory()
|
||
|
||
if op == 'add':
|
||
content = params.get('content', '').strip()
|
||
if not content:
|
||
return "错误:内容不能为空。"
|
||
memories.append(content)
|
||
_save_memory(memories)
|
||
return f"✅ 成功存入:『{content}』"
|
||
|
||
elif op == 'list':
|
||
if not memories:
|
||
return "目前没有任何长期记忆。"
|
||
return "记忆列表:\n" + "\n".join([f"[{i}] {m}" for i, m in enumerate(memories)])
|
||
|
||
elif op == 'delete':
|
||
idx = params.get('index')
|
||
if idx is None or not (0 <= idx < len(memories)):
|
||
return f"错误:索引 {idx} 无效。"
|
||
removed = memories.pop(idx)
|
||
_save_memory(memories)
|
||
return f"🗑️ 已删除:『{removed}』"
|
||
|
||
return f"不支持的操作: {op}" |