27 lines
868 B
Python
27 lines
868 B
Python
|
|
"""嵌入模型工厂:支持本地 sentence-transformers 和云端 API。"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
from dotenv import load_dotenv
|
||
|
|
|
||
|
|
load_dotenv()
|
||
|
|
|
||
|
|
|
||
|
|
def get_embeddings():
|
||
|
|
backend = os.getenv("EMBED_BACKEND", "local")
|
||
|
|
if backend == "cloud":
|
||
|
|
from langchain_openai import OpenAIEmbeddings
|
||
|
|
|
||
|
|
return OpenAIEmbeddings(
|
||
|
|
model=os.getenv("EMBED_CLOUD_MODEL", "text-embedding-3-small"),
|
||
|
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||
|
|
base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"),
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
try:
|
||
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||
|
|
except ImportError:
|
||
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||
|
|
|
||
|
|
model = os.getenv("LOCAL_EMBED_MODEL", "Qwen/Qwen3-Embedding-0.6B")
|
||
|
|
return HuggingFaceEmbeddings(model_name=model)
|