import json
import os
import logging
import uuid
import re
from typing import Dict, Any, Optional, Set, List
import chromadb
import openai
from openai import OpenAIError
import jsonschema
import requests
import platform # 新增:用于选择合适的中文字体
# --- 自定义远程嵌入函数 (与ingest.py中定义一致) ---
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
def __init__(self, api_url: str):
self._api_url = api_url
def __call__(self, input: Embeddable) -> Embeddings:
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
return []
try:
response = requests.post(
self._api_url,
json={"input": input},
headers={"Content-Type": "application/json"}
)
response.raise_for_status()
data = response.json().get("data", [])
if not data:
return []
return [item['embedding'] for item in data]
except Exception as e:
logging.error(f"调用远程嵌入API时出错: {e}")
return []
# --- 日志记录设置 ---
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
# ==============================================================================
# VALIDATION LOGIC (from utils/validation.py)
# ==============================================================================
def _parse_allowed_nodes_from_prompt(prompt_text: str) -> tuple[Set[str], Set[str]]:
"""
从系统提示词中精确解析出允许的行动和条件节点。
"""
try:
# 使用更精确的正则表达式匹配节点定义部分
node_section_pattern = r"#### 1\. 可用节点定义.*?```json\s*({.*?})\s*```"
match = re.search(node_section_pattern, prompt_text, re.DOTALL | re.IGNORECASE)
if not match:
logging.error("在系统提示词中未找到'可用节点定义'部分的JSON代码块。")
# 备用方案:尝试查找所有JSON块并识别节点定义
return _fallback_parse_nodes(prompt_text)
json_str = match.group(1)
logging.info("成功找到节点定义JSON代码块")
# 解析JSON
allowed_nodes = json.loads(json_str)
# 从对象列表中提取节点名称
actions = set()
conditions = set()
# 提取动作节点
if "actions" in allowed_nodes and isinstance(allowed_nodes["actions"], list):
for action in allowed_nodes["actions"]:
if isinstance(action, dict) and "name" in action:
actions.add(action["name"])
# 提取条件节点
if "conditions" in allowed_nodes and isinstance(allowed_nodes["conditions"], list):
for condition in allowed_nodes["conditions"]:
if isinstance(condition, dict) and "name" in condition:
conditions.add(condition["name"])
if not actions:
logging.warning("关键错误:从提示词解析出的行动节点列表为空。")
logging.info(f"成功解析出动作节点: {sorted(actions)}")
logging.info(f"成功解析出条件节点: {sorted(conditions)}")
return actions, conditions
except json.JSONDecodeError as e:
logging.error(f"解析节点定义JSON时失败: {e}")
return set(), set()
except Exception as e:
logging.error(f"解析可用节点时发生未知错误: {e}")
return set(), set()
def _fallback_parse_nodes(prompt_text: str) -> tuple[Set[str], Set[str]]:
"""
备用解析方案:当精确匹配失败时使用。
"""
logging.warning("使用备用方案解析节点定义...")
# 查找所有JSON代码块
matches = re.findall(r"```json\s*({.*?})\s*```", prompt_text, re.DOTALL)
if not matches:
logging.error("在系统提示词中未找到任何JSON代码块。")
return set(), set()
# 尝试从每个JSON块中解析节点定义
for i, json_str in enumerate(matches):
try:
data = json.loads(json_str)
# 检查是否是节点定义的结构(包含actions、conditions、control_flow)
if ("actions" in data and isinstance(data["actions"], list) and
"conditions" in data and isinstance(data["conditions"], list) and
"control_flow" in data and isinstance(data["control_flow"], list)):
actions = set()
conditions = set()
# 提取动作节点
for action in data["actions"]:
if isinstance(action, dict) and "name" in action:
actions.add(action["name"])
# 提取条件节点
for condition in data["conditions"]:
if isinstance(condition, dict) and "name" in condition:
conditions.add(condition["name"])
if actions:
logging.info(f"从第{i+1}个JSON块中成功解析出节点定义")
logging.info(f"动作节点: {sorted(actions)}")
logging.info(f"条件节点: {sorted(conditions)}")
return actions, conditions
except json.JSONDecodeError:
continue # 尝试下一个JSON块
logging.error("在所有JSON代码块中都没有找到有效的节点定义结构。")
return set(), set()
def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> dict:
"""
根据允许的行动和条件节点,动态生成一个JSON Schema。
"""
# 所有可能的节点类型
node_types = ["action", "condition", "Sequence", "Selector", "Parallel", "decorator"]
# 目标检测相关的类别枚举
target_classes = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
"traffic_light", "fire_hydrant", "stop_sign", "parking_meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack",
"umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports_ball",
"kite", "baseball_bat", "baseball_glove", "skateboard", "surfboard", "tennis_racket",
"bottle", "wine_glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot_dog", "pizza", "donut", "cake", "chair",
"couch", "potted_plant", "bed", "dining_table", "toilet", "tv", "laptop", "mouse", "remote",
"keyboard", "cell_phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
"clock", "vase", "scissors", "teddy_bear", "hair_drier", "toothbrush","balloon","trash","window"
]
# 递归节点定义
node_definition = {
"type": "object",
"properties": {
# 修改:手动构造不区分大小写的正则,避免使用不支持的 (?i) 标志
# 匹配: action, condition, sequence, selector, parallel, decorator (忽略大小写)
"type": {
"type": "string",
"pattern": "^([Aa][Cc][Tt][Ii][Oo][Nn]|[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]|[Ss][Ee][Qq][Uu][Ee][Nn][Cc][Ee]|[Ss][Ee][Ll][Ee][Cc][Tt][Oo][Rr]|[Pp][Aa][Rr][Aa][Ll][Ll][Ee][Ll]|[Dd][Ee][Cc][Oo][Rr][Aa][Tt][Oo][Rr])$"
},
"name": {"type": "string"},
"params": {"type": "object"},
"children": {
"type": "array",
"items": {"$ref": "#/definitions/node"}
},
"child": {"$ref": "#/definitions/node"}
},
"required": ["type", "name"],
"allOf": [
# 动作节点验证 (忽略大小写)
{
"if": {"properties": {"type": {"pattern": "^[Aa][Cc][Tt][Ii][Oo][Nn]$"}}},
"then": {"properties": {"name": {"enum": sorted(list(allowed_actions))}}}
},
# 条件节点验证 (忽略大小写)
{
"if": {"properties": {"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"}}},
"then": {"properties": {"name": {"enum": sorted(list(allowed_conditions))}}}
},
# 目标检测动作节点的参数验证 (忽略大小写)
{
"if": {
"properties": {
"type": {"pattern": "^[Aa][Cc][Tt][Ii][Oo][Nn]$"},
"name": {"const": "object_detect"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"target_class": {"type": "string", "enum": target_classes},
"description": {"type": "string"},
"count": {"type": "integer", "minimum": 1}
},
"required": ["target_class"],
"additionalProperties": False
}
}
}
},
# 目标检测条件节点的参数验证 (忽略大小写)
{
"if": {
"properties": {
"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"},
"name": {"const": "object_detected"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"target_class": {"type": "string", "enum": target_classes},
"description": {"type": "string"},
"count": {"type": "integer", "minimum": 1}
},
"required": ["target_class"],
"additionalProperties": False
}
}
}
},
# 电池监控节点的参数验证 (忽略大小写)
{
"if": {
"properties": {
"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"},
"name": {"const": "battery_above"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"threshold": {"type": "number", "minimum": 0.0, "maximum": 1.0}
},
"required": ["threshold"],
"additionalProperties": False
}
}
}
},
# GPS状态节点的参数验证 (忽略大小写)
{
"if": {
"properties": {
"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"},
"name": {"const": "gps_status"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"min_satellites": {"type": "integer", "minimum": 6, "maximum": 15}
},
"required": ["min_satellites"],
"additionalProperties": False
}
}
}
}
]
}
# 完整的Schema结构
schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "Pytree",
"definitions": {
"node": node_definition
},
"type": "object",
"properties": {
"root": { "$ref": "#/definitions/node" }
},
"required": ["root"]
}
return schema
def _generate_simple_mode_schema(allowed_actions: set) -> dict:
"""
生成简单模式JSON Schema:{"root":{"type":"action","name":"...","params":{...}}}
简单模式与复杂模式使用相同的格式(root字段),但要求root必须是action类型且没有children。
严格按照提示词要求:root节点必须是action类型节点,不能是控制流节点(即不能有children)。
仅校验动作名称在允许集合内,以及基本结构完整性;参数按对象形状放宽,由上游提示词与运行时再约束。
"""
schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "SimpleMode",
"type": "object",
"properties": {
"root": {
"type": "object",
"properties": {
"type": {"type": "string", "const": "action"}, # 必须是action类型,不能是控制流节点(Sequence/Selector/Parallel)
"name": {"type": "string", "enum": sorted(list(allowed_actions))}, # 动作名称必须在允许列表中
"params": {"type": "object"} # params是对象,具体参数由提示词和运行时约束
},
"required": ["type", "name"], # type和name是必需的,params可选
"additionalProperties": True # 允许root节点有其他属性(如额外的元数据)
# 注意:children字段的检查在验证后手动进行,因为JSON Schema的not/allOf在检查不存在字段时可能有问题
}
},
"required": ["root"], # 顶层必须有root字段
"additionalProperties": False # 顶层只能有root字段,不能有其他字段(如mode等)
}
return schema
def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
"""
使用JSON Schema验证给定的Pytree实例。
"""
try:
jsonschema.validate(instance=pytree_instance, schema=schema)
logging.info("✅ JSON Schema验证成功")
return True
except jsonschema.ValidationError as e:
logging.warning("❌ Pytree验证失败")
logging.warning(f"错误信息: {e.message}")
error_path = list(e.path)
logging.warning(f"错误路径: {' -> '.join(map(str, error_path)) if error_path else '根节点'}")
# 提供更具体的错误信息
if "object_detect" in str(e.message) or "object_detected" in str(e.message):
logging.warning("💡 提示: 请确保目标类别是预定义列表中的有效值")
elif "battery_above" in str(e.message):
logging.warning("💡 提示: 电池阈值必须在0.0到1.0之间")
elif "gps_status" in str(e.message):
logging.warning("💡 提示: 最小卫星数量必须在6到15之间")
return False
except Exception as e:
logging.error(f"进行JSON Schema验证时发生未知错误: {e}")
return False
# ==============================================================================
# VISUALIZATION LOGIC (from utils/visualization.py)
# ==============================================================================
def _visualize_pytree(node: Dict, file_path: str):
"""
使用Graphviz将Pytree字典可视化,并保存到指定路径。
"""
try:
from graphviz import Digraph
except ImportError:
logging.critical("错误:未安装graphviz库。请运行: pip install graphviz")
return
# 选择合适的中文字体,避免中文乱码
def _pick_zh_font():
sys = platform.system()
if sys == "Windows":
return "Microsoft YaHei"
elif sys == "Darwin":
return "PingFang SC"
else:
return "Noto Sans CJK SC"
fontname = _pick_zh_font()
dot = Digraph('Pytree', comment='Drone Mission Plan')
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20', fontname=fontname)
dot.attr('node', shape='box', style='rounded,filled', fontname=fontname)
dot.attr('edge', fontname=fontname)
_add_nodes_and_edges(node, dot)
try:
# 确保输出目录存在,并避免生成 .png.png
base_path, ext = os.path.splitext(file_path)
render_path = base_path if ext.lower() == '.png' else file_path
out_dir = os.path.dirname(render_path)
if out_dir and not os.path.exists(out_dir):
os.makedirs(out_dir, exist_ok=True)
# 保存为 .png 文件,并自动删除源码 .gv 文件
output_path = dot.render(render_path, format='png', cleanup=True, view=False)
logging.info("✅ 任务树可视化成功")
logging.info(f"图形已保存到: {output_path}")
except Exception as e:
logging.error("❌ 生成可视化图形失败")
logging.error("请确保您的系统已经正确安装了Graphviz图形库。")
logging.error(f"错误详情: {e}")
def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
"""递归辅助函数,用于添加节点和边。"""
# 为每个节点创建一个唯一的ID(加上随机数避免冲突)
import random
import html
current_id = f"{id(node)}_{random.randint(1000, 9999)}"
# 准备节点标签(HTML-like,正确换行与转义)
name = html.escape(str(node.get('name', '')))
ntype = html.escape(str(node.get('type', '')))
label_parts = [f"{name} ({ntype})"]
# 格式化参数显示
params = node.get('params') or {}
if params:
params_lines = []
for key, value in params.items():
k = html.escape(str(key))
if isinstance(value, float):
value_str = f"{value:.2f}".rstrip('0').rstrip('.')
else:
value_str = str(value)
v = html.escape(value_str)
params_lines.append(f"{k}: {v}")
params_text = "
".join(params_lines)
label_parts.append(f"{params_text}")
node_label = f"<{'
'.join(label_parts)}>"
# 根据类型设置节点样式和颜色(使用 fillcolor 控制填充色)
node_type = (node.get('type') or '').lower()
shape = 'ellipse'
style = 'filled'
fillcolor = '#e6e6e6' # 默认灰色填充
border_color = '#666666' # 默认描边色
if node_type == 'action':
shape = 'box'
style = 'rounded,filled'
fillcolor = "#cde4ff" # 浅蓝
elif node_type == 'condition':
shape = 'diamond'
style = 'filled'
fillcolor = "#fff2cc" # 浅黄
elif node_type == 'sequence':
shape = 'ellipse'
style = 'filled'
fillcolor = '#d5e8d4' # 绿色
elif node_type == 'selector':
shape = 'ellipse'
style = 'filled'
fillcolor = '#ffe6cc' # 橙色
elif node_type == 'parallel':
shape = 'ellipse'
style = 'filled'
fillcolor = '#e1d5e7' # 紫色
elif node_type == 'decorator':
shape = 'doubleoctagon'
style = 'filled'
fillcolor = '#f8cecc' # 浅红
# 特别标记安全相关节点
if node.get('name') in ['battery_above', 'gps_status', 'SafetyMonitor']:
border_color = '#ff0000' # 红色边框突出显示安全节点
style = 'filled,bold' # 加粗
dot.node(current_id, label=node_label, shape=shape, style=style, fillcolor=fillcolor, color=border_color)
# 连接父节点
if parent_id:
dot.edge(parent_id, current_id)
# 递归处理子节点 (Sequence, Selector, Parallel 等)
children = node.get("children", [])
# 兼容 decorator 类型的 child 字段 (处理为单元素列表以便统一逻辑)
if node_type == 'decorator' and 'child' in node:
children = [node['child']]
if children:
# 记录所有子节点的ID
child_ids = []
# 正确的递归连接:每个子节点都连接到当前节点
for child in children:
child_id = _add_nodes_and_edges(child, dot, current_id)
child_ids.append(child_id)
# 子节点同级排列(横向排布,更直观地表现同层)
if len(child_ids) > 1:
with dot.subgraph(name=f"rank_{current_id}") as s:
s.attr(rank='same')
for cid in child_ids:
s.node(cid)
# 递归处理单子节点 (Decorator) - 已合并到 children 处理逻辑中,此处删除旧逻辑
# child = node.get("child")
# if child:
# _add_nodes_and_edges(child, dot, current_id)
return current_id
# ==============================================================================
# CORE PYTREE GENERATOR CLASS
# ==============================================================================
class PyTreeGenerator:
def __init__(self):
self.base_dir = os.path.dirname(os.path.abspath(__file__))
self.prompts_dir = os.path.join(self.base_dir, 'prompts')
# Updated output directory for visualizations
self.vis_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_visualizations'))
os.makedirs(self.vis_dir, exist_ok=True)
# Reasoning content output directory (Markdown files)
self.reasoning_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_reasoning_content'))
os.makedirs(self.reasoning_dir, exist_ok=True)
# 控制是否允许模型返回含 的原文(不强制JSON),以便提取推理链
self.enable_reasoning_capture = os.getenv("ENABLE_REASONING_CAPTURE", "true").lower() in ("1", "true", "yes")
# 终端预览的最大行数
try:
self.reasoning_preview_lines = int(os.getenv("REASONING_PREVIEW_LINES", "20"))
except Exception:
self.reasoning_preview_lines = 20
# 加载提示词:复杂模式复用现有 system_prompt.txt;简单模式与分类器独立提示词
self.complex_prompt = self._load_prompt("system_prompt.txt")
self.simple_prompt = self._load_prompt("simple_mode_prompt.txt")
self.classifier_prompt = self._load_prompt("classifier_prompt.txt")
# 兼容旧变量名
self.system_prompt = self.complex_prompt
self.orin_ip = os.getenv("ORIN_IP", "localhost")
# 三类模型的可配置项:基于不同模型与Base URL分流
self.classifier_model = os.getenv("CLASSIFIER_MODEL", os.getenv("OPENAI_MODEL", "local-model"))
self.simple_model = os.getenv("SIMPLE_MODEL", os.getenv("OPENAI_MODEL", "local-model"))
self.complex_model = os.getenv("COMPLEX_MODEL", os.getenv("OPENAI_MODEL", "local-model"))
self.classifier_base_url = os.getenv("CLASSIFIER_BASE_URL", f"http://{self.orin_ip}:8081/v1")
self.simple_base_url = os.getenv("SIMPLE_BASE_URL", f"http://{self.orin_ip}:8081/v1")
self.complex_base_url = os.getenv("COMPLEX_BASE_URL", f"http://{self.orin_ip}:8081/v1")
self.api_key = os.getenv("OPENAI_API_KEY", "sk-no-key-required")
# 直接在代码中指定最大输出token数(不通过环境变量)
self.classifier_max_tokens = 512
self.simple_max_tokens = 8192
self.complex_max_tokens = 8192
# 为不同用途分别创建客户端
self.classifier_client = openai.OpenAI(api_key=self.api_key, base_url=self.classifier_base_url)
self.simple_llm_client = openai.OpenAI(api_key=self.api_key, base_url=self.simple_base_url)
self.complex_llm_client = openai.OpenAI(api_key=self.api_key, base_url=self.complex_base_url)
# --- ChromaDB Client Setup ---
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'rag','vector_store'))
self.chroma_client = chromadb.PersistentClient(path=vector_store_path)
# Explicitly use the remote embedding function for queries
embedding_api_url = f"http://{self.orin_ip}:8090/v1/embeddings"
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
self.collection = self.chroma_client.get_collection(
name="drone_docs",
embedding_function=embedding_func
)
# 使用复杂模式提示词作为节点来源,确保Schema稳定
allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.complex_prompt)
self.schema = _generate_pytree_schema(allowed_actions, allowed_conditions)
self.simple_schema = _generate_simple_mode_schema(allowed_actions)
def _load_prompt(self, file_name: str) -> str:
try:
with open(os.path.join(self.prompts_dir, file_name), 'r', encoding='utf-8') as f:
return f.read()
except FileNotFoundError:
logging.error(f"提示词文件未找到 -> {file_name}")
return ""
def _retrieve_context(self, query: str) -> Optional[str]:
logging.info("--- 开始从向量数据库检索上下文 ---")
try:
results = self.collection.query(query_texts=[query], n_results=5)
retrieved_docs = results.get("documents", [[]])[0]
if not retrieved_docs:
logging.warning("在向量数据库中没有找到相关的上下文信息。")
return None
context_str = "\n\n".join(retrieved_docs)
logging.info("--- 成功检索到上下文信息 ---")
# 打印检索到的上下文内容
logging.info(f"📚 检索到的上下文内容:\n{context_str}")
return context_str
except Exception as e:
logging.error(f"从向量数据库检索时发生错误: {e}")
return None
async def generate(self, user_prompt: str) -> Dict[str, Any]:
"""
Generates a py_tree.json structure based on the user's prompt.
"""
logging.info(f"接收到用户请求: {user_prompt}")
# 第一步:分类(简单/复杂)
mode = "complex"
try:
classifier_resp = self.classifier_client.chat.completions.create(
model=self.classifier_model,
messages=[
{"role": "system", "content": self.classifier_prompt or "你是一个分类器,只输出JSON。"},
{"role": "user", "content": user_prompt}
],
temperature=0.0,
response_format={"type": "json_object"}, # 强制JSON输出,禁用思考功能
max_tokens=self.classifier_max_tokens,
# 禁用 Qwen3 模型的思考功能(通过 extra_body 传递)
# 注意:如果 API 服务器不支持此参数,会忽略
extra_body={"chat_template_kwargs": {"enable_thinking": False}}
)
class_str = classifier_resp.choices[0].message.content
class_obj = json.loads(class_str)
if isinstance(class_obj, dict) and class_obj.get("mode") in ("simple", "complex"):
mode = class_obj.get("mode")
logging.info(f"分类结果: {mode}")
except Exception as e:
logging.warning(f"分类失败,默认按复杂指令处理: {e}")
# 第二步:根据模式准备提示词与上下文(简单与复杂都执行检索增强)
# 基于模式选择提示词;复杂模式追加一条强制规则,避免模型误输出简单结构
use_prompt = self.simple_prompt if mode == "simple" else (
(self.complex_prompt or "") +
"\n\n【强制规则】仅生成包含root的复杂行为树JSON,不得输出简单模式(不得包含mode字段或仅有action节点)。"
)
final_user_prompt = user_prompt
retrieved_context = self._retrieve_context(user_prompt)
if retrieved_context:
augmentation = (
"\n\n---\n"
"参考知识:\n"
"以下是从知识库中检索到的、与当前任务最相关的信息,请优先参考这些信息来生成结果:\n"
f"{retrieved_context}"
"\n---"
)
final_user_prompt += augmentation
else:
logging.warning("未检索到上下文或检索失败,将使用原始用户提示词。")
# 构建完整的 final_prompt(准确反映实际发送给大模型的内容结构)
# 注意:RAG检索结果被添加到 user prompt 中,而不是 system prompt
# System Prompt: use_prompt(不包含RAG结果)
# User Prompt: final_user_prompt(包含原始user_prompt + RAG检索结果)
final_prompt = f"=== System Prompt ===\n{use_prompt}\n\n=== User Prompt ===\n{final_user_prompt}"
for attempt in range(3):
logging.info(f"--- 第 {attempt + 1}/3 次尝试生成Pytree ---")
try:
# 简单/复杂分流到不同模型与提示词
client = self.simple_llm_client if mode == "simple" else self.complex_llm_client
model_name = self.simple_model if mode == "simple" else self.complex_model
# 始终强制JSON响应并禁用思考功能
response_kwargs = {
"model": model_name,
"messages": [
{"role": "system", "content": use_prompt},
{"role": "user", "content": final_user_prompt}
],
"temperature": 0.1 if mode == "complex" else 0.0,
"response_format": {"type": "json_object"}, # 始终强制JSON输出,禁用思考功能
# 禁用 Qwen3 模型的思考功能(通过 extra_body 传递)
# 注意:如果 API 服务器不支持此参数,会忽略
"extra_body": {"chat_template_kwargs": {"enable_thinking": False}}
}
# 基于模式设定最大输出token数(直接在代码中配置)
response_kwargs["max_tokens"] = self.simple_max_tokens if mode == "simple" else self.complex_max_tokens
response = client.chat.completions.create(**response_kwargs)
# 兼容可能存在的 reasoning_content 字段
try:
msg = response.choices[0].message
msg_content = getattr(msg, "content", None)
msg_reasoning = getattr(msg, "reasoning_content", None)
except Exception:
msg = response.choices[0]["message"] if isinstance(response.choices[0], dict) else None
msg_content = (msg or {}).get("content") if isinstance(msg, dict) else None
msg_reasoning = (msg or {}).get("reasoning_content") if isinstance(msg, dict) else None
combined_text = ""
if isinstance(msg_reasoning, str) and msg_reasoning.strip():
# 将 reasoning_content 包装为 ,便于统一解析
combined_text += f"\n{msg_reasoning}\n\n"
if isinstance(msg_content, str) and msg_content.strip():
combined_text += msg_content
pytree_str = combined_text if combined_text else (msg_content or "")
raw_full_text_for_logging = pytree_str # 保存完整原文(含 )以便失败时完整打印
# 提取 推理链内容(若有)
reasoning_text = None
try:
think_match = re.search(r"([\s\S]*?)", pytree_str)
if think_match:
reasoning_text = think_match.group(1).strip()
# 去除推理文本后再尝试解析JSON
pytree_str = re.sub(r"[\s\S]*?", "", pytree_str).strip()
except Exception:
reasoning_text = None
# 单独捕获JSON解析错误并打印原始响应
try:
pytree_dict = json.loads(pytree_str)
except json.JSONDecodeError as e:
logging.error(f"❌ JSON解析失败(第 {attempt + 1}/3 次)。\n—— 完整原始文本(含) ——\n{raw_full_text_for_logging}")
# 尝试打印响应对象的完整结构
try:
raw_response_dump = None
if hasattr(response, 'model_dump_json'):
raw_response_dump = response.model_dump_json(indent=2, exclude_none=False)
elif hasattr(response, 'dict'):
raw_response_dump = json.dumps(response.dict(), ensure_ascii=False, indent=2, default=str)
else:
# 兜底:尝试将choices与关键字段展开
safe_obj = {
"id": getattr(response, 'id', None),
"model": getattr(response, 'model', None),
"object": getattr(response, 'object', None),
"usage": getattr(response, 'usage', None),
"choices": [
{
"index": getattr(c, 'index', None),
"finish_reason": getattr(c, 'finish_reason', None),
"message": {
"role": getattr(getattr(c, 'message', None), 'role', None),
"content": getattr(getattr(c, 'message', None), 'content', None),
"reasoning_content": getattr(getattr(c, 'message', None), 'reasoning_content', None)
} if getattr(c, 'message', None) is not None else None
}
for c in getattr(response, 'choices', [])
] if hasattr(response, 'choices') else None
}
raw_response_dump = json.dumps(safe_obj, ensure_ascii=False, indent=2, default=str)
logging.error(f"—— 完整响应对象 ——\n{raw_response_dump}")
except Exception as dump_e:
try:
logging.error(f"响应对象转储失败,repr如下:\n{repr(response)}")
except Exception:
pass
continue
# 简单/复杂分别验证与返回
if mode == "simple":
try:
jsonschema.validate(instance=pytree_dict, schema=self.simple_schema)
# 手动检查:简单模式的root节点不能有children(或children必须是空数组)
root_node = pytree_dict.get('root', {})
if 'children' in root_node:
children = root_node.get('children', [])
if isinstance(children, list) and len(children) > 0:
logging.warning(f"❌ 简单模式验证失败: root节点不能有children,但发现 {len(children)} 个子节点")
continue
logging.info("✅ 简单模式JSON Schema验证成功")
except jsonschema.ValidationError as e:
logging.warning(f"❌ 简单模式验证失败: {e.message}")
continue
# 附加元信息并生成简单可视化(单动作)
plan_id = str(uuid.uuid4())
pytree_dict['plan_id'] = plan_id
# 简单模式可视化:使用root节点(已经是action类型)
try:
vis_filename = "py_tree.png"
vis_path = os.path.join(self.vis_dir, vis_filename)
# 简单模式的root节点就是action节点,直接使用
root_node = pytree_dict.get('root', {})
_visualize_pytree(root_node, os.path.splitext(vis_path)[0])
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
except Exception as e:
logging.warning(f"简单模式可视化失败: {e}")
# 保存推理链(若有)
try:
if reasoning_text:
reasoning_path = os.path.join(self.reasoning_dir, "reasoning_content.md")
with open(reasoning_path, 'w', encoding='utf-8') as rf:
rf.write(reasoning_text)
logging.info(f"📝 推理链已保存: {reasoning_path}")
# 终端预览(最多N行)
try:
lines = reasoning_text.splitlines()
preview = "\n".join(lines[: self.reasoning_preview_lines])
logging.info("🧠 推理链预览(前%d行):\n%s", self.reasoning_preview_lines, preview)
except Exception:
pass
else:
logging.info("未在模型输出中发现 推理链片段。若需捕获,请设置 ENABLE_REASONING_CAPTURE=true 以放宽JSON强制格式。")
except Exception as e:
logging.warning(f"保存推理链Markdown失败: {e}")
# 添加 final_prompt 到返回结果
pytree_dict['final_prompt'] = final_prompt
return pytree_dict
# 验证生成的复杂行为树
if _validate_pytree_with_schema(pytree_dict, self.schema):
logging.info("✅ 成功生成并验证了Pytree")
plan_id = str(uuid.uuid4())
pytree_dict['plan_id'] = plan_id
# Generate visualization to a static path
vis_filename = "py_tree.png"
vis_path = os.path.join(self.vis_dir, vis_filename)
_visualize_pytree(pytree_dict['root'], os.path.splitext(vis_path)[0])
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
# 保存推理链(若有)
try:
if reasoning_text:
reasoning_path = os.path.join(self.reasoning_dir, "reasoning_content.md")
with open(reasoning_path, 'w', encoding='utf-8') as rf:
rf.write(reasoning_text)
logging.info(f"📝 推理链已保存: {reasoning_path}")
# 终端预览(最多N行)
try:
lines = reasoning_text.splitlines()
preview = "\n".join(lines[: self.reasoning_preview_lines])
logging.info("🧠 推理链预览(前%d行):\n%s", self.reasoning_preview_lines, preview)
except Exception:
pass
else:
logging.info("未在模型输出中发现 推理链片段。若需捕获,请设置 ENABLE_REASONING_CAPTURE=true 以放宽JSON强制格式。")
except Exception as e:
logging.warning(f"保存推理链Markdown失败: {e}")
# 添加 final_prompt 到返回结果
pytree_dict['final_prompt'] = final_prompt
return pytree_dict
else:
# 打印未通过验证的Pytree以便排查
preview = json.dumps(pytree_dict, ensure_ascii=False, indent=2)
logging.warning(f"❌ 未通过验证的Pytree(第 {attempt + 1}/3 次尝试):\n{preview}")
logging.warning("生成的Pytree验证失败,正在重试...")
except OpenAIError as e:
logging.error(f"生成Pytree时发生错误: {e}")
raise RuntimeError("在3次尝试后,仍未能生成一个有效的Pytree。")
# Create a single instance for the application
py_tree_generator = PyTreeGenerator()