2026-05-14 23:39:00 +08:00
|
|
|
"""大语言模型工厂:支持 OpenAI 兼容的云端 API、Anthropic 兼容 API 和本地 Ollama。"""
|
2026-05-14 23:20:56 +08:00
|
|
|
|
|
|
|
|
import os
|
2026-05-15 00:35:41 +08:00
|
|
|
from typing import Any
|
|
|
|
|
|
2026-05-14 23:20:56 +08:00
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
|
|
|
|
|
2026-05-19 15:02:53 +08:00
|
|
|
class _BaseLLM:
|
|
|
|
|
"""LLM 统一接口基类 — 所有后端都提供 invoke() 和 stream()。"""
|
|
|
|
|
|
|
|
|
|
def invoke(self, prompt: str) -> Any:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def stream(self, prompt: str):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
2026-05-14 23:20:56 +08:00
|
|
|
def get_llm():
|
|
|
|
|
backend = os.getenv("LLM_BACKEND", "cloud")
|
|
|
|
|
if backend == "local":
|
|
|
|
|
from langchain_ollama import ChatOllama
|
|
|
|
|
|
|
|
|
|
model = os.getenv("LOCAL_LLM_MODEL", "qwen2.5-coder:7b")
|
2026-05-19 15:02:53 +08:00
|
|
|
raw = ChatOllama(model=model, temperature=0.1)
|
|
|
|
|
|
|
|
|
|
class OllamaWrapper(_BaseLLM):
|
|
|
|
|
def invoke(self, prompt):
|
|
|
|
|
return raw.invoke(prompt)
|
|
|
|
|
|
|
|
|
|
def stream(self, prompt):
|
|
|
|
|
for chunk in raw.stream(prompt):
|
|
|
|
|
yield chunk.content
|
|
|
|
|
|
|
|
|
|
return OllamaWrapper()
|
2026-05-14 23:39:00 +08:00
|
|
|
|
|
|
|
|
provider = os.getenv("LLM_PROVIDER", "openai")
|
|
|
|
|
if provider == "anthropic":
|
2026-05-15 00:35:41 +08:00
|
|
|
from anthropic import Anthropic
|
|
|
|
|
|
|
|
|
|
api_key = os.getenv("OPENAI_API_KEY", "")
|
|
|
|
|
base_url = os.getenv("OPENAI_BASE_URL", "https://api.minimaxi.com/anthropic")
|
|
|
|
|
model = os.getenv("LLM_MODEL", "minimax-2.7")
|
|
|
|
|
temperature = 0.1
|
|
|
|
|
max_tokens = 4096
|
|
|
|
|
|
|
|
|
|
os.environ["NO_PROXY"] = "*"
|
|
|
|
|
|
2026-05-19 09:42:57 +08:00
|
|
|
client = Anthropic(api_key=api_key, base_url=base_url, timeout=120)
|
2026-05-15 00:35:41 +08:00
|
|
|
|
2026-05-19 15:02:53 +08:00
|
|
|
class MiniMaxLLM(_BaseLLM):
|
2026-05-15 00:35:41 +08:00
|
|
|
def invoke(self, prompt: str) -> Any:
|
|
|
|
|
resp = client.messages.create(
|
|
|
|
|
model=model,
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}],
|
|
|
|
|
)
|
|
|
|
|
for block in resp.content:
|
|
|
|
|
if block.type == "text":
|
|
|
|
|
return type("Response", (), {"content": block.text})()
|
|
|
|
|
return type("Response", (), {"content": ""})()
|
|
|
|
|
|
2026-05-19 15:02:53 +08:00
|
|
|
def stream(self, prompt: str):
|
|
|
|
|
with client.messages.stream(
|
|
|
|
|
model=model,
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}],
|
|
|
|
|
) as s:
|
|
|
|
|
for text in s.text_stream:
|
|
|
|
|
yield text
|
|
|
|
|
|
2026-05-15 00:35:41 +08:00
|
|
|
def get_num_tokens(self, text: str) -> int:
|
2026-05-19 15:02:53 +08:00
|
|
|
resp = client.messages.count_tokens(
|
|
|
|
|
model=model,
|
|
|
|
|
messages=[{"role": "user", "content": [{"type": "text", "text": text}]}],
|
|
|
|
|
)
|
|
|
|
|
return resp.input_tokens
|
2026-05-15 00:35:41 +08:00
|
|
|
|
|
|
|
|
return MiniMaxLLM()
|
2026-05-14 23:20:56 +08:00
|
|
|
else:
|
|
|
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
|
|
2026-05-19 15:02:53 +08:00
|
|
|
raw = ChatOpenAI(
|
2026-05-14 23:20:56 +08:00
|
|
|
model=os.getenv("LLM_MODEL", "gpt-4o"),
|
|
|
|
|
api_key=os.getenv("OPENAI_API_KEY"),
|
|
|
|
|
base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"),
|
|
|
|
|
temperature=0.1,
|
|
|
|
|
)
|
|
|
|
|
|
2026-05-19 15:02:53 +08:00
|
|
|
class OpenAIWrapper(_BaseLLM):
|
|
|
|
|
def invoke(self, prompt):
|
|
|
|
|
return raw.invoke(prompt)
|
|
|
|
|
|
|
|
|
|
def stream(self, prompt):
|
|
|
|
|
for chunk in raw.stream(prompt):
|
|
|
|
|
yield chunk.content
|
|
|
|
|
|
|
|
|
|
return OpenAIWrapper()
|
|
|
|
|
|
2026-05-14 23:20:56 +08:00
|
|
|
|
|
|
|
|
def get_llm_for_correction():
|
2026-05-15 00:35:41 +08:00
|
|
|
return get_llm()
|