215 lines
8.9 KiB
Python
215 lines
8.9 KiB
Python
|
|
"""kb_searcher.py 测试 — KBChromaSearcher 创建, 搜索, 模板检索。"""
|
||
|
|
|
||
|
|
import sys
|
||
|
|
import tempfile
|
||
|
|
from pathlib import Path
|
||
|
|
from unittest.mock import patch, MagicMock
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
|
|
||
|
|
from backend.kb_searcher import (
|
||
|
|
KBChromaSearcher, get_kb_searcher, search_kb, search_templates_in_kb,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_chromadb(monkeypatch):
|
||
|
|
mock_client = MagicMock()
|
||
|
|
mock_collection = MagicMock()
|
||
|
|
mock_client.get_collection.return_value = mock_collection
|
||
|
|
mock_client.create_collection.return_value = mock_collection
|
||
|
|
monkeypatch.setattr(
|
||
|
|
"chromadb.PersistentClient",
|
||
|
|
lambda path: mock_client)
|
||
|
|
mock_st = MagicMock()
|
||
|
|
mock_st_model = MagicMock()
|
||
|
|
mock_st_model.encode.return_value = MagicMock()
|
||
|
|
mock_st_model.encode.return_value.tolist.return_value = [0.1, 0.2, 0.3]
|
||
|
|
mock_st.return_value = mock_st_model
|
||
|
|
monkeypatch.setattr("sentence_transformers.SentenceTransformer", mock_st)
|
||
|
|
yield {"client": mock_client, "collection": mock_collection,
|
||
|
|
"st_model": mock_st_model, "st": mock_st}
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def searcher(mock_chromadb):
|
||
|
|
with tempfile.TemporaryDirectory(prefix="test_chroma_") as tmpdir:
|
||
|
|
s = KBChromaSearcher(chroma_path=tmpdir, collection_name="test_kb")
|
||
|
|
s._model = mock_chromadb["st_model"]
|
||
|
|
s._client = mock_chromadb["client"]
|
||
|
|
s._collection = mock_chromadb["collection"]
|
||
|
|
yield s
|
||
|
|
|
||
|
|
|
||
|
|
# ── 创建 & 就绪检查 ─────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestKBChromaSearcherInit:
|
||
|
|
def test_creates_with_defaults(self, mock_chromadb):
|
||
|
|
with tempfile.TemporaryDirectory(prefix="test_chroma_") as tmpdir:
|
||
|
|
s = KBChromaSearcher(chroma_path=tmpdir)
|
||
|
|
assert s.collection_name == "kb_chunks"
|
||
|
|
assert s.chroma_path == str(tmpdir)
|
||
|
|
|
||
|
|
def test_custom_collection_name(self, mock_chromadb):
|
||
|
|
with tempfile.TemporaryDirectory(prefix="test_chroma_") as tmpdir:
|
||
|
|
s = KBChromaSearcher(chroma_path=tmpdir, collection_name="custom")
|
||
|
|
assert s.collection_name == "custom"
|
||
|
|
|
||
|
|
def test_model_lazy_loaded(self, mock_chromadb):
|
||
|
|
with tempfile.TemporaryDirectory(prefix="test_chroma_") as tmpdir:
|
||
|
|
s = KBChromaSearcher(chroma_path=tmpdir)
|
||
|
|
assert s._model is None
|
||
|
|
|
||
|
|
def test_is_ready_true(self, searcher):
|
||
|
|
assert searcher.is_ready() is True
|
||
|
|
|
||
|
|
def test_is_ready_false(self, searcher, mock_chromadb):
|
||
|
|
mock_chromadb["client"].get_collection.side_effect = Exception("no collection")
|
||
|
|
assert searcher.is_ready() is False
|
||
|
|
|
||
|
|
|
||
|
|
# ── 搜索 ────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestSearch:
|
||
|
|
def test_search_returns_empty_when_not_ready(self, searcher, mock_chromadb):
|
||
|
|
mock_chromadb["client"].get_collection.side_effect = Exception("no collection")
|
||
|
|
results = searcher.search("test query")
|
||
|
|
assert results == []
|
||
|
|
|
||
|
|
def test_search_calls_collection_query(self, searcher, mock_chromadb):
|
||
|
|
mock_chromadb["collection"].query.return_value = {
|
||
|
|
"ids": [["chunk_0", "chunk_1"]],
|
||
|
|
"documents": [["doc1", "doc2"]],
|
||
|
|
"metadatas": [[{"chunk_type": "md"}, {"chunk_type": "txt"}]],
|
||
|
|
"distances": [[0.1, 0.3]],
|
||
|
|
}
|
||
|
|
results = searcher.search("query", k=5)
|
||
|
|
assert len(results) == 2
|
||
|
|
assert results[0]["id"] == "chunk_0"
|
||
|
|
assert results[0]["content"] == "doc1"
|
||
|
|
assert results[0]["metadata"]["chunk_type"] == "md"
|
||
|
|
assert results[0]["distance"] == 0.1
|
||
|
|
|
||
|
|
def test_search_respects_threshold(self, searcher, mock_chromadb):
|
||
|
|
mock_chromadb["collection"].query.return_value = {
|
||
|
|
"ids": [["chunk_0", "chunk_1"]],
|
||
|
|
"documents": [["doc1", "doc2"]],
|
||
|
|
"metadatas": [[{}, {}]],
|
||
|
|
"distances": [[0.2, 0.8]],
|
||
|
|
}
|
||
|
|
results = searcher.search("query", threshold=0.5)
|
||
|
|
assert len(results) == 1
|
||
|
|
assert results[0]["id"] == "chunk_0"
|
||
|
|
|
||
|
|
def test_search_empty_results(self, searcher, mock_chromadb):
|
||
|
|
mock_chromadb["collection"].query.return_value = {
|
||
|
|
"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]],
|
||
|
|
}
|
||
|
|
assert searcher.search("query") == []
|
||
|
|
|
||
|
|
|
||
|
|
# ── 模板搜索 ────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestSearchTemplates:
|
||
|
|
def test_filters_jrxml_chunks(self, searcher, mock_chromadb):
|
||
|
|
mock_chromadb["collection"].query.return_value = {
|
||
|
|
"ids": [["c0", "c1", "c2"]],
|
||
|
|
"documents": [["t1", "t2", "t3"]],
|
||
|
|
"metadatas": [[
|
||
|
|
{"chunk_type": "jrxml_template", "report_name": "R1"},
|
||
|
|
{"chunk_type": "md_section"},
|
||
|
|
{"chunk_type": "jrxml_template", "report_name": "R2"},
|
||
|
|
]],
|
||
|
|
"distances": [[0.1, 0.2, 0.3]],
|
||
|
|
}
|
||
|
|
templates = searcher.search_templates("query", k=3)
|
||
|
|
assert len(templates) >= 1
|
||
|
|
for t in templates:
|
||
|
|
meta = t["metadata"]
|
||
|
|
assert "jrxml" in meta.get("chunk_type", "").lower() or meta.get("report_name")
|
||
|
|
|
||
|
|
def test_no_jrxml_chunks_returns_empty(self, searcher, mock_chromadb):
|
||
|
|
mock_chromadb["collection"].query.return_value = {
|
||
|
|
"ids": [["c0"]],
|
||
|
|
"documents": [["text"]],
|
||
|
|
"metadatas": [[{"chunk_type": "md_section"}]],
|
||
|
|
"distances": [[0.1]],
|
||
|
|
}
|
||
|
|
templates = searcher.search_templates("query")
|
||
|
|
assert templates == []
|
||
|
|
|
||
|
|
|
||
|
|
# ── search_as_context ───────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestSearchAsContext:
|
||
|
|
def test_returns_formatted_context(self, searcher, mock_chromadb):
|
||
|
|
mock_chromadb["collection"].query.return_value = {
|
||
|
|
"ids": [["c0", "c1"]],
|
||
|
|
"documents": [["内容1", "内容2"]],
|
||
|
|
"metadatas": [[
|
||
|
|
{"chunk_type": "md", "report_name": "R1"},
|
||
|
|
{"chunk_type": "txt"},
|
||
|
|
]],
|
||
|
|
"distances": [[0.1, 0.2]],
|
||
|
|
}
|
||
|
|
ctx = searcher.search_as_context("q", k=2)
|
||
|
|
assert "内容1" in ctx
|
||
|
|
assert "内容2" in ctx
|
||
|
|
assert "类型" in ctx
|
||
|
|
assert "报表" in ctx
|
||
|
|
assert "---" in ctx
|
||
|
|
|
||
|
|
def test_empty_returns_empty_string(self, searcher, mock_chromadb):
|
||
|
|
mock_chromadb["collection"].query.return_value = {
|
||
|
|
"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]],
|
||
|
|
}
|
||
|
|
assert searcher.search_as_context("q") == ""
|
||
|
|
|
||
|
|
|
||
|
|
# ── add_chunks ──────────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestAddChunks:
|
||
|
|
def test_add_chunks_calls_upsert(self, searcher, mock_chromadb):
|
||
|
|
chunks = [{"id": "c0", "content": "test content", "metadata": {"chunk_type": "md"}}]
|
||
|
|
searcher.add_chunks(chunks)
|
||
|
|
mock_chromadb["collection"].upsert.assert_called_once()
|
||
|
|
kwargs = mock_chromadb["collection"].upsert.call_args[1]
|
||
|
|
assert kwargs["ids"] == ["c0"]
|
||
|
|
assert kwargs["documents"] == ["test content"]
|
||
|
|
|
||
|
|
def test_empty_chunks_noop(self, searcher, mock_chromadb):
|
||
|
|
searcher.add_chunks([])
|
||
|
|
mock_chromadb["collection"].upsert.assert_not_called()
|
||
|
|
|
||
|
|
|
||
|
|
# ── 工厂函数 ────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestGetKbSearcher:
|
||
|
|
def test_returns_cached_instance(self, monkeypatch, mock_chromadb):
|
||
|
|
with tempfile.TemporaryDirectory(prefix="test_chroma_") as tmpdir:
|
||
|
|
monkeypatch.setattr(
|
||
|
|
"backend.kb_manager.get_kb_chroma_path",
|
||
|
|
lambda kb_id: Path(tmpdir) if kb_id == "abcdef1234567890abcd" else None)
|
||
|
|
s1 = get_kb_searcher("abcdef1234567890abcd")
|
||
|
|
s2 = get_kb_searcher("abcdef1234567890abcd")
|
||
|
|
assert s1 is s2
|
||
|
|
|
||
|
|
def test_returns_none_for_invalid_kb(self, monkeypatch):
|
||
|
|
monkeypatch.setattr(
|
||
|
|
"backend.kb_manager.get_kb_chroma_path", lambda kb_id: None)
|
||
|
|
assert get_kb_searcher("deadbeef1234567890abcd") is None
|
||
|
|
|
||
|
|
|
||
|
|
class TestSearchKbFunction:
|
||
|
|
def test_returns_empty_for_invalid_kb(self, monkeypatch):
|
||
|
|
monkeypatch.setattr(
|
||
|
|
"backend.kb_manager.get_kb_chroma_path", lambda kb_id: None)
|
||
|
|
assert search_kb("deadbeef1234567890abcd", "query") == ""
|
||
|
|
|
||
|
|
def test_returns_empty_for_invalid_template_search(self, monkeypatch):
|
||
|
|
monkeypatch.setattr(
|
||
|
|
"backend.kb_manager.get_kb_chroma_path", lambda kb_id: None)
|
||
|
|
assert search_templates_in_kb("deadbeef1234567890abcd", "query") == []
|