358 lines
13 KiB
Python
358 lines
13 KiB
Python
|
|
"""
|
|||
|
|
md_chunker.py
|
|||
|
|
Markdown 语义分块器
|
|||
|
|
支持标题层级、代码块、表格等元素的智能分块
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import re
|
|||
|
|
from typing import List, Dict, Tuple
|
|||
|
|
from pathlib import Path
|
|||
|
|
from dataclasses import dataclass, field, asdict
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class MDChunk:
|
|||
|
|
"""Single Markdown chunk data structure"""
|
|||
|
|
chunk_id: int
|
|||
|
|
chunk_type: str
|
|||
|
|
human_description: str
|
|||
|
|
raw_content: str
|
|||
|
|
context: str
|
|||
|
|
metadata: Dict = field(default_factory=dict)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MarkdownSemanticChunker:
|
|||
|
|
"""
|
|||
|
|
Markdown 语义分块器 v1.0
|
|||
|
|
分块策略:
|
|||
|
|
1. 按标题层级(H1/H2/H3...)划分大段落
|
|||
|
|
2. 代码块作为独立 chunk
|
|||
|
|
3. 表格作为独立 chunk
|
|||
|
|
4. 过长段落内部按句子/段落二次拆分
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# Heading patterns
|
|||
|
|
HEADING_PATTERN = re.compile(r'^(#{1,6})\s+(.+)$', re.MULTILINE)
|
|||
|
|
|
|||
|
|
# Code block pattern (fenced)
|
|||
|
|
CODE_BLOCK_PATTERN = re.compile(r'```(\w*)\n([\s\S]*?)```', re.MULTILINE)
|
|||
|
|
|
|||
|
|
# Inline code pattern
|
|||
|
|
INLINE_CODE_PATTERN = re.compile(r'`([^`]+)`')
|
|||
|
|
|
|||
|
|
# Table pattern
|
|||
|
|
TABLE_PATTERN = re.compile(r'\|.+\|\n\|[-| :]+\|\n((?:\|.+\|\n)*)', re.MULTILINE)
|
|||
|
|
|
|||
|
|
# List pattern
|
|||
|
|
LIST_PATTERN = re.compile(r'^(\s*[-*+]\s+.+)+', re.MULTILINE)
|
|||
|
|
|
|||
|
|
def __init__(self, max_chunk_size: int = 2000):
|
|||
|
|
self.max_chunk_size = max_chunk_size
|
|||
|
|
|
|||
|
|
def chunk_file(self, file_path: str) -> List[Dict]:
|
|||
|
|
"""处理单个 Markdown 文件"""
|
|||
|
|
if not os.path.exists(file_path):
|
|||
|
|
raise FileNotFoundError(f"File not found: {file_path}")
|
|||
|
|
|
|||
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|||
|
|
content = f.read()
|
|||
|
|
|
|||
|
|
file_name = Path(file_path).stem
|
|||
|
|
chunks = []
|
|||
|
|
chunk_id = 0
|
|||
|
|
|
|||
|
|
# 尝试提取文档标题(第一个 H1)
|
|||
|
|
title_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
|
|||
|
|
doc_title = title_match.group(1).strip() if title_match else file_name
|
|||
|
|
|
|||
|
|
# 按结构化元素分割
|
|||
|
|
segments = self._split_by_structure(content)
|
|||
|
|
|
|||
|
|
for segment in segments:
|
|||
|
|
seg_type = segment['type']
|
|||
|
|
seg_content = segment['content']
|
|||
|
|
|
|||
|
|
if not seg_content.strip():
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 构建描述
|
|||
|
|
description = self._build_description(seg_type, seg_content, doc_title)
|
|||
|
|
|
|||
|
|
# 如果超过最大长度,尝试二次拆分
|
|||
|
|
if len(seg_content) > self.max_chunk_size:
|
|||
|
|
sub_chunks = self._split_large_chunk(
|
|||
|
|
seg_content, seg_type, doc_title, chunk_id
|
|||
|
|
)
|
|||
|
|
chunks.extend([asdict(c) for c in sub_chunks])
|
|||
|
|
chunk_id += len(sub_chunks)
|
|||
|
|
else:
|
|||
|
|
chunks.append(asdict(MDChunk(
|
|||
|
|
chunk_id=chunk_id,
|
|||
|
|
chunk_type=seg_type,
|
|||
|
|
human_description=description,
|
|||
|
|
raw_content=seg_content.strip(),
|
|||
|
|
context=f"{doc_title}",
|
|||
|
|
metadata=segment.get('metadata', {})
|
|||
|
|
)))
|
|||
|
|
chunk_id += 1
|
|||
|
|
|
|||
|
|
return chunks
|
|||
|
|
|
|||
|
|
def _split_by_structure(self, content: str) -> List[Dict]:
|
|||
|
|
"""
|
|||
|
|
按 Markdown 结构分割内容
|
|||
|
|
返回: [{'type': 'h1/h2/code/table/paragraph', 'content': '...', 'metadata': {...}}]
|
|||
|
|
"""
|
|||
|
|
segments = []
|
|||
|
|
|
|||
|
|
# 首先提取所有代码块(保留位置标记,稍后处理)
|
|||
|
|
code_blocks = []
|
|||
|
|
code_pattern = re.compile(r'(```\w*\n[\s\S]*?```)', re.MULTILINE)
|
|||
|
|
|
|||
|
|
last_end = 0
|
|||
|
|
for match in code_pattern.finditer(content):
|
|||
|
|
# 处理代码块前的普通文本
|
|||
|
|
before = content[last_end:match.start()]
|
|||
|
|
if before.strip():
|
|||
|
|
segments.extend(self._process_text_section(before))
|
|||
|
|
|
|||
|
|
# 添加代码块
|
|||
|
|
code_blocks.append(match.group(1))
|
|||
|
|
lang_match = re.match(r'```(\w*)', match.group(1))
|
|||
|
|
lang = lang_match.group(1) if lang_match else ''
|
|||
|
|
segments.append({
|
|||
|
|
'type': 'code',
|
|||
|
|
'content': match.group(1),
|
|||
|
|
'metadata': {'language': lang}
|
|||
|
|
})
|
|||
|
|
last_end = match.end()
|
|||
|
|
|
|||
|
|
# 处理剩余文本
|
|||
|
|
remaining = content[last_end:]
|
|||
|
|
if remaining.strip():
|
|||
|
|
segments.extend(self._process_text_section(remaining))
|
|||
|
|
|
|||
|
|
return segments
|
|||
|
|
|
|||
|
|
def _process_text_section(self, text: str) -> List[Dict]:
|
|||
|
|
"""处理普通文本区域,提取标题和段落"""
|
|||
|
|
segments = []
|
|||
|
|
|
|||
|
|
# 按标题分割
|
|||
|
|
lines = text.split('\n')
|
|||
|
|
current_section = []
|
|||
|
|
current_heading_level = 0
|
|||
|
|
current_heading = ''
|
|||
|
|
|
|||
|
|
for line in lines:
|
|||
|
|
heading_match = re.match(r'^(#{1,6})\s+(.+)', line)
|
|||
|
|
if heading_match:
|
|||
|
|
# 保存之前的段落
|
|||
|
|
if current_section:
|
|||
|
|
section_text = '\n'.join(current_section).strip()
|
|||
|
|
if section_text:
|
|||
|
|
segments.append({
|
|||
|
|
'type': self._get_section_type(current_heading_level, current_heading),
|
|||
|
|
'content': section_text,
|
|||
|
|
'metadata': {
|
|||
|
|
'heading': current_heading,
|
|||
|
|
'heading_level': current_heading_level
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
current_section = []
|
|||
|
|
|
|||
|
|
# 开始新标题区域
|
|||
|
|
current_heading_level = len(heading_match.group(1))
|
|||
|
|
current_heading = heading_match.group(2).strip()
|
|||
|
|
else:
|
|||
|
|
current_section.append(line)
|
|||
|
|
|
|||
|
|
# 保存最后一段
|
|||
|
|
if current_section:
|
|||
|
|
section_text = '\n'.join(current_section).strip()
|
|||
|
|
if section_text:
|
|||
|
|
segments.append({
|
|||
|
|
'type': self._get_section_type(current_heading_level, current_heading),
|
|||
|
|
'content': section_text,
|
|||
|
|
'metadata': {
|
|||
|
|
'heading': current_heading,
|
|||
|
|
'heading_level': current_heading_level
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return segments
|
|||
|
|
|
|||
|
|
def _get_section_type(self, level: int, heading: str) -> str:
|
|||
|
|
"""根据标题级别和内容确定段落类型"""
|
|||
|
|
heading_lower = heading.lower()
|
|||
|
|
|
|||
|
|
if level == 1:
|
|||
|
|
return 'section_h1'
|
|||
|
|
elif level == 2:
|
|||
|
|
# 检测特殊章节类型
|
|||
|
|
if any(kw in heading_lower for kw in ['install', '安装', 'setup', '部署']):
|
|||
|
|
return 'section_installation'
|
|||
|
|
elif any(kw in heading_lower for kw in ['config', '配置', 'setting']):
|
|||
|
|
return 'section_configuration'
|
|||
|
|
elif any(kw in heading_lower for kw in ['api', '接口']):
|
|||
|
|
return 'section_api'
|
|||
|
|
elif any(kw in heading_lower for kw in ['example', '示例', 'usage', '使用']):
|
|||
|
|
return 'section_example'
|
|||
|
|
elif any(kw in heading_lower for kw in ['faq', 'question', '问题', '常见']):
|
|||
|
|
return 'section_faq'
|
|||
|
|
elif any(kw in heading_lower for kw in ['changelog', '更新', 'release']):
|
|||
|
|
return 'section_changelog'
|
|||
|
|
return 'section_h2'
|
|||
|
|
elif level == 3:
|
|||
|
|
return 'section_h3'
|
|||
|
|
else:
|
|||
|
|
return 'section_other'
|
|||
|
|
|
|||
|
|
def _build_description(self, chunk_type: str, content: str, doc_title: str) -> str:
|
|||
|
|
"""为 chunk 生成人类可读描述"""
|
|||
|
|
lines = content.split('\n')[:5]
|
|||
|
|
preview = ' '.join(line.strip() for line in lines if line.strip())[:150]
|
|||
|
|
|
|||
|
|
if chunk_type == 'code':
|
|||
|
|
lang = ''
|
|||
|
|
lang_match = re.match(r'```(\w*)', content)
|
|||
|
|
if lang_match:
|
|||
|
|
lang = lang_match.group(1) or 'text'
|
|||
|
|
return f"Code block (language: {lang}) in {doc_title}. Preview: {preview}"
|
|||
|
|
|
|||
|
|
elif chunk_type.startswith('section_'):
|
|||
|
|
heading = content.split('\n')[0] if '\n' in content else content[:50]
|
|||
|
|
heading_clean = re.sub(r'^#+\s+', '', heading)
|
|||
|
|
type_hint = chunk_type.replace('section_', '')
|
|||
|
|
return f"[{type_hint.upper()}] {heading_clean}. Content: {preview}"
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
return f"Document section in {doc_title}. Content: {preview}"
|
|||
|
|
|
|||
|
|
def _split_large_chunk(self, content: str, chunk_type: str,
|
|||
|
|
doc_title: str, start_id: int) -> List[MDChunk]:
|
|||
|
|
"""对过长的 chunk 进行二次拆分"""
|
|||
|
|
chunks = []
|
|||
|
|
|
|||
|
|
# 按段落分割(双换行符)
|
|||
|
|
paragraphs = re.split(r'\n\n+', content)
|
|||
|
|
current_chunk = []
|
|||
|
|
current_size = 0
|
|||
|
|
|
|||
|
|
for para in paragraphs:
|
|||
|
|
para_size = len(para)
|
|||
|
|
|
|||
|
|
if current_size + para_size > self.max_chunk_size and current_chunk:
|
|||
|
|
# 当前块已满,生成 chunk
|
|||
|
|
chunk_text = '\n\n'.join(current_chunk)
|
|||
|
|
chunks.append(MDChunk(
|
|||
|
|
chunk_id=start_id + len(chunks),
|
|||
|
|
chunk_type=f"{chunk_type}_part",
|
|||
|
|
human_description=f"Part of {doc_title} ({chunk_type}): {chunk_text[:100]}...",
|
|||
|
|
raw_content=chunk_text,
|
|||
|
|
context=f"{doc_title} (continued)",
|
|||
|
|
metadata={'part': len(chunks) + 1, 'original_type': chunk_type}
|
|||
|
|
))
|
|||
|
|
current_chunk = []
|
|||
|
|
current_size = 0
|
|||
|
|
|
|||
|
|
current_chunk.append(para)
|
|||
|
|
current_size += para_size + 2
|
|||
|
|
|
|||
|
|
# 处理剩余内容
|
|||
|
|
if current_chunk:
|
|||
|
|
chunk_text = '\n\n'.join(current_chunk)
|
|||
|
|
chunks.append(MDChunk(
|
|||
|
|
chunk_id=start_id + len(chunks),
|
|||
|
|
chunk_type=f"{chunk_type}_part",
|
|||
|
|
human_description=f"Part of {doc_title} ({chunk_type}): {chunk_text[:100]}...",
|
|||
|
|
raw_content=chunk_text,
|
|||
|
|
context=f"{doc_title} (continued)",
|
|||
|
|
metadata={'part': len(chunks) + 1, 'original_type': chunk_type}
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
return chunks if chunks else [MDChunk(
|
|||
|
|
chunk_id=start_id,
|
|||
|
|
chunk_type=chunk_type,
|
|||
|
|
human_description=f"{doc_title}: {content[:100]}...",
|
|||
|
|
raw_content=content[:self.max_chunk_size],
|
|||
|
|
context=doc_title,
|
|||
|
|
metadata={'truncated': True}
|
|||
|
|
)]
|
|||
|
|
|
|||
|
|
def chunk_directory(self, dir_path: str, extensions: tuple = ('.md', '.markdown')) -> List[Dict]:
|
|||
|
|
"""批量处理目录下所有 Markdown 文件"""
|
|||
|
|
all_chunks = []
|
|||
|
|
file_count = 0
|
|||
|
|
|
|||
|
|
for root, _, files in os.walk(dir_path):
|
|||
|
|
for file in files:
|
|||
|
|
if file.lower().endswith(extensions):
|
|||
|
|
file_path = os.path.join(root, file)
|
|||
|
|
try:
|
|||
|
|
chunks = self.chunk_file(file_path)
|
|||
|
|
all_chunks.extend(chunks)
|
|||
|
|
file_count += 1
|
|||
|
|
print(f"OK {file_path}: {len(chunks)} chunks")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"FAIL {file_path}: {e}")
|
|||
|
|
|
|||
|
|
print(f"\nTotal: {file_count} files, {len(all_chunks)} chunks")
|
|||
|
|
return all_chunks
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_chunks_to_json(chunks: List[Dict], output_path: str):
|
|||
|
|
"""保存 chunks 到 JSON 文件"""
|
|||
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(chunks, f, ensure_ascii=False, indent=2)
|
|||
|
|
print(f"Saved {len(chunks)} chunks to {output_path}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def print_chunk_summary(chunks: List[Dict]):
|
|||
|
|
"""打印 chunk 类型统计"""
|
|||
|
|
type_counts = {}
|
|||
|
|
for chunk in chunks:
|
|||
|
|
chunk_type = chunk["chunk_type"]
|
|||
|
|
type_counts[chunk_type] = type_counts.get(chunk_type, 0) + 1
|
|||
|
|
|
|||
|
|
print("\nChunk Type Summary:")
|
|||
|
|
for chunk_type, count in sorted(type_counts.items(), key=lambda x: -x[1]):
|
|||
|
|
print(f" {chunk_type}: {count}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
import sys
|
|||
|
|
|
|||
|
|
chunker = MarkdownSemanticChunker(max_chunk_size=2000)
|
|||
|
|
|
|||
|
|
if len(sys.argv) > 1:
|
|||
|
|
path = sys.argv[1]
|
|||
|
|
if os.path.isdir(path):
|
|||
|
|
all_chunks = chunker.chunk_directory(path)
|
|||
|
|
output_path = os.path.join(os.path.dirname(path.rstrip("/\\")) if os.path.dirname(path) else ".",
|
|||
|
|
os.path.basename(path.rstrip("/\\")) + "_md_chunks.json")
|
|||
|
|
save_chunks_to_json(all_chunks, output_path)
|
|||
|
|
print_chunk_summary(all_chunks)
|
|||
|
|
else:
|
|||
|
|
chunks = chunker.chunk_file(path)
|
|||
|
|
output_path = path.replace(".md", "_chunks.json").replace(".markdown", "_chunks.json")
|
|||
|
|
save_chunks_to_json(chunks, output_path)
|
|||
|
|
|
|||
|
|
print(f"\n{'='*60}")
|
|||
|
|
print("Chunking Results Preview")
|
|||
|
|
print(f"{'='*60}")
|
|||
|
|
for chunk in chunks[:10]:
|
|||
|
|
print(f"\n[Chunk {chunk['chunk_id']}] Type: {chunk['chunk_type']}")
|
|||
|
|
print(f"Description: {chunk['human_description'][:120]}...")
|
|||
|
|
print(f"Content length: {len(chunk['raw_content'])} chars")
|
|||
|
|
if len(chunks) > 10:
|
|||
|
|
print(f"\n... and {len(chunks) - 10} more chunks")
|
|||
|
|
|
|||
|
|
print_chunk_summary(chunks)
|
|||
|
|
else:
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("Markdown Semantic Chunking v1.0")
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("\nUsage: python md_chunker.py <md_file_or_directory>")
|