137 lines
4.4 KiB
Python
137 lines
4.4 KiB
Python
|
|
"""OCR 字段 → KB 字段匹配模块。
|
|||
|
|
|
|||
|
|
两阶段匹配:
|
|||
|
|
1. Embedding 粗筛(相似度 top-3)
|
|||
|
|
2. LLM 精确确认
|
|||
|
|
|
|||
|
|
返回映射: {"工单号": "billNo", "客户名称": "customerName", ...}
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
|
|||
|
|
from backend.logger import get_logger
|
|||
|
|
|
|||
|
|
load_dotenv()
|
|||
|
|
|
|||
|
|
_match_log = get_logger("field_matcher")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _embed(text: str) -> list:
|
|||
|
|
"""获取文本的向量嵌入。"""
|
|||
|
|
from backend.rag_adapter import _get_searcher
|
|||
|
|
searcher = _get_searcher()
|
|||
|
|
if searcher._model is None:
|
|||
|
|
_ = searcher.model
|
|||
|
|
emb = searcher.model.encode(text, normalize_embeddings=True, show_progress_bar=False)
|
|||
|
|
return emb.tolist()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _cosine_similarity(a: list, b: list) -> float:
|
|||
|
|
"""余弦相似度(假设向量已归一化,点积即相似度)。"""
|
|||
|
|
return sum(x * y for x, y in zip(a, b))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def match_ocr_to_kb(ocr_fields: list[str], kb_fields: list[dict],
|
|||
|
|
llm=None) -> dict[str, str]:
|
|||
|
|
"""将 OCR 提取的字段名匹配到 KB 字段定义。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
ocr_fields: OCR 提取的中文字段名列表
|
|||
|
|
kb_fields: KB 字段定义 [{"name": "billNo", "description": "工单号", ...}]
|
|||
|
|
llm: 可选的 LLM 实例,用于精确确认
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
{"工单号": "billNo", "客户": "customerName", ...}
|
|||
|
|
"""
|
|||
|
|
if not ocr_fields or not kb_fields:
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
result = {}
|
|||
|
|
|
|||
|
|
# 阶段 1: Embedding 粗筛
|
|||
|
|
try:
|
|||
|
|
ocr_embs = {f: _embed(f) for f in ocr_fields}
|
|||
|
|
kb_embs = {f["name"]: _embed(f.get("description", f["name"])) for f in kb_fields}
|
|||
|
|
except Exception as e:
|
|||
|
|
_match_log.warning("Embedding 匹配失败,回退到 LLM: %s", e)
|
|||
|
|
return _match_via_llm(ocr_fields, kb_fields, llm)
|
|||
|
|
|
|||
|
|
candidates = {}
|
|||
|
|
for ocr_name, ocr_emb in ocr_embs.items():
|
|||
|
|
scored = []
|
|||
|
|
for kb_name, kb_emb in kb_embs.items():
|
|||
|
|
sim = _cosine_similarity(ocr_emb, kb_emb)
|
|||
|
|
scored.append((kb_name, sim))
|
|||
|
|
scored.sort(key=lambda x: x[1], reverse=True)
|
|||
|
|
candidates[ocr_name] = scored[:3]
|
|||
|
|
|
|||
|
|
# 阶段 2: LLM 精确确认
|
|||
|
|
if llm:
|
|||
|
|
confirmed = _match_via_llm(ocr_fields, kb_fields, llm, candidates)
|
|||
|
|
result.update(confirmed)
|
|||
|
|
else:
|
|||
|
|
for ocr_name, cands in candidates.items():
|
|||
|
|
if cands and cands[0][1] > 0.5:
|
|||
|
|
result[ocr_name] = cands[0][0]
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _match_via_llm(ocr_fields: list[str], kb_fields: list[dict],
|
|||
|
|
llm, candidates: Optional[dict] = None) -> dict[str, str]:
|
|||
|
|
"""使用 LLM 精确确认字段映射。"""
|
|||
|
|
kb_desc = "\n".join(
|
|||
|
|
f"- {f['name']}: {f.get('description', '')} ({f.get('type', 'java.lang.String')})"
|
|||
|
|
for f in kb_fields
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
candidates_hint = ""
|
|||
|
|
if candidates:
|
|||
|
|
cand_lines = []
|
|||
|
|
for ocr_name, cands in candidates.items():
|
|||
|
|
cand_str = ", ".join(f"{n}({s:.2f})" for n, s in cands)
|
|||
|
|
cand_lines.append(f" {ocr_name} -> 候选: {cand_str}")
|
|||
|
|
candidates_hint = (
|
|||
|
|
"向量相似度候选(仅供参考,请根据语义确认):\n"
|
|||
|
|
+ "\n".join(cand_lines)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
prompt = (
|
|||
|
|
"请将以下 OCR 识别的字段名匹配到知识库定义的字段。\n\n"
|
|||
|
|
f"OCR 字段: {json.dumps(ocr_fields, ensure_ascii=False)}\n\n"
|
|||
|
|
f"知识库字段:\n{kb_desc}\n\n"
|
|||
|
|
f"{candidates_hint}\n\n"
|
|||
|
|
"请以 JSON 对象格式输出映射关系,键为 OCR 字段名,值为 KB 字段名:\n"
|
|||
|
|
'{"工单号": "billNo", "客户名称": "customerName"}'
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
response = llm.invoke(prompt)
|
|||
|
|
content = response.content if hasattr(response, "content") else str(response)
|
|||
|
|
start = content.find("{")
|
|||
|
|
end = content.rfind("}") + 1
|
|||
|
|
if start >= 0 and end > start:
|
|||
|
|
return json.loads(content[start:end])
|
|||
|
|
except Exception as e:
|
|||
|
|
_match_log.warning("LLM 字段匹配失败: %s", e)
|
|||
|
|
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def format_field_mapping_context(mapping: dict[str, str]) -> str:
|
|||
|
|
"""将字段映射格式化为 prompt 上下文字符串。"""
|
|||
|
|
if not mapping:
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
lines = ["[字段映射 — OCR -> KB]",
|
|||
|
|
"请在 JRXML 中使用以下参数名:",
|
|||
|
|
"| OCR 字段 | JRXML 参数 |",
|
|||
|
|
"|---|---|"]
|
|||
|
|
for ocr_name, kb_name in mapping.items():
|
|||
|
|
lines.append(f"| {ocr_name} | $P{{{kb_name}}} |")
|
|||
|
|
return "\n".join(lines)
|