88 lines
2.2 KiB
Python
88 lines
2.2 KiB
Python
|
|
"""初始化 Chroma 知识库,加载示例 JRXML 模板和错误修正案例。
|
||
|
|
|
||
|
|
用法: python scripts/init_kb.py
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
from dotenv import load_dotenv
|
||
|
|
|
||
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
|
|
||
|
|
load_dotenv()
|
||
|
|
|
||
|
|
from backend.embeddings import get_embeddings
|
||
|
|
|
||
|
|
|
||
|
|
def load_templates(template_dir: Path) -> list[dict]:
|
||
|
|
docs = []
|
||
|
|
for fpath in template_dir.glob('*.jrxml'):
|
||
|
|
content = fpath.read_text(encoding='utf-8')
|
||
|
|
name = fpath.stem
|
||
|
|
docs.append({
|
||
|
|
'content': content,
|
||
|
|
'metadata': {
|
||
|
|
'source': str(fpath),
|
||
|
|
'type': 'full_report',
|
||
|
|
'name': name,
|
||
|
|
},
|
||
|
|
})
|
||
|
|
return docs
|
||
|
|
|
||
|
|
|
||
|
|
def load_corrections(corrections_dir: Path) -> list[dict]:
|
||
|
|
docs = []
|
||
|
|
for fpath in corrections_dir.glob('*.jrxml'):
|
||
|
|
content = fpath.read_text(encoding='utf-8')
|
||
|
|
docs.append({
|
||
|
|
'content': content,
|
||
|
|
'metadata': {
|
||
|
|
'source': str(fpath),
|
||
|
|
'type': 'correction_case',
|
||
|
|
'name': fpath.stem,
|
||
|
|
},
|
||
|
|
})
|
||
|
|
return docs
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
persist_dir = os.getenv('CHROMA_PERSIST_DIR', './db/chroma')
|
||
|
|
data_dir = Path(__file__).parent.parent / 'data'
|
||
|
|
|
||
|
|
template_dir = data_dir / 'sample_templates'
|
||
|
|
corrections_dir = data_dir / 'corrections'
|
||
|
|
|
||
|
|
docs = []
|
||
|
|
if template_dir.exists():
|
||
|
|
docs.extend(load_templates(template_dir))
|
||
|
|
print(f'从 {template_dir} 加载了 {len(docs)} 个模板')
|
||
|
|
|
||
|
|
if corrections_dir.exists():
|
||
|
|
corr = load_corrections(corrections_dir)
|
||
|
|
docs.extend(corr)
|
||
|
|
print(f'从 {corrections_dir} 加载了 {len(corr)} 个修正案例')
|
||
|
|
|
||
|
|
if not docs:
|
||
|
|
print('未找到文档,无需索引。')
|
||
|
|
return
|
||
|
|
|
||
|
|
embeddings = get_embeddings()
|
||
|
|
from langchain_chroma import Chroma
|
||
|
|
|
||
|
|
texts = [d['content'] for d in docs]
|
||
|
|
metadatas = [d['metadata'] for d in docs]
|
||
|
|
|
||
|
|
Chroma.from_texts(
|
||
|
|
texts=texts,
|
||
|
|
embedding=embeddings,
|
||
|
|
metadatas=metadatas,
|
||
|
|
persist_directory=persist_dir,
|
||
|
|
)
|
||
|
|
print(f'已将 {len(docs)} 个文档索引到 Chroma,存储位置: {persist_dir}')
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
main()
|