226 lines
7.1 KiB
Python
226 lines
7.1 KiB
Python
|
|
"""错误自增长知识库 — 记录修正成功的错误案例,用于未来参考。
|
|||
|
|
|
|||
|
|
原则:
|
|||
|
|
- 仅记录"新错误"(指纹去重)
|
|||
|
|
- 必须包含完整的修正方案(prompt、工具链、前后 JRXML)
|
|||
|
|
- 存储于 ChromaDB,可被检索注入到生成 prompt 中
|
|||
|
|
|
|||
|
|
用法:
|
|||
|
|
from backend.error_kb import ErrorKB
|
|||
|
|
kb = ErrorKB()
|
|||
|
|
kb.record(error_msg, bad_jrxml, good_jrxml, correction_prompt)
|
|||
|
|
cases = kb.search("字段未声明", k=3)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import hashlib
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import re
|
|||
|
|
from datetime import datetime, timezone
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
|
|||
|
|
load_dotenv()
|
|||
|
|
|
|||
|
|
CHROMA_DIR = Path(os.getenv("CHROMA_PERSIST_DIR", "./db/chroma"))
|
|||
|
|
COLLECTION_NAME = "jrxml_error_cases"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _make_fingerprint(error_msg: str) -> str:
|
|||
|
|
"""生成错误指纹 — 标准化后取 hash,用于去重。
|
|||
|
|
|
|||
|
|
标准化规则:
|
|||
|
|
- 去除字段名、变量名等具体标识符(替换为占位符)
|
|||
|
|
- 小写化
|
|||
|
|
- 只保留错误的结构性特征
|
|||
|
|
"""
|
|||
|
|
text = error_msg.lower()
|
|||
|
|
# 替换变量名 / 字段名($F{xxx}, "name", 'value' 等)
|
|||
|
|
text = re.sub(r'\$f\{[^}]+\}', '$f{<FIELD>}', text)
|
|||
|
|
text = re.sub(r"'[^']*'", "'<VALUE>'", text)
|
|||
|
|
text = re.sub(r'"[^"]*"', '"<VALUE>"', text)
|
|||
|
|
# 替换数字
|
|||
|
|
text = re.sub(r'\b\d+\b', '<NUM>', text)
|
|||
|
|
# 压缩空白
|
|||
|
|
text = re.sub(r'\s+', ' ', text).strip()
|
|||
|
|
return hashlib.md5(text.encode()).hexdigest()[:16]
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ErrorKB:
|
|||
|
|
"""错误案例知识库 — 包装 ChromaDB 持久化。"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self._client = None
|
|||
|
|
self._collection = None
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def client(self):
|
|||
|
|
if self._client is None:
|
|||
|
|
import chromadb
|
|||
|
|
self._client = chromadb.PersistentClient(path=str(CHROMA_DIR))
|
|||
|
|
return self._client
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def collection(self):
|
|||
|
|
if self._collection is None:
|
|||
|
|
try:
|
|||
|
|
self._collection = self.client.get_collection(COLLECTION_NAME)
|
|||
|
|
except Exception:
|
|||
|
|
self._collection = self.client.create_collection(COLLECTION_NAME)
|
|||
|
|
return self._collection
|
|||
|
|
|
|||
|
|
def exists(self, error_msg: str) -> bool:
|
|||
|
|
"""检查错误是否已存在于知识库中(按指纹去重)。"""
|
|||
|
|
fp = _make_fingerprint(error_msg)
|
|||
|
|
try:
|
|||
|
|
results = self.collection.get(ids=[fp])
|
|||
|
|
return bool(results and results["ids"])
|
|||
|
|
except Exception:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def record(
|
|||
|
|
self,
|
|||
|
|
error_msg: str,
|
|||
|
|
bad_jrxml: str,
|
|||
|
|
good_jrxml: str,
|
|||
|
|
correction_prompt: str,
|
|||
|
|
model: str = "",
|
|||
|
|
retry_count: int = 0,
|
|||
|
|
) -> bool:
|
|||
|
|
"""记录一个成功修正的错误案例。
|
|||
|
|
|
|||
|
|
仅当指纹不重复时写入。返回 True 表示已记录,False 表示重复。
|
|||
|
|
"""
|
|||
|
|
if self.exists(error_msg):
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
fp = _make_fingerprint(error_msg)
|
|||
|
|
now = datetime.now(timezone.utc).isoformat()
|
|||
|
|
|
|||
|
|
# 内容:结构化记录
|
|||
|
|
doc = json.dumps({
|
|||
|
|
"error": error_msg,
|
|||
|
|
"bad_jrxml_snippet": bad_jrxml[:2000],
|
|||
|
|
"good_jrxml_snippet": good_jrxml[:2000],
|
|||
|
|
"correction_prompt": correction_prompt[:1500],
|
|||
|
|
"model": model,
|
|||
|
|
"retry_count": retry_count,
|
|||
|
|
"recorded_at": now,
|
|||
|
|
"tools": ["validation_service", "llm_correction"],
|
|||
|
|
}, ensure_ascii=False)
|
|||
|
|
|
|||
|
|
# 元数据:用于检索过滤
|
|||
|
|
error_keywords = _extract_keywords(error_msg)
|
|||
|
|
metadata = {
|
|||
|
|
"fingerprint": fp,
|
|||
|
|
"error_keywords": ", ".join(error_keywords[:5]),
|
|||
|
|
"recorded_at": now,
|
|||
|
|
"retry_success": retry_count + 1, # 第几次修正成功的
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
self.collection.add(
|
|||
|
|
ids=[fp],
|
|||
|
|
documents=[doc],
|
|||
|
|
metadatas=[metadata],
|
|||
|
|
)
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
def search(self, error_msg: str, k: int = 3) -> list[dict]:
|
|||
|
|
"""根据错误消息搜索相似的修正案例(ChromaDB 语义搜索)。
|
|||
|
|
|
|||
|
|
返回 [{error, fix_snippet, prompt, ...}, ...]
|
|||
|
|
"""
|
|||
|
|
keywords = _extract_keywords(error_msg)
|
|||
|
|
if not keywords:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
query_text = " ".join(keywords)
|
|||
|
|
try:
|
|||
|
|
results = self.collection.query(
|
|||
|
|
query_texts=[query_text],
|
|||
|
|
n_results=k,
|
|||
|
|
include=["documents", "metadatas", "distances"],
|
|||
|
|
)
|
|||
|
|
except Exception:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
output = []
|
|||
|
|
if not results["ids"] or not results["ids"][0]:
|
|||
|
|
return output
|
|||
|
|
|
|||
|
|
for i, doc_id in enumerate(results["ids"][0]):
|
|||
|
|
dist = results["distances"][0][i]
|
|||
|
|
try:
|
|||
|
|
data = json.loads(results["documents"][0][i])
|
|||
|
|
output.append({
|
|||
|
|
"id": doc_id,
|
|||
|
|
"error": data.get("error", ""),
|
|||
|
|
"fix_snippet": data.get("good_jrxml_snippet", ""),
|
|||
|
|
"prompt": data.get("correction_prompt", ""),
|
|||
|
|
"recorded_at": data.get("recorded_at", ""),
|
|||
|
|
"distance": dist,
|
|||
|
|
})
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
return output
|
|||
|
|
|
|||
|
|
def search_as_context(self, error_msg: str, k: int = 3) -> str:
|
|||
|
|
"""搜索并返回拼接好的错误案例上下文,可直接注入 LLM prompt。"""
|
|||
|
|
results = self.search(error_msg, k=k)
|
|||
|
|
if not results:
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
parts = []
|
|||
|
|
for r in results:
|
|||
|
|
parts.append(
|
|||
|
|
f"[历史错误案例]\n"
|
|||
|
|
f"错误: {r['error'][:200]}\n"
|
|||
|
|
f"修正后 JRXML 片段:\n{r['fix_snippet'][:800]}\n"
|
|||
|
|
)
|
|||
|
|
return "\n---\n".join(parts)
|
|||
|
|
|
|||
|
|
def stats(self) -> dict:
|
|||
|
|
"""返回知识库统计信息。"""
|
|||
|
|
try:
|
|||
|
|
count = self.collection.count()
|
|||
|
|
return {"total_cases": count, "collection": COLLECTION_NAME}
|
|||
|
|
except Exception:
|
|||
|
|
return {"total_cases": 0, "collection": COLLECTION_NAME}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_keywords(error_msg: str) -> list[str]:
|
|||
|
|
"""从错误消息中提取关键词(中文 + 英文 token)。"""
|
|||
|
|
# 中文字符作为独立关键词
|
|||
|
|
chinese = re.findall(r'[一-鿿]{2,}', error_msg)
|
|||
|
|
# 英文 camelCase / snake_case token
|
|||
|
|
english = re.findall(r'[a-zA-Z_][a-zA-Z0-9_]{2,}', error_msg)
|
|||
|
|
# JRXML 特有模式
|
|||
|
|
jrxml_patterns = re.findall(r'\$F\{[^}]*\}', error_msg)
|
|||
|
|
return chinese + english + jrxml_patterns
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局单例
|
|||
|
|
_kb: Optional[ErrorKB] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_error_kb() -> ErrorKB:
|
|||
|
|
global _kb
|
|||
|
|
if _kb is None:
|
|||
|
|
_kb = ErrorKB()
|
|||
|
|
return _kb
|
|||
|
|
|
|||
|
|
|
|||
|
|
def record_error(error_msg: str, bad_jrxml: str, good_jrxml: str,
|
|||
|
|
correction_prompt: str, model: str = "", retry_count: int = 0) -> bool:
|
|||
|
|
"""便捷函数:记录成功修正的错误案例。"""
|
|||
|
|
return get_error_kb().record(error_msg, bad_jrxml, good_jrxml,
|
|||
|
|
correction_prompt, model, retry_count)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def search_error_cases(error_msg: str, k: int = 3) -> str:
|
|||
|
|
"""便捷函数:搜索历史错误案例并返回上下文字符串。"""
|
|||
|
|
return get_error_kb().search_as_context(error_msg, k=k)
|