Files
agent_jrxml/backend/field_matcher.py
T

137 lines
4.4 KiB
Python
Raw Normal View History

"""OCR 字段 → KB 字段匹配模块。
两阶段匹配:
1. Embedding 粗筛(相似度 top-3
2. LLM 精确确认
返回映射: {"工单号": "billNo", "客户名称": "customerName", ...}
"""
import json
import os
from typing import Optional
from dotenv import load_dotenv
from backend.logger import get_logger
load_dotenv()
_match_log = get_logger("field_matcher")
def _embed(text: str) -> list:
"""获取文本的向量嵌入。"""
from backend.rag_adapter import _get_searcher
searcher = _get_searcher()
if searcher._model is None:
_ = searcher.model
emb = searcher.model.encode(text, normalize_embeddings=True, show_progress_bar=False)
return emb.tolist()
def _cosine_similarity(a: list, b: list) -> float:
"""余弦相似度(假设向量已归一化,点积即相似度)。"""
return sum(x * y for x, y in zip(a, b))
def match_ocr_to_kb(ocr_fields: list[str], kb_fields: list[dict],
llm=None) -> dict[str, str]:
"""将 OCR 提取的字段名匹配到 KB 字段定义。
Args:
ocr_fields: OCR 提取的中文字段名列表
kb_fields: KB 字段定义 [{"name": "billNo", "description": "工单号", ...}]
llm: 可选的 LLM 实例,用于精确确认
Returns:
{"工单号": "billNo", "客户": "customerName", ...}
"""
if not ocr_fields or not kb_fields:
return {}
result = {}
# 阶段 1: Embedding 粗筛
try:
ocr_embs = {f: _embed(f) for f in ocr_fields}
kb_embs = {f["name"]: _embed(f.get("description", f["name"])) for f in kb_fields}
except Exception as e:
_match_log.warning("Embedding 匹配失败,回退到 LLM: %s", e)
return _match_via_llm(ocr_fields, kb_fields, llm)
candidates = {}
for ocr_name, ocr_emb in ocr_embs.items():
scored = []
for kb_name, kb_emb in kb_embs.items():
sim = _cosine_similarity(ocr_emb, kb_emb)
scored.append((kb_name, sim))
scored.sort(key=lambda x: x[1], reverse=True)
candidates[ocr_name] = scored[:3]
# 阶段 2: LLM 精确确认
if llm:
confirmed = _match_via_llm(ocr_fields, kb_fields, llm, candidates)
result.update(confirmed)
else:
for ocr_name, cands in candidates.items():
if cands and cands[0][1] > 0.5:
result[ocr_name] = cands[0][0]
return result
def _match_via_llm(ocr_fields: list[str], kb_fields: list[dict],
llm, candidates: Optional[dict] = None) -> dict[str, str]:
"""使用 LLM 精确确认字段映射。"""
kb_desc = "\n".join(
f"- {f['name']}: {f.get('description', '')} ({f.get('type', 'java.lang.String')})"
for f in kb_fields
)
candidates_hint = ""
if candidates:
cand_lines = []
for ocr_name, cands in candidates.items():
cand_str = ", ".join(f"{n}({s:.2f})" for n, s in cands)
cand_lines.append(f" {ocr_name} -> 候选: {cand_str}")
candidates_hint = (
"向量相似度候选(仅供参考,请根据语义确认):\n"
+ "\n".join(cand_lines)
)
prompt = (
"请将以下 OCR 识别的字段名匹配到知识库定义的字段。\n\n"
f"OCR 字段: {json.dumps(ocr_fields, ensure_ascii=False)}\n\n"
f"知识库字段:\n{kb_desc}\n\n"
f"{candidates_hint}\n\n"
"请以 JSON 对象格式输出映射关系,键为 OCR 字段名,值为 KB 字段名:\n"
'{"工单号": "billNo", "客户名称": "customerName"}'
)
try:
response = llm.invoke(prompt)
content = response.content if hasattr(response, "content") else str(response)
start = content.find("{")
end = content.rfind("}") + 1
if start >= 0 and end > start:
return json.loads(content[start:end])
except Exception as e:
_match_log.warning("LLM 字段匹配失败: %s", e)
return {}
def format_field_mapping_context(mapping: dict[str, str]) -> str:
"""将字段映射格式化为 prompt 上下文字符串。"""
if not mapping:
return ""
lines = ["[字段映射 — OCR -> KB]",
"请在 JRXML 中使用以下参数名:",
"| OCR 字段 | JRXML 参数 |",
"|---|---|"]
for ocr_name, kb_name in mapping.items():
lines.append(f"| {ocr_name} | $P{{{kb_name}}} |")
return "\n".join(lines)