增加输出数量约束

This commit is contained in:
2025-09-21 22:33:54 +08:00
parent afd170c451
commit d32520d83f
6 changed files with 87 additions and 267 deletions

View File

@@ -574,6 +574,10 @@ class PyTreeGenerator:
self.simple_base_url = os.getenv("SIMPLE_BASE_URL", f"http://{self.orin_ip}:8081/v1")
self.complex_base_url = os.getenv("COMPLEX_BASE_URL", f"http://{self.orin_ip}:8081/v1")
self.api_key = os.getenv("OPENAI_API_KEY", "sk-no-key-required")
# 直接在代码中指定最大输出token数不通过环境变量
self.classifier_max_tokens = 512
self.simple_max_tokens = 8192
self.complex_max_tokens = 8192
# 为不同用途分别创建客户端
self.classifier_client = openai.OpenAI(api_key=self.api_key, base_url=self.classifier_base_url)
@@ -636,7 +640,8 @@ class PyTreeGenerator:
{"role": "user", "content": user_prompt}
],
temperature=0.0,
response_format={"type": "json_object"}
response_format={"type": "json_object"},
max_tokens=self.classifier_max_tokens
)
class_str = classifier_resp.choices[0].message.content
class_obj = json.loads(class_str)
@@ -682,6 +687,8 @@ class PyTreeGenerator:
}
if not self.enable_reasoning_capture:
response_kwargs["response_format"] = {"type": "json_object"}
# 基于模式设定最大输出token数直接在代码中配置
response_kwargs["max_tokens"] = self.simple_max_tokens if mode == "simple" else self.complex_max_tokens
response = client.chat.completions.create(**response_kwargs)
# 兼容可能存在的 reasoning_content 字段
try:
@@ -700,6 +707,7 @@ class PyTreeGenerator:
if isinstance(msg_content, str) and msg_content.strip():
combined_text += msg_content
pytree_str = combined_text if combined_text else (msg_content or "")
raw_full_text_for_logging = pytree_str # 保存完整原文(含 <think>)以便失败时完整打印
# 提取 <think> 推理链内容(若存在)
reasoning_text = None
@@ -715,7 +723,41 @@ class PyTreeGenerator:
try:
pytree_dict = json.loads(pytree_str)
except json.JSONDecodeError as e:
logging.error(f"❌ JSON解析失败{attempt + 1}/3 次)。原始响应如下:\n{pytree_str}")
logging.error(f"❌ JSON解析失败{attempt + 1}/3 次)。\n—— 完整原始文本(含<think>) ——\n{raw_full_text_for_logging}")
# 尝试打印响应对象的完整结构
try:
raw_response_dump = None
if hasattr(response, 'model_dump_json'):
raw_response_dump = response.model_dump_json(indent=2, exclude_none=False)
elif hasattr(response, 'dict'):
raw_response_dump = json.dumps(response.dict(), ensure_ascii=False, indent=2, default=str)
else:
# 兜底尝试将choices与关键字段展开
safe_obj = {
"id": getattr(response, 'id', None),
"model": getattr(response, 'model', None),
"object": getattr(response, 'object', None),
"usage": getattr(response, 'usage', None),
"choices": [
{
"index": getattr(c, 'index', None),
"finish_reason": getattr(c, 'finish_reason', None),
"message": {
"role": getattr(getattr(c, 'message', None), 'role', None),
"content": getattr(getattr(c, 'message', None), 'content', None),
"reasoning_content": getattr(getattr(c, 'message', None), 'reasoning_content', None)
} if getattr(c, 'message', None) is not None else None
}
for c in getattr(response, 'choices', [])
] if hasattr(response, 'choices') else None
}
raw_response_dump = json.dumps(safe_obj, ensure_ascii=False, indent=2, default=str)
logging.error(f"—— 完整响应对象 ——\n{raw_response_dump}")
except Exception as dump_e:
try:
logging.error(f"响应对象转储失败repr如下\n{repr(response)}")
except Exception:
pass
continue
# 简单/复杂分别验证与返回