162 lines
6.3 KiB
Python
162 lines
6.3 KiB
Python
|
|
"""datasource.py 测试 — 数据源模式解析, JDBC 检测, 上下文构建。"""
|
||
|
|
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
from unittest.mock import MagicMock
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
|
|
||
|
|
from agent.datasource import (
|
||
|
|
resolve_datasource_mode, _detect_jdbc_intent,
|
||
|
|
build_datasource_context, configure_jdbc, ask_db_config,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _make_state(**overrides):
|
||
|
|
s = {
|
||
|
|
"user_input": "",
|
||
|
|
"conversation_history": [],
|
||
|
|
"current_jrxml": "",
|
||
|
|
"status": "",
|
||
|
|
"error_msg": "",
|
||
|
|
"natural_explanation": "",
|
||
|
|
"retry_count": 0,
|
||
|
|
"user_modification_request": "",
|
||
|
|
"final_jrxml": "",
|
||
|
|
"stage": "",
|
||
|
|
"retrieved_context": "",
|
||
|
|
**overrides,
|
||
|
|
}
|
||
|
|
return s
|
||
|
|
|
||
|
|
|
||
|
|
# ── JDBC 意图检测 ───────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestDetectJdbcIntent:
|
||
|
|
def test_direct_connect_keywords(self):
|
||
|
|
assert _detect_jdbc_intent("我想从数据库直连查询") is True
|
||
|
|
assert _detect_jdbc_intent("直连数据库获取数据") is True
|
||
|
|
|
||
|
|
def test_db_name_mentions(self):
|
||
|
|
assert _detect_jdbc_intent("从MySQL数据库查询用户表") is True
|
||
|
|
assert _detect_jdbc_intent("在PostgreSQL中执行查询") is True
|
||
|
|
assert _detect_jdbc_intent("从Oracle读取数据") is True
|
||
|
|
|
||
|
|
def test_jdbc_explicit_mention(self):
|
||
|
|
assert _detect_jdbc_intent("使用JDBC连接") is True
|
||
|
|
|
||
|
|
def test_sql_keywords(self):
|
||
|
|
assert _detect_jdbc_intent("SELECT * FROM users") is True
|
||
|
|
assert _detect_jdbc_intent("从数据库查询用户表") is True
|
||
|
|
assert _detect_jdbc_intent("先查询 数据库") is True
|
||
|
|
|
||
|
|
def test_normal_request_is_not_jdbc(self):
|
||
|
|
assert _detect_jdbc_intent("生成一个员工报表") is False
|
||
|
|
assert _detect_jdbc_intent("修改标题为XX公司") is False
|
||
|
|
|
||
|
|
def test_empty_input(self):
|
||
|
|
assert _detect_jdbc_intent("") is False
|
||
|
|
|
||
|
|
|
||
|
|
# ── 模式解析 ────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestResolveDatasourceMode:
|
||
|
|
def test_defaults_to_parameter_mode(self):
|
||
|
|
state = _make_state(user_input="生成报表")
|
||
|
|
assert resolve_datasource_mode(state) == "parameter"
|
||
|
|
|
||
|
|
def test_detects_jdbc_from_input(self):
|
||
|
|
state = _make_state(user_input="从数据库直连查询")
|
||
|
|
assert resolve_datasource_mode(state) == "jdbc"
|
||
|
|
|
||
|
|
def test_respects_existing_mode_in_state(self):
|
||
|
|
state = _make_state(datasource_mode="jdbc", user_input="生成报表")
|
||
|
|
assert resolve_datasource_mode(state) == "jdbc"
|
||
|
|
|
||
|
|
def test_existing_parameter_overrides_jdbc_input(self):
|
||
|
|
state = _make_state(datasource_mode="parameter", user_input="从数据库直连")
|
||
|
|
assert resolve_datasource_mode(state) == "parameter"
|
||
|
|
|
||
|
|
def test_ignores_invalid_mode_in_state(self):
|
||
|
|
state = _make_state(datasource_mode="unknown", user_input="从数据库直连")
|
||
|
|
assert resolve_datasource_mode(state) == "jdbc"
|
||
|
|
|
||
|
|
|
||
|
|
# ── 上下文构建 ──────────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestBuildDatasourceContext:
|
||
|
|
def test_parameter_mode_with_fields(self):
|
||
|
|
fields = [
|
||
|
|
{"name": "billNo", "description": "工单号", "type": "java.lang.String"},
|
||
|
|
{"name": "amount", "description": "金额", "type": "java.math.BigDecimal"},
|
||
|
|
]
|
||
|
|
ctx = build_datasource_context("parameter", fields)
|
||
|
|
assert "[数据源模式: 参数]" in ctx
|
||
|
|
assert "$P{xxx}" in ctx
|
||
|
|
assert "billNo" in ctx
|
||
|
|
assert "amount" in ctx
|
||
|
|
|
||
|
|
def test_parameter_mode_without_fields(self):
|
||
|
|
ctx = build_datasource_context("parameter", [])
|
||
|
|
assert "[数据源模式: 参数]" in ctx
|
||
|
|
assert "$P{xxx}" in ctx
|
||
|
|
|
||
|
|
def test_jdbc_mode_with_config(self):
|
||
|
|
db_config = {"url": "jdbc:mysql://localhost:3306/mydb",
|
||
|
|
"driver": "com.mysql.cj.jdbc.Driver"}
|
||
|
|
ctx = build_datasource_context("jdbc", [], db_config)
|
||
|
|
assert "[数据源模式: JDBC]" in ctx
|
||
|
|
assert "jdbc:mysql://" in ctx
|
||
|
|
assert "CDATA" in ctx
|
||
|
|
|
||
|
|
def test_jdbc_mode_without_config_shows_warning(self):
|
||
|
|
ctx = build_datasource_context("jdbc", [])
|
||
|
|
assert "尚未配置数据库连接" in ctx
|
||
|
|
assert "P{xxx}" in ctx
|
||
|
|
|
||
|
|
|
||
|
|
# ── JDBC 配置 ───────────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestConfigureJdbc:
|
||
|
|
def test_configure_returns_update_dict(self):
|
||
|
|
state = _make_state()
|
||
|
|
update = configure_jdbc(
|
||
|
|
state, url="jdbc:mysql://localhost/db",
|
||
|
|
driver="com.mysql.cj.jdbc.Driver",
|
||
|
|
username="root", password="pass")
|
||
|
|
assert update["datasource_mode"] == "jdbc"
|
||
|
|
assert update["db_config"]["url"] == "jdbc:mysql://localhost/db"
|
||
|
|
assert update["db_config"]["username"] == "root"
|
||
|
|
|
||
|
|
def test_default_driver_is_mysql(self):
|
||
|
|
update = configure_jdbc(_make_state(), url="jdbc:postgresql://localhost/db")
|
||
|
|
assert "mysql" in update["db_config"]["driver"]
|
||
|
|
|
||
|
|
|
||
|
|
# ── ask_db_config ───────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestAskDbConfig:
|
||
|
|
def test_returns_none_for_parameter_mode(self):
|
||
|
|
state = _make_state(datasource_mode="parameter")
|
||
|
|
assert ask_db_config(state) is None
|
||
|
|
|
||
|
|
def test_returns_none_when_jdbc_configured(self):
|
||
|
|
state = _make_state(datasource_mode="jdbc",
|
||
|
|
db_config={"url": "jdbc:mysql://localhost/db"})
|
||
|
|
assert ask_db_config(state) is None
|
||
|
|
|
||
|
|
def test_returns_prompt_when_jdbc_missing_config(self):
|
||
|
|
state = _make_state(datasource_mode="jdbc")
|
||
|
|
msg = ask_db_config(state)
|
||
|
|
assert msg is not None
|
||
|
|
assert "JDBC URL" in msg
|
||
|
|
assert "用户名" in msg
|
||
|
|
assert "密码" in msg
|
||
|
|
|
||
|
|
def test_returns_none_when_db_config_empty(self):
|
||
|
|
state = _make_state(datasource_mode="jdbc", db_config={})
|
||
|
|
msg = ask_db_config(state)
|
||
|
|
assert msg is not None
|