116 lines
3.9 KiB
Python
116 lines
3.9 KiB
Python
|
|
"""数据源模式解析模块。
|
|||
|
|
|
|||
|
|
默认使用 $P{xxx} 参数模式;用户可选择 JDBC 直连模式。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import re
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
|
|||
|
|
from agent.state import AgentState
|
|||
|
|
|
|||
|
|
load_dotenv()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def resolve_datasource_mode(state: AgentState) -> str:
|
|||
|
|
"""返回数据源模式: "parameter" 或 "jdbc"。
|
|||
|
|
|
|||
|
|
优先读取 state 中已设定的模式,否则根据用户输入检测。
|
|||
|
|
"""
|
|||
|
|
existing = state.get("datasource_mode", "")
|
|||
|
|
if existing in ("parameter", "jdbc"):
|
|||
|
|
return existing
|
|||
|
|
|
|||
|
|
user_input = state.get("user_input", "")
|
|||
|
|
if _detect_jdbc_intent(user_input):
|
|||
|
|
return "jdbc"
|
|||
|
|
return "parameter"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _detect_jdbc_intent(user_input: str) -> bool:
|
|||
|
|
"""检测用户是否想要 JDBC 直连数据库模式。"""
|
|||
|
|
patterns = [
|
|||
|
|
r"(直连|直连数据库|数据库直连)",
|
|||
|
|
r"(从|在)(数据库|DB|MySQL|PostgreSQL|Oracle|SQL Server)\w*",
|
|||
|
|
r"(jdbc|JDBC)",
|
|||
|
|
r"(连接|连)(数据库|DB)",
|
|||
|
|
r"(查询|select|SELECT)\s",
|
|||
|
|
]
|
|||
|
|
for pat in patterns:
|
|||
|
|
if re.search(pat, user_input):
|
|||
|
|
return True
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _sanitize_url(url: str) -> str:
|
|||
|
|
"""剥离 JDBC URL 中的 user:password@ 片段,防止泄露到 LLM prompt。"""
|
|||
|
|
return re.sub(r"://[^@]*@", "://***:***@", url)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_datasource_context(mode: str, kb_fields: list, db_config: Optional[dict] = None) -> str:
|
|||
|
|
"""构建数据源上下文字符串,注入生成 prompt。"""
|
|||
|
|
if mode == "jdbc":
|
|||
|
|
if not db_config or not db_config.get("url"):
|
|||
|
|
return (
|
|||
|
|
"[数据源模式: JDBC]\n"
|
|||
|
|
"⚠ 用户想要 JDBC 直连模式,但尚未配置数据库连接信息。\n"
|
|||
|
|
"请先生成带 $P{xxx} 参数占位符的 JRXML,并提醒用户配置 JDBC 连接。"
|
|||
|
|
)
|
|||
|
|
safe_url = _sanitize_url(db_config.get("url", ""))
|
|||
|
|
return (
|
|||
|
|
"[数据源模式: JDBC]\n"
|
|||
|
|
f"连接URL: {safe_url}\n"
|
|||
|
|
f"驱动: {db_config.get('driver', '')}\n"
|
|||
|
|
"请使用 <queryString><![CDATA[...]]></queryString> 中的 SQL 查询。"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# parameter mode
|
|||
|
|
if kb_fields:
|
|||
|
|
field_list = "\n".join(
|
|||
|
|
f"| {f['name']} | {f.get('description', '')} | {f.get('type', 'java.lang.String')} |"
|
|||
|
|
for f in kb_fields
|
|||
|
|
)
|
|||
|
|
return (
|
|||
|
|
"[数据源模式: 参数]\n"
|
|||
|
|
"使用 $P{xxx} 参数模式,以下为可用参数:\n"
|
|||
|
|
f"| 参数名 | 含义 | 类型 |\n|---|---|---|\n{field_list}"
|
|||
|
|
)
|
|||
|
|
return "[数据源模式: 参数]\n使用 $P{xxx} 参数模式生成 JRXML。"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def configure_jdbc(state: AgentState, url: str = "", driver: str = "",
|
|||
|
|
username: str = "", password: str = "") -> dict:
|
|||
|
|
"""配置 JDBC 连接并返回更新字段。
|
|||
|
|
|
|||
|
|
注意:db_config 会被存入 AgentState 并持久化到会话文件。
|
|||
|
|
生产环境应使用外部密钥管理服务,避免明文存储密码。
|
|||
|
|
"""
|
|||
|
|
return {
|
|||
|
|
"datasource_mode": "jdbc",
|
|||
|
|
"db_config": {
|
|||
|
|
"url": url,
|
|||
|
|
"driver": driver or "com.mysql.cj.jdbc.Driver",
|
|||
|
|
"username": username,
|
|||
|
|
"password": password,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def ask_db_config(state: AgentState) -> Optional[str]:
|
|||
|
|
"""如果用户选了 JDBC 模式但未配置 DB 连接,返回反问消息。"""
|
|||
|
|
mode = resolve_datasource_mode(state)
|
|||
|
|
if mode == "jdbc":
|
|||
|
|
db_config = state.get("db_config", {})
|
|||
|
|
if not db_config or not db_config.get("url"):
|
|||
|
|
return (
|
|||
|
|
"您选择了数据库直连模式,请提供以下信息:\n"
|
|||
|
|
"1. JDBC URL(如 jdbc:mysql://localhost:3306/dbname)\n"
|
|||
|
|
"2. 数据库用户名\n"
|
|||
|
|
"3. 数据库密码\n"
|
|||
|
|
"4. 驱动类名(可选,默认 com.mysql.cj.jdbc.Driver)"
|
|||
|
|
)
|
|||
|
|
return None
|