797 lines
26 KiB
Python
797 lines
26 KiB
Python
|
|
"""OCR 单据字段精确提取器。
|
|||
|
|
|
|||
|
|
两阶段提取流程:
|
|||
|
|
阶段1 - 文档分析: 复用 file_parser.parse_file() 和 layout_analyzer.analyze_layout()
|
|||
|
|
获取每个文本元素的精确坐标和内容
|
|||
|
|
阶段2 - 字段提取: 给定目标字段列表,通过四种策略(精确KV匹配、模糊KV匹配、
|
|||
|
|
正则模式匹配、表格结构匹配)提取字段值、位置和置信度
|
|||
|
|
|
|||
|
|
用法:
|
|||
|
|
from backend.ocr_extractor import OcrExtractor
|
|||
|
|
|
|||
|
|
extractor = OcrExtractor()
|
|||
|
|
result = extractor.extract("invoice.png", ["发票代码", "发票号码", "合计金额"])
|
|||
|
|
for field in result:
|
|||
|
|
print(f"{field['field_name']}: {field['field_value']} (置信度: {field['confidence']})")
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import re
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Any, Optional
|
|||
|
|
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
|
|||
|
|
load_dotenv()
|
|||
|
|
|
|||
|
|
OCR_USE_GPU = os.getenv("OCR_USE_GPU", "false").lower() in ("true", "1", "yes")
|
|||
|
|
OCR_CONFIDENCE_THRESHOLD = float(os.getenv("OCR_CONFIDENCE_THRESHOLD", "0.5"))
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class OcrTextElement:
|
|||
|
|
"""OCR 文本元素,包含精确坐标和内容。"""
|
|||
|
|
|
|||
|
|
text: str
|
|||
|
|
x_min: float
|
|||
|
|
y_min: float
|
|||
|
|
x_max: float
|
|||
|
|
y_max: float
|
|||
|
|
confidence: float = 1.0
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def center_x(self) -> float:
|
|||
|
|
return (self.x_min + self.x_max) / 2
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def center_y(self) -> float:
|
|||
|
|
return (self.y_min + self.y_max) / 2
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def width(self) -> float:
|
|||
|
|
return self.x_max - self.x_min
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def height(self) -> float:
|
|||
|
|
return self.y_max - self.y_min
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def bbox(self) -> list[float]:
|
|||
|
|
return [self.x_min, self.y_min, self.x_max, self.y_max]
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ExtractedField:
|
|||
|
|
"""提取的字段结果。"""
|
|||
|
|
|
|||
|
|
field_name: str
|
|||
|
|
field_value: str
|
|||
|
|
bbox: list[float]
|
|||
|
|
confidence: float
|
|||
|
|
extraction_method: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ExtractionResult:
|
|||
|
|
"""单次提取的完整结果。"""
|
|||
|
|
|
|||
|
|
file_path: str
|
|||
|
|
image_size: tuple[int, int]
|
|||
|
|
fields: list[ExtractedField] = field(default_factory=list)
|
|||
|
|
all_elements: list[OcrTextElement] = field(default_factory=list)
|
|||
|
|
errors: list[str] = field(default_factory=list)
|
|||
|
|
ocr_available: bool = False
|
|||
|
|
|
|||
|
|
def to_dict(self) -> dict:
|
|||
|
|
return {
|
|||
|
|
"file_path": self.file_path,
|
|||
|
|
"image_size": self.image_size,
|
|||
|
|
"ocr_available": self.ocr_available,
|
|||
|
|
"fields": [
|
|||
|
|
{
|
|||
|
|
"field_name": f.field_name,
|
|||
|
|
"field_value": f.field_value,
|
|||
|
|
"bbox": f.bbox,
|
|||
|
|
"confidence": f.confidence,
|
|||
|
|
"extraction_method": f.extraction_method,
|
|||
|
|
}
|
|||
|
|
for f in self.fields
|
|||
|
|
],
|
|||
|
|
"total_elements": len(self.all_elements),
|
|||
|
|
"errors": self.errors,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class OcrExtractor:
|
|||
|
|
"""OCR 单据字段精确提取器。
|
|||
|
|
|
|||
|
|
两阶段流水线:
|
|||
|
|
阶段1: 对上传图片进行 OCR + 版面分析,产出带坐标的文本元素列表
|
|||
|
|
阶段2: 根据目标字段列表,按优先级逐一尝试四种提取策略
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
use_gpu: bool = False,
|
|||
|
|
confidence_threshold: float = 0.5,
|
|||
|
|
):
|
|||
|
|
"""初始化提取器。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
use_gpu: 是否使用 GPU 加速 OCR(需要相应驱动)
|
|||
|
|
confidence_threshold: OCR 文本置信度最低阈值,低于此值的元素被忽略
|
|||
|
|
"""
|
|||
|
|
self.use_gpu = use_gpu if use_gpu else OCR_USE_GPU
|
|||
|
|
self.confidence_threshold = (
|
|||
|
|
confidence_threshold
|
|||
|
|
if confidence_threshold != 0.5
|
|||
|
|
else OCR_CONFIDENCE_THRESHOLD
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# ========================================================================
|
|||
|
|
# 公共接口
|
|||
|
|
# ========================================================================
|
|||
|
|
|
|||
|
|
def extract(
|
|||
|
|
self,
|
|||
|
|
file_path: str,
|
|||
|
|
target_fields: list[str],
|
|||
|
|
) -> dict:
|
|||
|
|
"""执行两阶段 OCR 字段提取。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
file_path: 图片文件路径(支持 png/jpg/jpeg/bmp/webp)
|
|||
|
|
target_fields: 需要提取的字段名称列表,如 ["发票代码", "发票号码", "合计金额"]
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
提取结果字典,格式见 ExtractionResult.to_dict()
|
|||
|
|
"""
|
|||
|
|
result = ExtractionResult(file_path=file_path, image_size=(0, 0))
|
|||
|
|
|
|||
|
|
if not Path(file_path).exists():
|
|||
|
|
result.errors.append(f"文件不存在: {file_path}")
|
|||
|
|
return result.to_dict()
|
|||
|
|
|
|||
|
|
elements, image_size = self._analyze_document(file_path)
|
|||
|
|
result.image_size = image_size
|
|||
|
|
result.all_elements = elements
|
|||
|
|
|
|||
|
|
if not elements:
|
|||
|
|
result.ocr_available = self._check_ocr_availability()
|
|||
|
|
if not result.ocr_available:
|
|||
|
|
result.errors.append(
|
|||
|
|
"OCR 引擎不可用,请安装 easyocr (pip install easyocr) 或 paddleocr"
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
result.errors.append("图片未检测到文字元素")
|
|||
|
|
return result.to_dict()
|
|||
|
|
|
|||
|
|
result.ocr_available = True
|
|||
|
|
for field_name in target_fields:
|
|||
|
|
extracted = self._extract_field(field_name, elements)
|
|||
|
|
if extracted:
|
|||
|
|
result.fields.append(extracted)
|
|||
|
|
else:
|
|||
|
|
result.fields.append(
|
|||
|
|
ExtractedField(
|
|||
|
|
field_name=field_name,
|
|||
|
|
field_value="",
|
|||
|
|
bbox=[],
|
|||
|
|
confidence=0.0,
|
|||
|
|
extraction_method="none",
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return result.to_dict()
|
|||
|
|
|
|||
|
|
def extract_from_layout_result(
|
|||
|
|
self,
|
|||
|
|
layout_result: dict,
|
|||
|
|
target_fields: list[str],
|
|||
|
|
) -> dict:
|
|||
|
|
"""直接从 layout_analyzer.analyze_layout() 的结果中提取字段。
|
|||
|
|
|
|||
|
|
当已有版面分析结果时,跳过阶段1的重复 OCR,直接进入阶段2。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
layout_result: analyze_layout() 的返回值
|
|||
|
|
target_fields: 需要提取的字段名称列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
提取结果字典
|
|||
|
|
"""
|
|||
|
|
rows = layout_result.get("rows", [])
|
|||
|
|
if not rows:
|
|||
|
|
return ExtractionResult(
|
|||
|
|
file_path="(from layout)",
|
|||
|
|
image_size=layout_result.get("image_size", (0, 0)),
|
|||
|
|
errors=["版面分析结果中没有文本行"],
|
|||
|
|
).to_dict()
|
|||
|
|
|
|||
|
|
elements = []
|
|||
|
|
for row in rows:
|
|||
|
|
for elem_data in row.get("elements", []):
|
|||
|
|
elements.append(
|
|||
|
|
OcrTextElement(
|
|||
|
|
text=elem_data.get("text", ""),
|
|||
|
|
x_min=elem_data.get("x", 0),
|
|||
|
|
y_min=elem_data.get("y", 0),
|
|||
|
|
x_max=elem_data.get("x", 0) + elem_data.get("w", 0),
|
|||
|
|
y_max=elem_data.get("y", 0) + elem_data.get("h", 0),
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result = ExtractionResult(
|
|||
|
|
file_path="(from layout)",
|
|||
|
|
image_size=layout_result.get("image_size", (0, 0)),
|
|||
|
|
all_elements=elements,
|
|||
|
|
ocr_available=True,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
for field_name in target_fields:
|
|||
|
|
extracted = self._extract_field(field_name, elements)
|
|||
|
|
if extracted:
|
|||
|
|
result.fields.append(extracted)
|
|||
|
|
else:
|
|||
|
|
result.fields.append(
|
|||
|
|
ExtractedField(
|
|||
|
|
field_name=field_name,
|
|||
|
|
field_value="",
|
|||
|
|
bbox=[],
|
|||
|
|
confidence=0.0,
|
|||
|
|
extraction_method="none",
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return result.to_dict()
|
|||
|
|
|
|||
|
|
# ========================================================================
|
|||
|
|
# 阶段1: 文档分析
|
|||
|
|
# ========================================================================
|
|||
|
|
|
|||
|
|
def _analyze_document(self, file_path: str) -> tuple[list[OcrTextElement], tuple[int, int]]:
|
|||
|
|
"""阶段1: OCR + 版面分析,产出带坐标的文本元素列表。"""
|
|||
|
|
from backend.layout_analyzer import _load_image, _ocr_elements
|
|||
|
|
|
|||
|
|
img = _load_image(Path(file_path))
|
|||
|
|
if img is None:
|
|||
|
|
return [], (0, 0)
|
|||
|
|
|
|||
|
|
image_size = img.size
|
|||
|
|
raw_elements = self._ocr_elements_enhanced(img, file_path)
|
|||
|
|
|
|||
|
|
elements = []
|
|||
|
|
for e_data in raw_elements:
|
|||
|
|
if e_data.get("confidence", 1.0) < self.confidence_threshold:
|
|||
|
|
continue
|
|||
|
|
elements.append(
|
|||
|
|
OcrTextElement(
|
|||
|
|
text=e_data.get("text", ""),
|
|||
|
|
x_min=e_data.get("x", 0),
|
|||
|
|
y_min=e_data.get("y", 0),
|
|||
|
|
x_max=e_data.get("x", 0) + e_data.get("w", 0),
|
|||
|
|
y_max=e_data.get("y", 0) + e_data.get("h", 0),
|
|||
|
|
confidence=e_data.get("confidence", 1.0),
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
elements.sort(key=lambda e: (e.y_min, e.x_min))
|
|||
|
|
return elements, image_size
|
|||
|
|
|
|||
|
|
def _ocr_elements_enhanced(self, img, file_path: str) -> list[dict]:
|
|||
|
|
"""增强版 OCR,返回带置信度的元素列表。"""
|
|||
|
|
try:
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
easyocr_result = self._try_easyocr(np.array(img))
|
|||
|
|
if easyocr_result:
|
|||
|
|
return easyocr_result
|
|||
|
|
|
|||
|
|
paddleocr_result = self._try_paddleocr(img, file_path)
|
|||
|
|
if paddleocr_result:
|
|||
|
|
return paddleocr_result
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
def _try_easyocr(self, np_img) -> Optional[list[dict]]:
|
|||
|
|
try:
|
|||
|
|
import easyocr
|
|||
|
|
|
|||
|
|
reader = easyocr.Reader(
|
|||
|
|
["ch_sim", "en"],
|
|||
|
|
gpu=self.use_gpu,
|
|||
|
|
verbose=False,
|
|||
|
|
)
|
|||
|
|
raw_result = reader.readtext(np_img)
|
|||
|
|
|
|||
|
|
elements = []
|
|||
|
|
for bbox, text, confidence in raw_result:
|
|||
|
|
if not text.strip():
|
|||
|
|
continue
|
|||
|
|
xs = [p[0] for p in bbox]
|
|||
|
|
ys = [p[1] for p in bbox]
|
|||
|
|
x_min, x_max = min(xs), max(xs)
|
|||
|
|
y_min, y_max = min(ys), max(ys)
|
|||
|
|
|
|||
|
|
elements.append({
|
|||
|
|
"x": round(x_min, 1),
|
|||
|
|
"y": round(y_min, 1),
|
|||
|
|
"w": round(x_max - x_min, 1),
|
|||
|
|
"h": round(y_max - y_min, 1),
|
|||
|
|
"text": text.strip(),
|
|||
|
|
"confidence": round(confidence, 4),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
elements.sort(key=lambda e: (e["y"], e["x"]))
|
|||
|
|
return elements
|
|||
|
|
except ImportError:
|
|||
|
|
return None
|
|||
|
|
except Exception:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def _try_paddleocr(self, img, file_path: str) -> Optional[list[dict]]:
|
|||
|
|
try:
|
|||
|
|
from paddleocr import PaddleOCR
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
ocr = PaddleOCR(lang="ch")
|
|||
|
|
raw_result = ocr.ocr(np.array(img))
|
|||
|
|
|
|||
|
|
elements = []
|
|||
|
|
if raw_result and raw_result[0]:
|
|||
|
|
for line in raw_result[0]:
|
|||
|
|
if len(line) < 2:
|
|||
|
|
continue
|
|||
|
|
box = line[0]
|
|||
|
|
text_info = line[1]
|
|||
|
|
|
|||
|
|
if isinstance(text_info, (list, tuple)):
|
|||
|
|
text = text_info[0]
|
|||
|
|
confidence = text_info[1] if len(text_info) > 1 else 1.0
|
|||
|
|
else:
|
|||
|
|
text = str(text_info)
|
|||
|
|
confidence = 1.0
|
|||
|
|
|
|||
|
|
if not text.strip():
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
xs = [p[0] for p in box]
|
|||
|
|
ys = [p[1] for p in box]
|
|||
|
|
x_min, x_max = min(xs), max(xs)
|
|||
|
|
y_min, y_max = min(ys), max(ys)
|
|||
|
|
|
|||
|
|
elements.append({
|
|||
|
|
"x": round(x_min, 1),
|
|||
|
|
"y": round(y_min, 1),
|
|||
|
|
"w": round(x_max - x_min, 1),
|
|||
|
|
"h": round(y_max - y_min, 1),
|
|||
|
|
"text": text.strip(),
|
|||
|
|
"confidence": round(float(confidence), 4),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
elements.sort(key=lambda e: (e["y"], e["x"]))
|
|||
|
|
return elements
|
|||
|
|
except ImportError:
|
|||
|
|
return None
|
|||
|
|
except Exception:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def _check_ocr_availability(self) -> bool:
|
|||
|
|
try:
|
|||
|
|
import easyocr
|
|||
|
|
return True
|
|||
|
|
except ImportError:
|
|||
|
|
pass
|
|||
|
|
try:
|
|||
|
|
import paddleocr
|
|||
|
|
return True
|
|||
|
|
except ImportError:
|
|||
|
|
pass
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# ========================================================================
|
|||
|
|
# 阶段2: 字段精确提取
|
|||
|
|
# ========================================================================
|
|||
|
|
|
|||
|
|
def _extract_field(
|
|||
|
|
self,
|
|||
|
|
field_name: str,
|
|||
|
|
elements: list[OcrTextElement],
|
|||
|
|
) -> Optional[ExtractedField]:
|
|||
|
|
"""按优先级尝试四种策略提取单个字段。
|
|||
|
|
|
|||
|
|
策略优先级:
|
|||
|
|
1. 精确键值对匹配
|
|||
|
|
2. 模糊键值对匹配
|
|||
|
|
3. 正则模式匹配
|
|||
|
|
4. 表格结构匹配
|
|||
|
|
"""
|
|||
|
|
strategies = [
|
|||
|
|
("exact_match", self._exact_kv_match),
|
|||
|
|
("kv_pair", self._fuzzy_kv_match),
|
|||
|
|
("regex", self._regex_match),
|
|||
|
|
("table_match", self._table_match),
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
for method_name, strategy_fn in strategies:
|
|||
|
|
result = strategy_fn(field_name, elements)
|
|||
|
|
if result and result.field_value:
|
|||
|
|
result.extraction_method = method_name
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# -----------------------------------------------------------------------
|
|||
|
|
# 策略1: 精确键值对匹配
|
|||
|
|
# -----------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def _exact_kv_match(
|
|||
|
|
self,
|
|||
|
|
field_name: str,
|
|||
|
|
elements: list[OcrTextElement],
|
|||
|
|
) -> Optional[ExtractedField]:
|
|||
|
|
"""精确键值对匹配: 识别"字段名: 值"或"字段名:值"模式。
|
|||
|
|
|
|||
|
|
在同一文本元素中查找 "字段名" 后紧跟分隔符 + "值" 的模式。
|
|||
|
|
如 OCR 识别出 "发票代码: 12345678" 这一整个元素。
|
|||
|
|
"""
|
|||
|
|
separators = [":", ":", "=", "-", "—", ":", "\t", "|"]
|
|||
|
|
field_name_clean = field_name.strip()
|
|||
|
|
|
|||
|
|
for elem in elements:
|
|||
|
|
text = elem.text
|
|||
|
|
if field_name_clean not in text:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
for sep in separators:
|
|||
|
|
pattern = re.escape(field_name_clean) + r"\s*" + re.escape(sep) + r"\s*(.+)"
|
|||
|
|
m = re.search(pattern, text)
|
|||
|
|
if m:
|
|||
|
|
value = m.group(1).strip()
|
|||
|
|
if value:
|
|||
|
|
return ExtractedField(
|
|||
|
|
field_name=field_name,
|
|||
|
|
field_value=value,
|
|||
|
|
bbox=elem.bbox,
|
|||
|
|
confidence=0.95,
|
|||
|
|
extraction_method="",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
simple_pattern = re.escape(field_name_clean) + r"\s+(.+)"
|
|||
|
|
m = re.search(simple_pattern, text)
|
|||
|
|
if m:
|
|||
|
|
value = m.group(1).strip()
|
|||
|
|
if value and value != field_name_clean:
|
|||
|
|
return ExtractedField(
|
|||
|
|
field_name=field_name,
|
|||
|
|
field_value=value,
|
|||
|
|
bbox=elem.bbox,
|
|||
|
|
confidence=0.85,
|
|||
|
|
extraction_method="",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# -----------------------------------------------------------------------
|
|||
|
|
# 策略2: 模糊键值对匹配
|
|||
|
|
# -----------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def _fuzzy_kv_match(
|
|||
|
|
self,
|
|||
|
|
field_name: str,
|
|||
|
|
elements: list[OcrTextElement],
|
|||
|
|
) -> Optional[ExtractedField]:
|
|||
|
|
"""模糊键值对匹配: 字段名和值分布在相邻的文本元素中。
|
|||
|
|
|
|||
|
|
找到含字段名的元素后,在同一行或相邻元素中查找值。
|
|||
|
|
"""
|
|||
|
|
field_name_clean = field_name.strip()
|
|||
|
|
field_elem = None
|
|||
|
|
|
|||
|
|
for elem in elements:
|
|||
|
|
if field_name_clean in elem.text:
|
|||
|
|
field_elem = elem
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if field_elem is None:
|
|||
|
|
matching = []
|
|||
|
|
for elem in elements:
|
|||
|
|
sim = self._text_similarity(field_name_clean, elem.text)
|
|||
|
|
if sim > 0.6:
|
|||
|
|
matching.append((sim, elem))
|
|||
|
|
if matching:
|
|||
|
|
matching.sort(key=lambda x: x[0], reverse=True)
|
|||
|
|
field_elem = matching[0][1]
|
|||
|
|
|
|||
|
|
if field_elem is None:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
candidates = []
|
|||
|
|
for elem in elements:
|
|||
|
|
if elem is field_elem:
|
|||
|
|
continue
|
|||
|
|
candidates.append(elem)
|
|||
|
|
|
|||
|
|
same_row = []
|
|||
|
|
for elem in candidates:
|
|||
|
|
if abs(elem.center_y - field_elem.center_y) < field_elem.height * 1.5:
|
|||
|
|
same_row.append(elem)
|
|||
|
|
if same_row:
|
|||
|
|
same_row.sort(key=lambda e: e.x_min)
|
|||
|
|
for elem in same_row:
|
|||
|
|
if elem.x_min > field_elem.x_max:
|
|||
|
|
return ExtractedField(
|
|||
|
|
field_name=field_name,
|
|||
|
|
field_value=elem.text,
|
|||
|
|
bbox=elem.bbox,
|
|||
|
|
confidence=0.75,
|
|||
|
|
extraction_method="",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
nearest = None
|
|||
|
|
nearest_dist = float("inf")
|
|||
|
|
for elem in candidates:
|
|||
|
|
if elem.y_min > field_elem.y_max:
|
|||
|
|
dy = elem.y_min - field_elem.y_max
|
|||
|
|
dx = abs(elem.center_x - field_elem.center_x)
|
|||
|
|
dist = dy + dx * 0.3
|
|||
|
|
if dist < nearest_dist and dy < field_elem.height * 3:
|
|||
|
|
nearest_dist = dist
|
|||
|
|
nearest = elem
|
|||
|
|
|
|||
|
|
if nearest:
|
|||
|
|
return ExtractedField(
|
|||
|
|
field_name=field_name,
|
|||
|
|
field_value=nearest.text,
|
|||
|
|
bbox=nearest.bbox,
|
|||
|
|
confidence=0.6,
|
|||
|
|
extraction_method="",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# -----------------------------------------------------------------------
|
|||
|
|
# 策略3: 正则模式匹配
|
|||
|
|
# -----------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
PREDEFINED_PATTERNS: dict[str, str] = {
|
|||
|
|
"发票代码": r"[0-9A-Za-z]{10,12}",
|
|||
|
|
"发票号码": r"\d{8}",
|
|||
|
|
"合计金额": r"[\d,]+\.?\d*",
|
|||
|
|
"金额": r"[\d,]+\.?\d*",
|
|||
|
|
"开票日期": r"\d{4}[年/\-]\d{1,2}[月/\-]\d{1,2}日?",
|
|||
|
|
"日期": r"\d{4}[年/\-]\d{1,2}[月/\-]\d{1,2}日?",
|
|||
|
|
"校验码": r"[0-9A-Fa-f]{5,20}",
|
|||
|
|
"总价": r"[\d,]+\.?\d*",
|
|||
|
|
"总金额": r"[\d,]+\.?\d*",
|
|||
|
|
"价税合计": r"[\d,]+\.?\d*",
|
|||
|
|
"数量": r"\d+\.?\d*",
|
|||
|
|
"单价": r"[\d,]+\.?\d*",
|
|||
|
|
"税率": r"\d+\.?\d*%?",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def _regex_match(
|
|||
|
|
self,
|
|||
|
|
field_name: str,
|
|||
|
|
elements: list[OcrTextElement],
|
|||
|
|
) -> Optional[ExtractedField]:
|
|||
|
|
"""正则模式匹配: 根据字段名选择预定义的正则模式,在所有元素中搜索。"""
|
|||
|
|
pattern = self.PREDEFINED_PATTERNS.get(field_name)
|
|||
|
|
if not pattern:
|
|||
|
|
for key, pat in self.PREDEFINED_PATTERNS.items():
|
|||
|
|
if key in field_name or field_name in key:
|
|||
|
|
pattern = pat
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if not pattern:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
compiled = re.compile(r"^\s*" + pattern + r"\s*$")
|
|||
|
|
for elem in elements:
|
|||
|
|
if compiled.match(elem.text):
|
|||
|
|
return ExtractedField(
|
|||
|
|
field_name=field_name,
|
|||
|
|
field_value=elem.text.strip(),
|
|||
|
|
bbox=elem.bbox,
|
|||
|
|
confidence=0.7,
|
|||
|
|
extraction_method="",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
compiled_partial = re.compile(pattern)
|
|||
|
|
for elem in elements:
|
|||
|
|
m = compiled_partial.search(elem.text)
|
|||
|
|
if m:
|
|||
|
|
return ExtractedField(
|
|||
|
|
field_name=field_name,
|
|||
|
|
field_value=m.group(0),
|
|||
|
|
bbox=elem.bbox,
|
|||
|
|
confidence=0.6,
|
|||
|
|
extraction_method="",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# -----------------------------------------------------------------------
|
|||
|
|
# 策略4: 表格结构匹配
|
|||
|
|
# -----------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def _table_match(
|
|||
|
|
self,
|
|||
|
|
field_name: str,
|
|||
|
|
elements: list[OcrTextElement],
|
|||
|
|
) -> Optional[ExtractedField]:
|
|||
|
|
"""表格结构匹配: 将元素按行列分组,查找表头-值对应关系。
|
|||
|
|
|
|||
|
|
识别逻辑:
|
|||
|
|
1. 将元素按 Y 坐标分组为"行"
|
|||
|
|
2. 查找包含 field_name 的表头行
|
|||
|
|
3. 在表头列对应的数据行中取值
|
|||
|
|
"""
|
|||
|
|
if len(elements) < 3:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
rows = self._group_elements_by_rows(elements)
|
|||
|
|
if len(rows) < 2:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
header_row_idx = -1
|
|||
|
|
header_col_idx = -1
|
|||
|
|
|
|||
|
|
for ri, row in enumerate(rows):
|
|||
|
|
for ci, elem in enumerate(row):
|
|||
|
|
if field_name in elem.text:
|
|||
|
|
header_row_idx = ri
|
|||
|
|
header_col_idx = ci
|
|||
|
|
break
|
|||
|
|
if header_row_idx >= 0:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if header_row_idx < 0:
|
|||
|
|
for ri, row in enumerate(rows):
|
|||
|
|
for ci, elem in enumerate(row):
|
|||
|
|
sim = self._text_similarity(field_name, elem.text)
|
|||
|
|
if sim > 0.5:
|
|||
|
|
header_row_idx = ri
|
|||
|
|
header_col_idx = ci
|
|||
|
|
break
|
|||
|
|
if header_row_idx >= 0:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if header_row_idx < 0:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
data_rows = rows[header_row_idx + 1:]
|
|||
|
|
if not data_rows:
|
|||
|
|
data_rows = [rows[header_row_idx]]
|
|||
|
|
|
|||
|
|
matched_elem = None
|
|||
|
|
for row in data_rows:
|
|||
|
|
if header_col_idx < len(row):
|
|||
|
|
matched_elem = row[header_col_idx]
|
|||
|
|
break
|
|||
|
|
closest = None
|
|||
|
|
min_dist = float("inf")
|
|||
|
|
header_x = float("inf")
|
|||
|
|
if header_col_idx < len(rows[header_row_idx]):
|
|||
|
|
header_x = rows[header_row_idx][header_col_idx].center_x
|
|||
|
|
for elem in row:
|
|||
|
|
dist = abs(elem.center_x - header_x)
|
|||
|
|
if dist < min_dist:
|
|||
|
|
min_dist = dist
|
|||
|
|
closest = elem
|
|||
|
|
if closest:
|
|||
|
|
matched_elem = closest
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if matched_elem and matched_elem.text != field_name:
|
|||
|
|
return ExtractedField(
|
|||
|
|
field_name=field_name,
|
|||
|
|
field_value=matched_elem.text,
|
|||
|
|
bbox=matched_elem.bbox,
|
|||
|
|
confidence=0.55,
|
|||
|
|
extraction_method="",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# ========================================================================
|
|||
|
|
# 工具方法
|
|||
|
|
# ========================================================================
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _group_elements_by_rows(
|
|||
|
|
elements: list[OcrTextElement],
|
|||
|
|
) -> list[list[OcrTextElement]]:
|
|||
|
|
"""将元素按 Y 坐标分组为行(容差为元素平均高度的一半)。"""
|
|||
|
|
if not elements:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
avg_height = sum(e.height for e in elements) / len(elements)
|
|||
|
|
tolerance = max(avg_height * 0.5, 5.0)
|
|||
|
|
|
|||
|
|
rows = []
|
|||
|
|
current_row = [elements[0]]
|
|||
|
|
|
|||
|
|
for elem in elements[1:]:
|
|||
|
|
prev_center_y = current_row[0].center_y
|
|||
|
|
if abs(elem.center_y - prev_center_y) < tolerance:
|
|||
|
|
current_row.append(elem)
|
|||
|
|
else:
|
|||
|
|
current_row.sort(key=lambda e: e.x_min)
|
|||
|
|
rows.append(current_row)
|
|||
|
|
current_row = [elem]
|
|||
|
|
|
|||
|
|
if current_row:
|
|||
|
|
current_row.sort(key=lambda e: e.x_min)
|
|||
|
|
rows.append(current_row)
|
|||
|
|
|
|||
|
|
return rows
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _text_similarity(text1: str, text2: str) -> float:
|
|||
|
|
"""计算两个文本的简单相似度(公共字符比例)。"""
|
|||
|
|
if not text1 or not text2:
|
|||
|
|
return 0.0
|
|||
|
|
|
|||
|
|
t1 = text1.lower().strip()
|
|||
|
|
t2 = text2.lower().strip()
|
|||
|
|
|
|||
|
|
if t1 == t2:
|
|||
|
|
return 1.0
|
|||
|
|
if t1 in t2 or t2 in t1:
|
|||
|
|
return 0.8
|
|||
|
|
|
|||
|
|
chars1 = set(t1)
|
|||
|
|
chars2 = set(t2)
|
|||
|
|
if not chars1:
|
|||
|
|
return 0.0
|
|||
|
|
|
|||
|
|
intersection = chars1 & chars2
|
|||
|
|
return len(intersection) / len(chars1)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_ocr_fields(
|
|||
|
|
file_path: str,
|
|||
|
|
target_fields: list[str],
|
|||
|
|
use_gpu: bool = False,
|
|||
|
|
confidence_threshold: float = 0.5,
|
|||
|
|
) -> dict:
|
|||
|
|
"""便捷函数: 对指定图片执行 OCR 字段提取。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
file_path: 图片文件路径
|
|||
|
|
target_fields: 目标字段名列表
|
|||
|
|
use_gpu: 是否使用 GPU 加速
|
|||
|
|
confidence_threshold: OCR 置信度阈值
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
提取结果字典
|
|||
|
|
"""
|
|||
|
|
extractor = OcrExtractor(
|
|||
|
|
use_gpu=use_gpu,
|
|||
|
|
confidence_threshold=confidence_threshold,
|
|||
|
|
)
|
|||
|
|
return extractor.extract(file_path, target_fields)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_from_layout(
|
|||
|
|
layout_result: dict,
|
|||
|
|
target_fields: list[str],
|
|||
|
|
confidence_threshold: float = 0.5,
|
|||
|
|
) -> dict:
|
|||
|
|
"""便捷函数: 从已有的版面分析结果中提取字段。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
layout_result: analyze_layout() 的返回值
|
|||
|
|
target_fields: 目标字段名列表
|
|||
|
|
confidence_threshold: OCR 置信度阈值
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
提取结果字典
|
|||
|
|
"""
|
|||
|
|
extractor = OcrExtractor(confidence_threshold=confidence_threshold)
|
|||
|
|
return extractor.extract_from_layout_result(layout_result, target_fields)
|