138 lines
4.3 KiB
Python
138 lines
4.3 KiB
Python
|
|
"""
|
||
|
|
config.py
|
||
|
|
统一配置管理,从 .env 文件读取环境变量
|
||
|
|
所有脚本通过此模块获取配置,避免硬编码
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
|
||
|
|
def _get_project_root() -> Path:
|
||
|
|
return Path(__file__).resolve().parent
|
||
|
|
|
||
|
|
|
||
|
|
def _load_dotenv():
|
||
|
|
"""简易 .env 文件加载器,不依赖 python-dotenv"""
|
||
|
|
project_root = _get_project_root()
|
||
|
|
env_file = project_root / ".env"
|
||
|
|
|
||
|
|
if not env_file.exists():
|
||
|
|
return
|
||
|
|
|
||
|
|
with open(env_file, "r", encoding="utf-8") as f:
|
||
|
|
for line in f:
|
||
|
|
line = line.strip()
|
||
|
|
if not line or line.startswith("#"):
|
||
|
|
continue
|
||
|
|
if "=" not in line:
|
||
|
|
continue
|
||
|
|
key, _, value = line.partition("=")
|
||
|
|
key = key.strip()
|
||
|
|
value = value.strip().strip('"').strip("'")
|
||
|
|
if key and key not in os.environ:
|
||
|
|
os.environ[key] = value
|
||
|
|
|
||
|
|
|
||
|
|
_load_dotenv()
|
||
|
|
|
||
|
|
PROJECT_ROOT = _get_project_root()
|
||
|
|
|
||
|
|
|
||
|
|
def _get_path(key: str, default: str) -> Path:
|
||
|
|
value = os.environ.get(key, default)
|
||
|
|
p = Path(value)
|
||
|
|
if not p.is_absolute():
|
||
|
|
p = PROJECT_ROOT / p
|
||
|
|
return p
|
||
|
|
|
||
|
|
|
||
|
|
def _get_str(key: str, default: str) -> str:
|
||
|
|
return os.environ.get(key, default)
|
||
|
|
|
||
|
|
|
||
|
|
def _get_bool(key: str, default: bool) -> bool:
|
||
|
|
value = os.environ.get(key, "").strip().lower()
|
||
|
|
if value in ("true", "1", "yes"):
|
||
|
|
return True
|
||
|
|
if value in ("false", "0", "no"):
|
||
|
|
return False
|
||
|
|
return default
|
||
|
|
|
||
|
|
|
||
|
|
def _get_int(key: str, default: int) -> int:
|
||
|
|
try:
|
||
|
|
return int(os.environ.get(key, str(default)))
|
||
|
|
except ValueError:
|
||
|
|
return default
|
||
|
|
|
||
|
|
|
||
|
|
def _get_float(key: str, default: float) -> float:
|
||
|
|
try:
|
||
|
|
return float(os.environ.get(key, str(default)))
|
||
|
|
except ValueError:
|
||
|
|
return default
|
||
|
|
|
||
|
|
|
||
|
|
# ==================== 模型配置 ====================
|
||
|
|
EMBEDDING_MODEL_NAME = _get_str("EMBEDDING_MODEL_NAME", "Qwen/Qwen3-Embedding-4B")
|
||
|
|
EMBEDDING_MODEL_PATH = _get_path("EMBEDDING_MODEL_PATH", "models/Qwen3-Embedding-4B")
|
||
|
|
HF_ENDPOINT = _get_str("HF_ENDPOINT", "https://hf-mirror.com")
|
||
|
|
|
||
|
|
# ==================== 硬件配置 ====================
|
||
|
|
USE_GPU = _get_bool("USE_GPU", True)
|
||
|
|
USE_FP16 = _get_bool("USE_FP16", True)
|
||
|
|
BATCH_SIZE = _get_int("BATCH_SIZE", 64)
|
||
|
|
|
||
|
|
# ==================== 目录配置 ====================
|
||
|
|
JRXML_SOURCE_DIR = _get_path("JRXML_SOURCE_DIR", "jrxml_source")
|
||
|
|
CHUNKER_OUTPUT_DIR = _get_path("CHUNKER_OUTPUT_DIR", "jrxml_chunker_output")
|
||
|
|
EMBEDDINGS_DIR = _get_path("EMBEDDINGS_DIR", "embeddings")
|
||
|
|
CHROMA_DB_PATH = _get_path("CHROMA_DB_PATH", "chroma_db")
|
||
|
|
CHROMA_COLLECTION_NAME = _get_str("CHROMA_COLLECTION_NAME", "jrxml_chunks")
|
||
|
|
|
||
|
|
# ==================== 分块配置 ====================
|
||
|
|
MAX_CHUNK_SIZE = _get_int("MAX_CHUNK_SIZE", 2000)
|
||
|
|
|
||
|
|
# ==================== 查询配置 ====================
|
||
|
|
DEFAULT_N_RESULTS = _get_int("DEFAULT_N_RESULTS", 5)
|
||
|
|
SIMILARITY_THRESHOLD = _get_float("SIMILARITY_THRESHOLD", 0.3)
|
||
|
|
|
||
|
|
|
||
|
|
def resolve_model_path() -> str:
|
||
|
|
"""
|
||
|
|
解析模型路径:
|
||
|
|
1. 如果 EMBEDDING_MODEL_PATH 本地存在,使用本地路径
|
||
|
|
2. 否则使用 EMBEDDING_MODEL_NAME 作为 Hub 模型名
|
||
|
|
"""
|
||
|
|
if EMBEDDING_MODEL_PATH.exists():
|
||
|
|
return str(EMBEDDING_MODEL_PATH)
|
||
|
|
return EMBEDDING_MODEL_NAME
|
||
|
|
|
||
|
|
|
||
|
|
def print_config():
|
||
|
|
"""打印当前配置(调试用)"""
|
||
|
|
print(f"{'='*60}")
|
||
|
|
print(f"JRXML RAG 当前配置")
|
||
|
|
print(f"{'='*60}")
|
||
|
|
print(f" 项目根目录: {PROJECT_ROOT}")
|
||
|
|
print(f" 嵌入模型名称: {EMBEDDING_MODEL_NAME}")
|
||
|
|
print(f" 嵌入模型路径: {EMBEDDING_MODEL_PATH}")
|
||
|
|
print(f" 模型解析结果: {resolve_model_path()}")
|
||
|
|
print(f" HF 镜像: {HF_ENDPOINT}")
|
||
|
|
print(f" GPU 加速: {USE_GPU}")
|
||
|
|
print(f" FP16 半精度: {USE_FP16}")
|
||
|
|
print(f" 批处理大小: {BATCH_SIZE}")
|
||
|
|
print(f" JRXML 源目录: {JRXML_SOURCE_DIR}")
|
||
|
|
print(f" 分块输出目录: {CHUNKER_OUTPUT_DIR}")
|
||
|
|
print(f" 向量输出目录: {EMBEDDINGS_DIR}")
|
||
|
|
print(f" Chroma 数据库: {CHROMA_DB_PATH}")
|
||
|
|
print(f" Chroma 集合名: {CHROMA_COLLECTION_NAME}")
|
||
|
|
print(f" 最大 Chunk 大小: {MAX_CHUNK_SIZE}")
|
||
|
|
print(f" 默认返回结果数: {DEFAULT_N_RESULTS}")
|
||
|
|
print(f" 相似度阈值: {SIMILARITY_THRESHOLD}")
|
||
|
|
print(f"{'='*60}")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
print_config()
|