Files
agent_jrxml/app.py
T

927 lines
32 KiB
Python
Raw Normal View History

"""Streamlit 多轮对话 UI,用于 JRXML 生成代理。
支持:
- 流式输出(LLM 逐字展示)
- 节点平铺展开(每个处理阶段独立展示)
- 完成后自动折叠节点区
- 过程总结卡片
"""
import os
import sys
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
try:
import torchvision
except Exception:
pass
import base64
import tempfile
import time
from pathlib import Path
import streamlit as st
import streamlit.components.v1 as components
from dotenv import load_dotenv
load_dotenv(override=True)
from agent.graph import build_graph, create_initial_state
from backend.session import (
create_session,
load_session,
delete_session,
list_all_sessions,
)
from backend.logger import get_logger, set_trace_id, generate_trace_id
_app_log = get_logger("app")
st.set_page_config(
page_title="JRXML 代理",
page_icon="📊",
layout="wide",
initial_sidebar_state="expanded",
)
# 阻止 Streamlit 裸 'c' 键清除缓存,保留 Ctrl+C 复制行为
st.html("""
<script>
(function() {
const parent = window.parent.document;
parent.addEventListener('keydown', function(e) {
// 仅拦截裸 'c' 键(非 Ctrl/Cmd 组合)
if (e.key === 'c' && !e.ctrlKey && !e.metaKey && !e.altKey) {
const tag = parent.activeElement ? parent.activeElement.tagName : '';
if (tag !== 'INPUT' && tag !== 'TEXTAREA' && !parent.activeElement.isContentEditable) {
e.stopImmediatePropagation();
e.preventDefault();
}
}
}, true);
})();
</script>
""")
# ---- 节点名称 → 中文标签 ----
NODE_LABELS = {
"load_session": "📂 加载会话",
"process_input": "📝 记录输入",
"manage_context": "🧠 管理上下文",
"save_state_snapshot": "💾 保存快照",
"classify_intent": "🔍 识别意图",
"retrieve": "📚 检索模板",
"generate": "⚙️ 生成 JRXML",
"modify_jrxml": "🔧 修改 JRXML",
"validate": "✅ 验证",
"explain_error": "🔎 分析错误",
"correct_jrxml": "🛠 自动修正",
"finalize": "📋 完成",
"handle_consult": "💬 咨询回答",
"handle_undo": "↩ 撤销操作",
"handle_reset": "🔄 重置会话",
"save_session": "💾 保存会话",
"generate_skeleton": "🏗 生成骨架",
"refine_layout": "📐 精调布局",
"map_fields": "🏷 映射字段",
}
INTENT_LABELS = {
"initial_generation": "新建报表",
"modify_report": "修改报表",
"preview_report": "预览报表",
"export_pdf": "导出 PDF",
"export_jrxml": "下载 JRXML",
"undo_modification": "撤销修改",
"consult_question": "咨询问题",
"reset_session": "重置会话",
}
SKIP_NODES = {"load_session", "process_input", "manage_context",
"save_state_snapshot", "save_session"}
def _render_jrxml(jrxml: str, max_lines: int = 30):
"""展示 JRXML 代码(折叠、限行)。"""
lines = jrxml.strip().split("\n")
preview = "\n".join(lines[:max_lines])
if len(lines) > max_lines:
preview += f"\n... (共 {len(lines)} 行)"
st.code(preview, language="xml")
# ---- URL 参数 ----
query_params = st.query_params
url_session_id = query_params.get("session_id", "")
# ---- 会话状态初始化 ----
if "messages" not in st.session_state:
st.session_state.messages = []
if "graph" not in st.session_state:
st.session_state.graph = build_graph()
if "pending_action" not in st.session_state:
st.session_state.pending_action = None
if "agent_state" not in st.session_state:
if url_session_id:
data = load_session(url_session_id)
if data and data.get("agent_state"):
st.session_state.agent_state = data["agent_state"]
st.session_state.agent_state["session_id"] = url_session_id
else:
st.session_state.agent_state = create_initial_state()
new_data = create_session(name="", agent_state=st.session_state.agent_state)
st.session_state.agent_state["session_id"] = new_data["session_id"]
st.session_state.agent_state["session_name"] = new_data["session_name"]
st.session_state.agent_state["created_at"] = new_data["created_at"]
else:
st.session_state.agent_state = create_initial_state()
new_data = create_session(name="", agent_state=st.session_state.agent_state)
st.session_state.agent_state["session_id"] = new_data["session_id"]
st.session_state.agent_state["session_name"] = new_data["session_name"]
st.session_state.agent_state["created_at"] = new_data["created_at"]
current_session_id = st.session_state.agent_state.get("session_id", "")
def run_agent(user_input: str):
"""运行代理图:流式渲染节点进度 + LLM 文本。"""
trace_id = generate_trace_id()
set_trace_id(trace_id)
agent_state = st.session_state.agent_state
session_id = agent_state.get("session_id", "")
_app_log.info(
"代理执行开始",
extra={
"session_id": session_id,
"trace_id": trace_id,
"user_input_preview": user_input[:200],
"user_input_length": len(user_input),
"has_jrxml": bool(agent_state.get("current_jrxml", "").strip()),
"intent": agent_state.get("intent", ""),
},
)
if agent_state.get("current_jrxml") and agent_state.get("status") == "pass":
agent_state["user_modification_request"] = user_input
agent_state["user_input"] = user_input
agent_state["retry_count"] = 0
# ---- UI 占位 ----
progress_placeholder = st.empty() # 实时节点进度
streaming_placeholder = st.empty() # 流式文本
summary_placeholder = st.empty() # 总结卡片
# 初始状态提示
progress_placeholder.info("⏳ 正在分析您的需求...")
executed_nodes: list[dict] = []
stream_text = ""
stream_active = False
final_state = None
def _render_progress(nodes: list[dict]):
"""渲染实时节点进度到占位符。"""
if not nodes:
return
lines = []
for i, node in enumerate(nodes):
icon = "" if i == len(nodes) - 1 else ""
detail = f"{node['detail']}" if node.get("detail") else ""
lines.append(f"{icon} {node['label']}{detail}")
progress_placeholder.markdown("\n\n".join(lines))
try:
for event in st.session_state.graph.stream(
agent_state, stream_mode=["updates", "custom"]
):
mode, data = event
if mode == "updates":
for node_name, node_state in data.items():
label = NODE_LABELS.get(node_name, node_name)
if node_name not in SKIP_NODES:
executed_nodes.append({
"name": node_name,
"label": label,
})
if node_name == "classify_intent":
intent = node_state.get("intent", "")
il = INTENT_LABELS.get(intent, intent)
executed_nodes[-1]["detail"] = f"意图: {il}"
elif node_name == "retrieve":
ctx = node_state.get("retrieved_context", "")
executed_nodes[-1]["detail"] = (
f"找到 {len(ctx)} 字符参考模板" if ctx else "未匹配到模板"
)
elif node_name in ("generate", "modify_jrxml", "correct_jrxml",
"generate_skeleton", "refine_layout", "map_fields"):
jrxml = node_state.get("current_jrxml", "")
executed_nodes[-1]["detail"] = f"生成 {len(jrxml)} 字符 JRXML"
elif node_name == "validate":
status = node_state.get("status", "")
if status == "pass":
executed_nodes[-1]["detail"] = "验证通过 ✓"
else:
err = node_state.get("error_msg", "")
executed_nodes[-1]["detail"] = f"验证失败: {err[:80]}"
elif node_name == "explain_error":
expl = node_state.get("natural_explanation", "")
executed_nodes[-1]["detail"] = expl[:120]
elif node_name == "handle_consult":
ans = node_state.get("consult_answer", "")
executed_nodes[-1]["detail"] = ans[:150]
final_state = node_state
# 每个节点完成后立即更新进度
_render_progress(executed_nodes)
elif mode == "custom":
cd = data
if cd.get("type") == "stream":
stream_text += cd.get("text", "")
stream_active = True
streaming_placeholder.code(stream_text, language="xml")
except Exception as e:
progress_placeholder.empty()
_app_log.error(
f"代理执行异常: {e}",
extra={"session_id": session_id, "error": str(e)},
)
st.error(f"工作流异常: {e}")
return
# ---- 清理临时占位 ----
progress_placeholder.empty()
if stream_active:
streaming_placeholder.empty()
# ---- 总结卡片 ----
# 注:node_state 只含变更字段,用 agent_state(被所有节点就地修改)获取完整状态
final_state = agent_state
if final_state:
st.session_state.agent_state = final_state
intent = final_state.get("intent", "")
status = final_state.get("status", "")
with summary_placeholder.container(border=True):
if intent == "consult_question":
answer = final_state.get("consult_answer", "")
st.info(answer)
st.session_state.messages.append({
"role": "assistant", "content": answer, "type": "consult",
})
elif intent in ("undo_modification", "reset_session"):
st.success("操作已完成")
elif intent in ("preview_report", "export_pdf", "export_jrxml"):
jrxml = final_state.get("current_jrxml", "")
if jrxml:
st.success("✅ 当前报表")
_render_jrxml(jrxml)
st.session_state.messages.append({
"role": "assistant", "content": jrxml, "type": "jrxml",
})
else:
st.warning("⚠ 当前没有报表可以展示。")
elif status == "pass":
jrxml = final_state.get("current_jrxml", "")
st.success("✅ JRXML 生成成功")
st.markdown("**生成结果:**")
_render_jrxml(jrxml)
st.caption("您可以从侧边栏下载文件,或继续对话进行修改。")
st.session_state.messages.append({
"role": "assistant", "content": jrxml, "type": "jrxml",
})
st.session_state.messages.append({
"role": "assistant",
"content": "✅ JRXML 生成成功!您可以从侧边栏下载文件,或继续修改。",
"type": "success",
})
else:
jrxml = final_state.get("current_jrxml", "")
error_msg = final_state.get("error_msg", "未知错误")
explanation = final_state.get("natural_explanation", "")
retries = final_state.get("retry_count", 0)
st.error(f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML")
st.markdown(f"**错误:** {error_msg}")
if explanation:
st.markdown(f"**原因:** {explanation}")
if jrxml:
with st.expander("查看当前 JRXML"):
_render_jrxml(jrxml, max_lines=80)
st.caption("💡 下次输入修改需求时,系统会自动加载失败上下文继续修复。")
st.session_state.messages.append({
"role": "assistant",
"content": f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n\n**错误:** {error_msg}\n\n💡 请直接描述修改需求,系统会自动加载失败上下文。",
"type": "error_explanation",
})
# OCR 字段提取结果展示
ocr_result = agent_state.get("ocr_extraction_result", {})
if ocr_result and ocr_result.get("ocr_available") and ocr_result.get("fields"):
with st.expander("🔍 OCR 单据字段提取结果", expanded=False):
fields = ocr_result.get("fields", [])
non_empty = [f for f in fields if f.get("field_value")]
empty = [f for f in fields if not f.get("field_value")]
if non_empty:
st.markdown("**已提取字段:**")
for f in non_empty:
method = f.get("extraction_method", "")
conf = f.get("confidence", 0)
st.markdown(
f"- **{f['field_name']}**: `{f['field_value']}` "
f"(置信度: {conf:.0%}, 方法: {method}"
)
if empty:
st.caption(
f"未提取到值的字段: {', '.join(f['field_name'] for f in empty)}"
)
st.caption(
f"共检测到 {ocr_result.get('total_elements', 0)} 个文本元素"
)
else:
st.error("未产生结果,请重试。")
_app_log.info(
"代理执行完成",
extra={
"session_id": session_id,
"intent": final_state.get("intent", ""),
"status": final_state.get("status", ""),
"jrxml_length": len(final_state.get("current_jrxml", "")),
"retry_count": final_state.get("retry_count", 0),
},
)
# ---- 侧边栏 ----
with st.sidebar:
st.title("📊 JRXML 代理")
st.markdown("通过自然语言生成 JasperReports 模板。")
st.divider()
# 会话管理
st.markdown("### 会话管理")
sessions = list_all_sessions()
session_options = {}
for s in sessions:
sid = s["session_id"]
name = s.get("session_name", sid)
updated = s.get("updated_at", "")[:16]
session_options[f"{name} ({updated})"] = sid
selected_label = None
for label, sid in session_options.items():
if sid == current_session_id:
selected_label = label
break
selected = st.selectbox(
"切换会话",
options=list(session_options.keys()),
index=list(session_options.keys()).index(selected_label) if selected_label else 0,
key="session_selector",
)
if selected and session_options.get(selected) != current_session_id:
new_sid = session_options[selected]
if st.session_state.get("_last_switched_to") == new_sid:
# 防止同一会话重复切换导致的无限 rerun 循环
st.session_state._last_switched_to = ""
else:
data = load_session(new_sid)
if data and data.get("agent_state"):
_app_log.info(
"切换会话",
extra={"from_session": current_session_id, "to_session": new_sid},
)
data["agent_state"]["session_id"] = new_sid
st.session_state.agent_state = data["agent_state"]
st.session_state.messages = []
st.session_state._last_switched_to = new_sid
st.rerun()
col1, col2 = st.columns(2)
with col1:
if st.button(" 新建", use_container_width=True):
new_data = create_session(name="", agent_state=create_initial_state())
_app_log.info(
"新建会话",
extra={"session_id": new_data["session_id"]},
)
st.session_state.agent_state = create_initial_state()
st.session_state.agent_state["session_id"] = new_data["session_id"]
st.session_state.agent_state["session_name"] = new_data["session_name"]
st.session_state.agent_state["created_at"] = new_data["created_at"]
st.session_state.messages = []
st.rerun()
with col2:
if st.button("🗑 删除", use_container_width=True):
if current_session_id:
_app_log.info(
"删除会话",
extra={"session_id": current_session_id},
)
delete_session(current_session_id)
st.session_state.agent_state = create_initial_state()
new_data = create_session(name="", agent_state=st.session_state.agent_state)
st.session_state.agent_state["session_id"] = new_data["session_id"]
st.session_state.agent_state["session_name"] = new_data["session_name"]
st.session_state.agent_state["created_at"] = new_data["created_at"]
st.session_state.messages = []
st.rerun()
current_name = st.session_state.agent_state.get("session_name", "")
st.caption(f"当前: {current_name} (`{current_session_id}`)")
st.divider()
st.markdown("### 快捷操作")
has_jrxml = bool(st.session_state.agent_state.get("current_jrxml", "").strip())
has_history = bool(st.session_state.agent_state.get("history_states", []))
qcol1, qcol2 = st.columns(2)
with qcol1:
if st.button("👁 预览", use_container_width=True, disabled=not has_jrxml):
with st.spinner("正在准备预览..."):
run_agent("预览报表")
st.rerun()
with qcol2:
if st.button("↩ 撤销", use_container_width=True, disabled=not has_history):
with st.spinner("正在撤销..."):
run_agent("撤销上一步修改")
st.rerun()
if st.button("🔄 重置会话", use_container_width=True):
with st.spinner("正在重置..."):
run_agent("重新来,清空当前报表")
st.rerun()
st.divider()
st.markdown("### 配置")
llm_backend = os.getenv("LLM_BACKEND", "cloud")
llm_model = os.getenv("LLM_MODEL", os.getenv("LOCAL_LLM_MODEL", "gpt-4o"))
st.caption(f"大语言模型: {llm_backend} / {llm_model}")
st.caption(f"最大重试次数: {os.getenv('MAX_RETRY', '5')}")
st.caption(f"验证服务: {os.getenv('VALIDATION_SERVICE_URL', 'http://localhost:8001/validate')}")
st.divider()
st.markdown("### 下载")
final = st.session_state.agent_state.get("final_jrxml", "")
versions = st.session_state.agent_state.get("jrxml_versions", [])
if final:
st.download_button(
label="📥 下载最新 JRXML",
data=final,
file_name="report.jrxml",
mime="application/xml",
use_container_width=True,
)
if versions:
with st.expander("📋 历史版本", expanded=False):
for i, v in enumerate(reversed(versions)):
ts = v.get("ts", "")[:16]
label = v.get("label", "版本")
status = v.get("status", "")
icon = "" if status == "pass" else ""
dl_label = f"{icon} v{len(versions)-i}{label} ({ts})"
st.download_button(
label=dl_label,
data=v.get("jrxml", ""),
file_name=f"report_v{len(versions)-i}.jrxml",
mime="application/xml",
use_container_width=True,
key=f"dl_v{i}",
)
# ---- 标题 ----
st.title("📝 JRXML 报表生成器")
st.caption("用自然语言描述您的报表需求,我将逐步生成可用的 JRXML 模板。")
# ---- 聊天历史 ----
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
if msg.get("type") == "jrxml":
with st.expander("查看生成的 JRXML", expanded=False):
st.code(msg["content"], language="xml")
elif msg.get("type") == "error_explanation":
st.warning(msg["content"])
elif msg.get("type") == "success":
st.success(msg["content"])
elif msg.get("type") == "consult":
st.info(msg["content"])
else:
st.markdown(msg["content"])
# ---- 统一聊天输入组件 ----
UNIFIED_CHAT_HTML = r"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="utf-8">
<style>
* { box-sizing: border-box; margin: 0; padding: 0; }
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
background: transparent;
padding: 4px 0;
}
.chat-container {
position: relative;
border: 1px solid #d1d5db;
border-radius: 12px;
padding: 8px 12px;
background: #ffffff;
transition: border-color 0.2s, box-shadow 0.2s;
}
.chat-container:focus-within {
border-color: #3b82f6;
box-shadow: 0 0 0 2px rgba(59,130,246,0.15);
}
.chat-container.drag-active {
border-color: #3b82f6;
background: rgba(59,130,246,0.04);
}
.file-chips {
display: flex;
flex-wrap: wrap;
gap: 6px;
margin-bottom: 6px;
}
.file-chips:empty { display: none; }
.file-chip {
display: inline-flex;
align-items: center;
gap: 4px;
padding: 2px 8px;
background: #f3f4f6;
border-radius: 14px;
font-size: 12px;
color: #374151;
max-width: 200px;
}
.file-chip .chip-icon { font-size: 13px; }
.file-chip .chip-name {
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.file-chip .chip-remove {
border: none;
background: none;
cursor: pointer;
color: #9ca3af;
font-size: 14px;
line-height: 1;
padding: 0 2px;
flex-shrink: 0;
}
.file-chip .chip-remove:hover { color: #ef4444; }
.input-row {
display: flex;
align-items: flex-end;
gap: 8px;
}
.attach-btn {
border: none;
background: none;
cursor: pointer;
padding: 4px 6px;
font-size: 20px;
line-height: 1;
color: #6b7280;
border-radius: 6px;
transition: background 0.15s, color 0.15s;
flex-shrink: 0;
}
.attach-btn:hover { background: #f3f4f6; color: #374151; }
textarea {
flex: 1;
border: none;
outline: none;
resize: none;
font-size: 15px;
line-height: 1.5;
font-family: inherit;
color: #111827;
background: transparent;
padding: 4px 0;
min-height: 24px;
max-height: 120px;
overflow-y: auto;
}
textarea::placeholder { color: #9ca3af; }
.send-btn {
border: none;
cursor: pointer;
padding: 4px 10px;
font-size: 16px;
background: #e5e7eb;
color: #9ca3af;
border-radius: 8px;
transition: all 0.15s;
flex-shrink: 0;
}
.send-btn.active { background: #3b82f6; color: #fff; }
.send-btn.active:hover { background: #2563eb; }
.send-btn:disabled { opacity: 0.5; cursor: default; }
.error-toast {
position: fixed;
bottom: 12px;
left: 50%;
transform: translateX(-50%);
background: #ef4444;
color: #fff;
padding: 6px 16px;
border-radius: 8px;
font-size: 13px;
z-index: 9999;
animation: toastOut 2.5s forwards;
pointer-events: none;
}
@keyframes toastOut {
0%, 70% { opacity: 1; }
100% { opacity: 0; }
}
@media (prefers-color-scheme: dark) {
.chat-container { background: #1f2937; border-color: #374151; }
.chat-container:focus-within { border-color: #3b82f6; }
.file-chip { background: #374151; color: #e5e7eb; }
.file-chip .chip-remove { color: #6b7280; }
.attach-btn { color: #9ca3af; }
.attach-btn:hover { background: #374151; color: #e5e7eb; }
textarea { color: #f9fafb; }
textarea::placeholder { color: #6b7280; }
.send-btn { background: #374151; }
}
</style>
</head>
<body>
<div class="chat-container" id="container">
<div class="file-chips" id="chips"></div>
<div class="input-row">
<button class="attach-btn" id="attachBtn" title="附加文件">&#x1F4CE;</button>
<textarea id="textInput" placeholder="描述您的报表需求..." rows="1"></textarea>
<button class="send-btn" id="sendBtn" title="发送">&#x27A4;</button>
</div>
<input type="file" id="fileInput" multiple hidden
accept=".png,.jpg,.jpeg,.bmp,.webp,.pdf,.docx,.xlsx,.xls,.doc,.txt">
</div>
<script>
const container = document.getElementById('container');
const chipsEl = document.getElementById('chips');
const textInput = document.getElementById('textInput');
const sendBtn = document.getElementById('sendBtn');
const attachBtn = document.getElementById('attachBtn');
const fileInput = document.getElementById('fileInput');
let attachedFiles = [];
const MAX_FILES = 10;
const MAX_SIZE = 20 * 1024 * 1024;
function getIcon(type) {
if (type.startsWith('image/')) return '🖼';
if (type.includes('pdf')) return '📄';
if (type.includes('document')) return '📝';
if (type.includes('spreadsheet') || type.includes('excel')) return '📊';
return '📎';
}
function updateSendBtn() {
var canSend = textInput.value.trim() || attachedFiles.length > 0;
sendBtn.classList.toggle('active', canSend);
}
function renderChips() {
chipsEl.innerHTML = '';
attachedFiles.forEach(function(f, i) {
var chip = document.createElement('span');
chip.className = 'file-chip';
var name = f.name.length > 16 ? f.name.slice(0,14)+'..' : f.name;
chip.innerHTML = '<span class="chip-icon">'+getIcon(f.type)+'</span>' +
'<span class="chip-name">'+name+'</span>' +
'<button class="chip-remove">&times;</button>';
chip.querySelector('.chip-remove').onclick = (function(idx) {
return function() {
attachedFiles.splice(idx, 1);
renderChips();
updateSendBtn();
};
})(i);
chipsEl.appendChild(chip);
});
updateSendBtn();
}
function addFiles(fileList) {
for (var i = 0; i < fileList.length; i++) {
var file = fileList[i];
if (attachedFiles.length >= MAX_FILES) { showToast('最多附加 '+MAX_FILES+' 个文件'); break; }
if (file.size > MAX_SIZE) { showToast(file.name+' 超过 20MB 限制'); continue; }
if (attachedFiles.some(function(f) { return f.name === file.name && f.size === file.size; })) continue;
attachedFiles.push({name: file.name, type: file.type, file: file});
}
renderChips();
}
function showToast(msg) {
var t = document.createElement('div');
t.className = 'error-toast';
t.textContent = msg;
document.body.appendChild(t);
setTimeout(function() { t.remove(); }, 2600);
}
function readFile(file) {
return new Promise(function(resolve, reject) {
var reader = new FileReader();
reader.onload = function() { resolve(reader.result); };
reader.onerror = reject;
reader.readAsDataURL(file);
});
}
async function handleSend() {
var text = textInput.value.trim();
if (!text && attachedFiles.length === 0) return;
sendBtn.disabled = true;
var files = [];
for (var i = 0; i < attachedFiles.length; i++) {
var f = attachedFiles[i];
try {
var dataUrl = await readFile(f.file);
files.push({name: f.name, type: f.type, data: dataUrl, size: f.file.size});
} catch(e) {
showToast(f.name+' 读取失败');
}
}
Streamlit.setComponentValue({text: text, files: files});
textInput.value = '';
attachedFiles = [];
renderChips();
sendBtn.disabled = false;
textInput.style.height = 'auto';
}
attachBtn.onclick = function() { fileInput.click(); };
fileInput.onchange = function() { addFiles(fileInput.files); fileInput.value = ''; };
textInput.oninput = function() {
updateSendBtn();
textInput.style.height = 'auto';
textInput.style.height = Math.min(textInput.scrollHeight, 120) + 'px';
};
textInput.onkeydown = function(e) {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
handleSend();
}
};
sendBtn.onclick = handleSend;
document.addEventListener('paste', function(e) {
var items = e.clipboardData && e.clipboardData.items;
if (!items) return;
var files = [];
for (var i = 0; i < items.length; i++) {
if (items[i].kind === 'file') files.push(items[i].getAsFile());
}
if (files.length) { e.preventDefault(); addFiles(files); }
});
var containerDiv = document.getElementById('container');
containerDiv.addEventListener('dragover', function(e) {
e.preventDefault();
containerDiv.classList.add('drag-active');
});
containerDiv.addEventListener('dragleave', function() {
containerDiv.classList.remove('drag-active');
});
containerDiv.addEventListener('drop', function(e) {
e.preventDefault();
containerDiv.classList.remove('drag-active');
addFiles(e.dataTransfer.files);
});
updateSendBtn();
</script>
</body>
</html>
"""
chat_result = components.html(UNIFIED_CHAT_HTML, height=180)
if chat_result and isinstance(chat_result, dict):
prompt = chat_result.get("text", "")
files = chat_result.get("files", [])
from backend.file_parser import parse_file
from backend.layout_analyzer import analyze_layout, extract_layout_schema
file_texts = []
attached_info = []
first_image_path = None
temp_paths = []
for f in files:
header, b64data = f.get("data", ",").split(",", 1)
raw = base64.b64decode(b64data)
mime = f.get("type", "")
mime_to_suffix = {
"image/png": ".png", "image/jpeg": ".jpg", "image/bmp": ".bmp",
"image/webp": ".webp", "application/pdf": ".pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.ms-excel": ".xls", "application/msword": ".doc",
"text/plain": ".txt",
}
suffix = mime_to_suffix.get(mime, Path(f["name"]).suffix.lower())
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(raw)
tmp_path = tmp.name
temp_paths.append(tmp_path)
result = parse_file(tmp_path, suffix)
text = result["text"]
file_type = result["file_type"]
img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
if suffix in img_suffixes and result.get("method") not in ("metadata_only", None):
try:
layout = analyze_layout(tmp_path)
tt = layout.get("template_type", "unknown")
if tt == "full_a4":
text = layout["description"]
file_type = "a4_template"
schema = extract_layout_schema(layout)
st.session_state.agent_state["layout_schema"] = schema
st.session_state.agent_state["ocr_elements"] = layout.get("rows", [])
elif tt == "partial_rows":
file_type = "a4_partial"
except Exception:
pass
file_texts.append(f"[附加文件: {f['name']} ({file_type})]\n{text}")
attached_info.append({"name": f["name"], "type": file_type, "length": len(text)})
if not first_image_path and file_type in ("image", "a4_template", "a4_partial"):
first_image_path = tmp_path
if file_texts:
full_prompt = "\n\n".join(file_texts) + "\n\n---\n用户需求:\n" + prompt
else:
full_prompt = prompt
if first_image_path:
st.session_state.agent_state["uploaded_file_path"] = first_image_path
_app_log.info(
"收到用户输入",
extra={
"session_id": current_session_id,
"prompt_preview": prompt[:200],
"prompt_length": len(prompt),
"has_uploaded_files": bool(attached_info),
"uploaded_files": attached_info,
},
)
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
run_agent(full_prompt)
for p in temp_paths:
try:
Path(p).unlink(missing_ok=True)
except Exception:
pass
st.rerun()