158 lines
6.7 KiB
Python
158 lines
6.7 KiB
Python
|
|
"""field_matcher.py 测试 — OCR 字段 → KB 字段匹配, embedding + LLM。"""
|
||
|
|
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
from unittest.mock import patch, MagicMock
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
|
|
||
|
|
from backend.field_matcher import (
|
||
|
|
_cosine_similarity, match_ocr_to_kb, _match_via_llm,
|
||
|
|
format_field_mapping_context,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ── 余弦相似度 ──────────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestCosineSimilarity:
|
||
|
|
def test_identical_vectors(self):
|
||
|
|
assert _cosine_similarity([1, 0, 0], [1, 0, 0]) == 1.0
|
||
|
|
|
||
|
|
def test_orthogonal_vectors(self):
|
||
|
|
assert _cosine_similarity([1, 0, 0], [0, 1, 0]) == 0.0
|
||
|
|
|
||
|
|
def test_opposite_vectors(self):
|
||
|
|
assert _cosine_similarity([1, 0], [-1, 0]) == -1.0
|
||
|
|
|
||
|
|
def test_normalized_vectors_range(self):
|
||
|
|
sim = _cosine_similarity([0.6, 0.8], [0.8, 0.6])
|
||
|
|
assert -1.0 <= sim <= 1.0
|
||
|
|
|
||
|
|
|
||
|
|
# ── LLM 匹配 ────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestMatchViaLlm:
|
||
|
|
def test_returns_json_mapping(self):
|
||
|
|
mock_llm = MagicMock()
|
||
|
|
mock_response = MagicMock()
|
||
|
|
mock_response.content = '{"工单号": "billNo", "客户": "customerName"}'
|
||
|
|
mock_llm.invoke.return_value = mock_response
|
||
|
|
|
||
|
|
kb_fields = [
|
||
|
|
{"name": "billNo", "description": "工单号", "type": "String"},
|
||
|
|
{"name": "customerName", "description": "客户名称", "type": "String"},
|
||
|
|
]
|
||
|
|
result = _match_via_llm(["工单号", "客户"], kb_fields, mock_llm)
|
||
|
|
assert result == {"工单号": "billNo", "客户": "customerName"}
|
||
|
|
|
||
|
|
def test_includes_candidates_hint_when_provided(self):
|
||
|
|
mock_llm = MagicMock()
|
||
|
|
mock_response = MagicMock()
|
||
|
|
mock_response.content = '{"工单号": "billNo"}'
|
||
|
|
mock_llm.invoke.return_value = mock_response
|
||
|
|
candidates = {"工单号": [("billNo", 0.85), ("orderId", 0.62)]}
|
||
|
|
result = _match_via_llm(
|
||
|
|
["工单号"],
|
||
|
|
[{"name": "billNo", "description": "工单号", "type": "String"}],
|
||
|
|
mock_llm, candidates=candidates)
|
||
|
|
call_args = mock_llm.invoke.call_args[0][0]
|
||
|
|
assert "候选" in call_args
|
||
|
|
assert "billNo" in call_args
|
||
|
|
|
||
|
|
def test_llm_error_returns_empty_dict(self):
|
||
|
|
mock_llm = MagicMock()
|
||
|
|
mock_llm.invoke.side_effect = RuntimeError("LLM crash")
|
||
|
|
result = _match_via_llm(["x"], [{"name": "y", "description": "", "type": "String"}], mock_llm)
|
||
|
|
assert result == {}
|
||
|
|
|
||
|
|
def test_llm_returns_invalid_json_returns_empty(self):
|
||
|
|
mock_llm = MagicMock()
|
||
|
|
mock_response = MagicMock()
|
||
|
|
mock_response.content = "not json at all"
|
||
|
|
mock_llm.invoke.return_value = mock_response
|
||
|
|
result = _match_via_llm(["x"], [{"name": "y", "description": "", "type": "String"}], mock_llm)
|
||
|
|
assert result == {}
|
||
|
|
|
||
|
|
|
||
|
|
# ── 完整匹配流程 ────────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestMatchOcrToKb:
|
||
|
|
@pytest.fixture(autouse=True)
|
||
|
|
def mock_embed(self):
|
||
|
|
with patch("backend.field_matcher._embed") as mock_embed:
|
||
|
|
def _fake_embed(text):
|
||
|
|
if "billNo" in text or "工单" in text:
|
||
|
|
return [1.0, 0.0, 0.0]
|
||
|
|
if "customerName" in text or "客户" in text:
|
||
|
|
return [0.0, 1.0, 0.0]
|
||
|
|
if "amount" in text or "金额" in text:
|
||
|
|
return [0.0, 0.0, 1.0]
|
||
|
|
return [0.0, 0.0, 0.0]
|
||
|
|
mock_embed.side_effect = _fake_embed
|
||
|
|
yield mock_embed
|
||
|
|
|
||
|
|
def test_matches_without_llm(self):
|
||
|
|
kb_fields = [
|
||
|
|
{"name": "billNo", "description": "工单号", "type": "String"},
|
||
|
|
{"name": "customerName", "description": "客户名称", "type": "String"},
|
||
|
|
{"name": "amount", "description": "金额", "type": "BigDecimal"},
|
||
|
|
]
|
||
|
|
mapping = match_ocr_to_kb(
|
||
|
|
["工单号", "客户名称", "金额"], kb_fields, llm=None)
|
||
|
|
assert mapping["工单号"] == "billNo"
|
||
|
|
assert mapping["客户名称"] == "customerName"
|
||
|
|
assert mapping["金额"] == "amount"
|
||
|
|
|
||
|
|
def test_empty_inputs_return_empty(self):
|
||
|
|
assert match_ocr_to_kb([], [], llm=None) == {}
|
||
|
|
assert match_ocr_to_kb(["x"], [], llm=None) == {}
|
||
|
|
assert match_ocr_to_kb([], [{"name": "y", "description": "", "type": "String"}], llm=None) == {}
|
||
|
|
|
||
|
|
def test_low_similarity_not_matched(self):
|
||
|
|
kb_fields = [{"name": "far", "description": "不相关字段", "type": "String"}]
|
||
|
|
mapping = match_ocr_to_kb(["无关"], kb_fields, llm=None)
|
||
|
|
assert "无关" not in mapping or mapping == {}
|
||
|
|
|
||
|
|
def test_uses_llm_when_provided(self):
|
||
|
|
mock_llm = MagicMock()
|
||
|
|
mock_response = MagicMock()
|
||
|
|
mock_response.content = '{"工单号": "billNo", "客户名称": "customerName"}'
|
||
|
|
mock_llm.invoke.return_value = mock_response
|
||
|
|
kb_fields = [
|
||
|
|
{"name": "billNo", "description": "工单号", "type": "String"},
|
||
|
|
{"name": "customerName", "description": "客户", "type": "String"},
|
||
|
|
]
|
||
|
|
mapping = match_ocr_to_kb(["工单号", "客户名称"], kb_fields, llm=mock_llm)
|
||
|
|
assert mapping["工单号"] == "billNo"
|
||
|
|
|
||
|
|
def test_embedding_failure_falls_back_to_llm(self):
|
||
|
|
mock_llm = MagicMock()
|
||
|
|
mock_response = MagicMock()
|
||
|
|
mock_response.content = '{"工单号": "billNo"}'
|
||
|
|
mock_llm.invoke.return_value = mock_response
|
||
|
|
with patch("backend.field_matcher._embed", side_effect=RuntimeError("model error")):
|
||
|
|
kb_fields = [{"name": "billNo", "description": "工单号", "type": "String"}]
|
||
|
|
mapping = match_ocr_to_kb(["工单号"], kb_fields, llm=mock_llm)
|
||
|
|
assert mapping["工单号"] == "billNo"
|
||
|
|
|
||
|
|
|
||
|
|
# ── 格式化上下文 ────────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestFormatFieldMappingContext:
|
||
|
|
def test_formats_mapping_as_table(self):
|
||
|
|
ctx = format_field_mapping_context({"工单号": "billNo", "客户": "customerName"})
|
||
|
|
assert "[字段映射" in ctx
|
||
|
|
assert "$P{billNo}" in ctx
|
||
|
|
assert "$P{customerName}" in ctx
|
||
|
|
assert "工单号" in ctx
|
||
|
|
assert "客户" in ctx
|
||
|
|
|
||
|
|
def test_empty_mapping_returns_empty_string(self):
|
||
|
|
assert format_field_mapping_context({}) == ""
|
||
|
|
|
||
|
|
def test_single_entry(self):
|
||
|
|
ctx = format_field_mapping_context({"发票号码": "invoiceNo"})
|
||
|
|
assert "$P{invoiceNo}" in ctx
|