Files
DronePlanning/backend_service/src/py_tree_generator.py
2025-09-21 22:12:21 +08:00

848 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import logging
import uuid
import re
from typing import Dict, Any, Optional, Set, List
import chromadb
import openai
from openai import OpenAIError
import jsonschema
import requests
import platform # 新增:用于选择合适的中文字体
# --- 自定义远程嵌入函数 (与ingest.py中定义一致) ---
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
def __init__(self, api_url: str):
self._api_url = api_url
def __call__(self, input: Embeddable) -> Embeddings:
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
return []
try:
response = requests.post(
self._api_url,
json={"input": input},
headers={"Content-Type": "application/json"}
)
response.raise_for_status()
data = response.json().get("data", [])
if not data:
return []
return [item['embedding'] for item in data]
except Exception as e:
logging.error(f"调用远程嵌入API时出错: {e}")
return []
# --- 日志记录设置 ---
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
# ==============================================================================
# VALIDATION LOGIC (from utils/validation.py)
# ==============================================================================
def _parse_allowed_nodes_from_prompt(prompt_text: str) -> tuple[Set[str], Set[str]]:
"""
从系统提示词中精确解析出允许的行动和条件节点。
"""
try:
# 使用更精确的正则表达式匹配节点定义部分
node_section_pattern = r"#### 2\. 可用节点定义.*?```json\s*({.*?})\s*```"
match = re.search(node_section_pattern, prompt_text, re.DOTALL | re.IGNORECASE)
if not match:
logging.error("在系统提示词中未找到'可用节点定义'部分的JSON代码块。")
# 备用方案尝试查找所有JSON块并识别节点定义
return _fallback_parse_nodes(prompt_text)
json_str = match.group(1)
logging.info("成功找到节点定义JSON代码块")
# 解析JSON
allowed_nodes = json.loads(json_str)
# 从对象列表中提取节点名称
actions = set()
conditions = set()
# 提取动作节点
if "actions" in allowed_nodes and isinstance(allowed_nodes["actions"], list):
for action in allowed_nodes["actions"]:
if isinstance(action, dict) and "name" in action:
actions.add(action["name"])
# 提取条件节点
if "conditions" in allowed_nodes and isinstance(allowed_nodes["conditions"], list):
for condition in allowed_nodes["conditions"]:
if isinstance(condition, dict) and "name" in condition:
conditions.add(condition["name"])
if not actions:
logging.warning("关键错误:从提示词解析出的行动节点列表为空。")
logging.info(f"成功解析出动作节点: {sorted(actions)}")
logging.info(f"成功解析出条件节点: {sorted(conditions)}")
return actions, conditions
except json.JSONDecodeError as e:
logging.error(f"解析节点定义JSON时失败: {e}")
return set(), set()
except Exception as e:
logging.error(f"解析可用节点时发生未知错误: {e}")
return set(), set()
def _fallback_parse_nodes(prompt_text: str) -> tuple[Set[str], Set[str]]:
"""
备用解析方案:当精确匹配失败时使用。
"""
logging.warning("使用备用方案解析节点定义...")
# 查找所有JSON代码块
matches = re.findall(r"```json\s*({.*?})\s*```", prompt_text, re.DOTALL)
if not matches:
logging.error("在系统提示词中未找到任何JSON代码块。")
return set(), set()
# 尝试从每个JSON块中解析节点定义
for i, json_str in enumerate(matches):
try:
data = json.loads(json_str)
# 检查是否是节点定义的结构包含actions、conditions、control_flow
if ("actions" in data and isinstance(data["actions"], list) and
"conditions" in data and isinstance(data["conditions"], list) and
"control_flow" in data and isinstance(data["control_flow"], list)):
actions = set()
conditions = set()
# 提取动作节点
for action in data["actions"]:
if isinstance(action, dict) and "name" in action:
actions.add(action["name"])
# 提取条件节点
for condition in data["conditions"]:
if isinstance(condition, dict) and "name" in condition:
conditions.add(condition["name"])
if actions:
logging.info(f"从第{i+1}个JSON块中成功解析出节点定义")
logging.info(f"动作节点: {sorted(actions)}")
logging.info(f"条件节点: {sorted(conditions)}")
return actions, conditions
except json.JSONDecodeError:
continue # 尝试下一个JSON块
logging.error("在所有JSON代码块中都没有找到有效的节点定义结构。")
return set(), set()
def _find_nodes_by_name(node: Dict, target_name: str) -> List[Dict]:
"""递归查找所有指定名称的节点"""
nodes_found = []
if node.get("name") == target_name:
nodes_found.append(node)
# 递归搜索子节点
for child in node.get("children", []):
nodes_found.extend(_find_nodes_by_name(child, target_name))
return nodes_found
def _validate_safety_monitoring(pytree_instance: dict) -> bool:
"""验证行为树是否包含必要的安全监控"""
root_node = pytree_instance.get("root", {})
# 查找所有电池监控节点
battery_nodes = _find_nodes_by_name(root_node, "battery_above")
# 检查是否包含安全监控结构
safety_monitors = _find_nodes_by_name(root_node, "SafetyMonitor")
if not battery_nodes and not safety_monitors:
logging.warning("⚠️ 安全警告: 行为树中没有发现电池监控节点或安全监控器")
return False
# 检查电池阈值设置是否合理
for battery_node in battery_nodes:
threshold = battery_node.get("params", {}).get("threshold")
if threshold is not None:
if threshold < 0.25:
logging.warning(f"⚠️ 安全警告: 电池阈值设置过低 ({threshold})建议不低于0.25")
elif threshold > 0.5:
logging.warning(f"⚠️ 安全警告: 电池阈值设置过高 ({threshold}),可能影响任务执行")
logging.info("✅ 安全监控验证通过")
return True
def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> dict:
"""
根据允许的行动和条件节点动态生成一个JSON Schema。
"""
# 所有可能的节点类型
node_types = ["action", "condition", "Sequence", "Selector", "Parallel"]
# 目标检测相关的类别枚举
target_classes = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
"traffic_light", "fire_hydrant", "stop_sign", "parking_meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack",
"umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports_ball",
"kite", "baseball_bat", "baseball_glove", "skateboard", "surfboard", "tennis_racket",
"bottle", "wine_glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot_dog", "pizza", "donut", "cake", "chair",
"couch", "potted_plant", "bed", "dining_table", "toilet", "tv", "laptop", "mouse", "remote",
"keyboard", "cell_phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
"clock", "vase", "scissors", "teddy_bear", "hair_drier", "toothbrush","balloon"
]
# 递归节点定义
node_definition = {
"type": "object",
"properties": {
"type": {"type": "string", "enum": node_types},
"name": {"type": "string"},
"params": {"type": "object"},
"children": {
"type": "array",
"items": {"$ref": "#/definitions/node"}
}
},
"required": ["type", "name"],
"allOf": [
# 动作节点验证
{
"if": {"properties": {"type": {"const": "action"}}},
"then": {"properties": {"name": {"enum": sorted(list(allowed_actions))}}}
},
# 条件节点验证
{
"if": {"properties": {"type": {"const": "condition"}}},
"then": {"properties": {"name": {"enum": sorted(list(allowed_conditions))}}}
},
# 目标检测动作节点的参数验证
{
"if": {
"properties": {
"type": {"const": "action"},
"name": {"const": "object_detect"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"target_class": {"type": "string", "enum": target_classes},
"description": {"type": "string"},
"count": {"type": "integer", "minimum": 1}
},
"required": ["target_class"],
"additionalProperties": False
}
}
}
},
# 目标检测条件节点的参数验证
{
"if": {
"properties": {
"type": {"const": "condition"},
"name": {"const": "object_detected"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"target_class": {"type": "string", "enum": target_classes},
"description": {"type": "string"},
"count": {"type": "integer", "minimum": 1}
},
"required": ["target_class"],
"additionalProperties": False
}
}
}
},
# 电池监控节点的参数验证
{
"if": {
"properties": {
"type": {"const": "condition"},
"name": {"const": "battery_above"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"threshold": {"type": "number", "minimum": 0.0, "maximum": 1.0}
},
"required": ["threshold"],
"additionalProperties": False
}
}
}
},
# GPS状态节点的参数验证
{
"if": {
"properties": {
"type": {"const": "condition"},
"name": {"const": "gps_status"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"min_satellites": {"type": "integer", "minimum": 6, "maximum": 15}
},
"required": ["min_satellites"],
"additionalProperties": False
}
}
}
}
]
}
# 完整的Schema结构
schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "Pytree",
"definitions": {
"node": node_definition
},
"type": "object",
"properties": {
"root": { "$ref": "#/definitions/node" }
},
"required": ["root"]
}
return schema
def _generate_simple_mode_schema(allowed_actions: set) -> dict:
"""
生成简单模式JSON Schema{"mode":"simple","action":{...}}
仅校验动作名称在允许集合内,以及基本结构完整性;参数按对象形状放宽,由上游提示词与运行时再约束。
"""
schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "SimpleMode",
"type": "object",
"properties": {
"mode": {"type": "string", "const": "simple"},
"action": {
"type": "object",
"properties": {
"name": {"type": "string", "enum": sorted(list(allowed_actions))},
"params": {"type": "object"}
},
"required": ["name"],
"additionalProperties": True
}
},
"required": ["mode", "action"],
"additionalProperties": False
}
return schema
def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
"""
使用JSON Schema验证给定的Pytree实例。
"""
try:
jsonschema.validate(instance=pytree_instance, schema=schema)
logging.info("✅ JSON Schema验证成功")
# 额外验证安全监控
safety_valid = _validate_safety_monitoring(pytree_instance)
return True and safety_valid
except jsonschema.ValidationError as e:
logging.warning("❌ Pytree验证失败")
logging.warning(f"错误信息: {e.message}")
error_path = list(e.path)
logging.warning(f"错误路径: {' -> '.join(map(str, error_path)) if error_path else '根节点'}")
# 提供更具体的错误信息
if "object_detect" in str(e.message) or "object_detected" in str(e.message):
logging.warning("💡 提示: 请确保目标类别是预定义列表中的有效值")
elif "battery_above" in str(e.message):
logging.warning("💡 提示: 电池阈值必须在0.0到1.0之间")
elif "gps_status" in str(e.message):
logging.warning("💡 提示: 最小卫星数量必须在6到15之间")
return False
except Exception as e:
logging.error(f"进行JSON Schema验证时发生未知错误: {e}")
return False
# ==============================================================================
# VISUALIZATION LOGIC (from utils/visualization.py)
# ==============================================================================
def _visualize_pytree(node: Dict, file_path: str):
"""
使用Graphviz将Pytree字典可视化并保存到指定路径。
"""
try:
from graphviz import Digraph
except ImportError:
logging.critical("错误未安装graphviz库。请运行: pip install graphviz")
return
# 选择合适的中文字体,避免中文乱码
def _pick_zh_font():
sys = platform.system()
if sys == "Windows":
return "Microsoft YaHei"
elif sys == "Darwin":
return "PingFang SC"
else:
return "Noto Sans CJK SC"
fontname = _pick_zh_font()
dot = Digraph('Pytree', comment='Drone Mission Plan')
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20', fontname=fontname)
dot.attr('node', shape='box', style='rounded,filled', fontname=fontname)
dot.attr('edge', fontname=fontname)
_add_nodes_and_edges(node, dot)
try:
# 确保输出目录存在,并避免生成 .png.png
base_path, ext = os.path.splitext(file_path)
render_path = base_path if ext.lower() == '.png' else file_path
out_dir = os.path.dirname(render_path)
if out_dir and not os.path.exists(out_dir):
os.makedirs(out_dir, exist_ok=True)
# 保存为 .png 文件,并自动删除源码 .gv 文件
output_path = dot.render(render_path, format='png', cleanup=True, view=False)
logging.info("✅ 任务树可视化成功")
logging.info(f"图形已保存到: {output_path}")
except Exception as e:
logging.error("❌ 生成可视化图形失败")
logging.error("请确保您的系统已经正确安装了Graphviz图形库。")
logging.error(f"错误详情: {e}")
def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
"""递归辅助函数,用于添加节点和边。"""
# 为每个节点创建一个唯一的ID加上随机数避免冲突
import random
import html
current_id = f"{id(node)}_{random.randint(1000, 9999)}"
# 准备节点标签HTML-like正确换行与转义
name = html.escape(str(node.get('name', '')))
ntype = html.escape(str(node.get('type', '')))
label_parts = [f"<B>{name}</B> <FONT POINT-SIZE='10'><I>({ntype})</I></FONT>"]
# 格式化参数显示
params = node.get('params') or {}
if params:
params_lines = []
for key, value in params.items():
k = html.escape(str(key))
if isinstance(value, float):
value_str = f"{value:.2f}".rstrip('0').rstrip('.')
else:
value_str = str(value)
v = html.escape(value_str)
params_lines.append(f"{k}: {v}")
params_text = "<BR ALIGN='LEFT'/>".join(params_lines)
label_parts.append(f"<FONT POINT-SIZE='9' COLOR='#555555'>{params_text}</FONT>")
node_label = f"<{'<BR/>'.join(label_parts)}>"
# 根据类型设置节点样式和颜色(使用 fillcolor 控制填充色)
node_type = (node.get('type') or '').lower()
shape = 'ellipse'
style = 'filled'
fillcolor = '#e6e6e6' # 默认灰色填充
border_color = '#666666' # 默认描边色
if node_type == 'action':
shape = 'box'
style = 'rounded,filled'
fillcolor = "#cde4ff" # 浅蓝
elif node_type == 'condition':
shape = 'diamond'
style = 'filled'
fillcolor = "#fff2cc" # 浅黄
elif node_type == 'sequence':
shape = 'ellipse'
style = 'filled'
fillcolor = '#d5e8d4' # 绿色
elif node_type == 'selector':
shape = 'ellipse'
style = 'filled'
fillcolor = '#ffe6cc' # 橙色
elif node_type == 'parallel':
shape = 'ellipse'
style = 'filled'
fillcolor = '#e1d5e7' # 紫色
# 特别标记安全相关节点
if node.get('name') in ['battery_above', 'gps_status', 'SafetyMonitor']:
border_color = '#ff0000' # 红色边框突出显示安全节点
style = 'filled,bold' # 加粗
dot.node(current_id, label=node_label, shape=shape, style=style, fillcolor=fillcolor, color=border_color)
# 连接父节点
if parent_id:
dot.edge(parent_id, current_id)
# 递归处理子节点
children = node.get("children", [])
if not children:
return current_id
# 记录所有子节点的ID
child_ids = []
# 正确的递归连接:每个子节点都连接到当前节点
for child in children:
child_id = _add_nodes_and_edges(child, dot, current_id)
child_ids.append(child_id)
# 子节点同级排列(横向排布,更直观地表现同层)
if len(child_ids) > 1:
with dot.subgraph(name=f"rank_{current_id}") as s:
s.attr(rank='same')
for cid in child_ids:
s.node(cid)
# 行为树中,所有类型的节点都只是父连子,不需要子节点间的额外连接
# Sequence、Selector、Parallel 的执行逻辑由行为树引擎处理,不需要在可视化中体现
return current_id
# ==============================================================================
# CORE PYTREE GENERATOR CLASS
# ==============================================================================
class PyTreeGenerator:
def __init__(self):
self.base_dir = os.path.dirname(os.path.abspath(__file__))
self.prompts_dir = os.path.join(self.base_dir, 'prompts')
# Updated output directory for visualizations
self.vis_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_visualizations'))
os.makedirs(self.vis_dir, exist_ok=True)
# Reasoning content output directory (Markdown files)
self.reasoning_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_reasoning_content'))
os.makedirs(self.reasoning_dir, exist_ok=True)
# 控制是否允许模型返回含 <think> 的原文不强制JSON以便提取推理链
self.enable_reasoning_capture = os.getenv("ENABLE_REASONING_CAPTURE", "true").lower() in ("1", "true", "yes")
# 终端预览的最大行数
try:
self.reasoning_preview_lines = int(os.getenv("REASONING_PREVIEW_LINES", "20"))
except Exception:
self.reasoning_preview_lines = 20
# 加载提示词:复杂模式复用现有 system_prompt.txt简单模式与分类器独立提示词
self.complex_prompt = self._load_prompt("system_prompt.txt")
self.simple_prompt = self._load_prompt("simple_mode_prompt.txt")
self.classifier_prompt = self._load_prompt("classifier_prompt.txt")
# 兼容旧变量名
self.system_prompt = self.complex_prompt
self.orin_ip = os.getenv("ORIN_IP", "localhost")
# 三类模型的可配置项基于不同模型与Base URL分流
self.classifier_model = os.getenv("CLASSIFIER_MODEL", os.getenv("OPENAI_MODEL", "local-model"))
self.simple_model = os.getenv("SIMPLE_MODEL", os.getenv("OPENAI_MODEL", "local-model"))
self.complex_model = os.getenv("COMPLEX_MODEL", os.getenv("OPENAI_MODEL", "local-model"))
self.classifier_base_url = os.getenv("CLASSIFIER_BASE_URL", f"http://{self.orin_ip}:8081/v1")
self.simple_base_url = os.getenv("SIMPLE_BASE_URL", f"http://{self.orin_ip}:8081/v1")
self.complex_base_url = os.getenv("COMPLEX_BASE_URL", f"http://{self.orin_ip}:8081/v1")
self.api_key = os.getenv("OPENAI_API_KEY", "sk-no-key-required")
# 为不同用途分别创建客户端
self.classifier_client = openai.OpenAI(api_key=self.api_key, base_url=self.classifier_base_url)
self.simple_llm_client = openai.OpenAI(api_key=self.api_key, base_url=self.simple_base_url)
self.complex_llm_client = openai.OpenAI(api_key=self.api_key, base_url=self.complex_base_url)
# --- ChromaDB Client Setup ---
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'vector_store'))
self.chroma_client = chromadb.PersistentClient(path=vector_store_path)
# Explicitly use the remote embedding function for queries
embedding_api_url = f"http://{self.orin_ip}:8090/v1/embeddings"
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
self.collection = self.chroma_client.get_collection(
name="drone_docs",
embedding_function=embedding_func
)
# 使用复杂模式提示词作为节点来源确保Schema稳定
allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.complex_prompt)
self.schema = _generate_pytree_schema(allowed_actions, allowed_conditions)
self.simple_schema = _generate_simple_mode_schema(allowed_actions)
def _load_prompt(self, file_name: str) -> str:
try:
with open(os.path.join(self.prompts_dir, file_name), 'r', encoding='utf-8') as f:
return f.read()
except FileNotFoundError:
logging.error(f"提示词文件未找到 -> {file_name}")
return ""
def _retrieve_context(self, query: str) -> Optional[str]:
logging.info("--- 开始从向量数据库检索上下文 ---")
try:
results = self.collection.query(query_texts=[query], n_results=5)
retrieved_docs = results.get("documents", [[]])[0]
if not retrieved_docs:
logging.warning("在向量数据库中没有找到相关的上下文信息。")
return None
context_str = "\n\n".join(retrieved_docs)
logging.info("--- 成功检索到上下文信息 ---")
return context_str
except Exception as e:
logging.error(f"从向量数据库检索时发生错误: {e}")
return None
async def generate(self, user_prompt: str) -> Dict[str, Any]:
"""
Generates a py_tree.json structure based on the user's prompt.
"""
logging.info(f"接收到用户请求: {user_prompt}")
# 第一步:分类(简单/复杂)
mode = "complex"
try:
classifier_resp = self.classifier_client.chat.completions.create(
model=self.classifier_model,
messages=[
{"role": "system", "content": self.classifier_prompt or "你是一个分类器只输出JSON。"},
{"role": "user", "content": user_prompt}
],
temperature=0.0,
response_format={"type": "json_object"}
)
class_str = classifier_resp.choices[0].message.content
class_obj = json.loads(class_str)
if isinstance(class_obj, dict) and class_obj.get("mode") in ("simple", "complex"):
mode = class_obj.get("mode")
logging.info(f"分类结果: {mode}")
except Exception as e:
logging.warning(f"分类失败,默认按复杂指令处理: {e}")
# 第二步:根据模式准备提示词与上下文(简单与复杂都执行检索增强)
# 基于模式选择提示词;复杂模式追加一条强制规则,避免模型误输出简单结构
use_prompt = self.simple_prompt if mode == "simple" else (
(self.complex_prompt or "") +
"\n\n【强制规则】仅生成包含root的复杂行为树JSON不得输出简单模式不得包含mode字段或仅有action节点"
)
final_user_prompt = user_prompt
retrieved_context = self._retrieve_context(user_prompt)
if retrieved_context:
augmentation = (
"\n\n---\n"
"参考知识:\n"
"以下是从知识库中检索到的、与当前任务最相关的信息,请优先参考这些信息来生成结果:\n"
f"{retrieved_context}"
"\n---"
)
final_user_prompt += augmentation
else:
logging.warning("未检索到上下文或检索失败,将使用原始用户提示词。")
for attempt in range(3):
logging.info(f"--- 第 {attempt + 1}/3 次尝试生成Pytree ---")
try:
# 简单/复杂分流到不同模型与提示词
client = self.simple_llm_client if mode == "simple" else self.complex_llm_client
model_name = self.simple_model if mode == "simple" else self.complex_model
# 根据是否捕获推理链来决定是否强制JSON响应
response_kwargs = {
"model": model_name,
"messages": [
{"role": "system", "content": use_prompt},
{"role": "user", "content": final_user_prompt}
],
"temperature": 0.1 if mode == "complex" else 0.0,
}
if not self.enable_reasoning_capture:
response_kwargs["response_format"] = {"type": "json_object"}
response = client.chat.completions.create(**response_kwargs)
# 兼容可能存在的 reasoning_content 字段
try:
msg = response.choices[0].message
msg_content = getattr(msg, "content", None)
msg_reasoning = getattr(msg, "reasoning_content", None)
except Exception:
msg = response.choices[0]["message"] if isinstance(response.choices[0], dict) else None
msg_content = (msg or {}).get("content") if isinstance(msg, dict) else None
msg_reasoning = (msg or {}).get("reasoning_content") if isinstance(msg, dict) else None
combined_text = ""
if isinstance(msg_reasoning, str) and msg_reasoning.strip():
# 将 reasoning_content 包装为 <think>,便于统一解析
combined_text += f"<think>\n{msg_reasoning}\n</think>\n"
if isinstance(msg_content, str) and msg_content.strip():
combined_text += msg_content
pytree_str = combined_text if combined_text else (msg_content or "")
# 提取 <think> 推理链内容(若存在)
reasoning_text = None
try:
think_match = re.search(r"<think>([\s\S]*?)</think>", pytree_str)
if think_match:
reasoning_text = think_match.group(1).strip()
# 去除推理文本后再尝试解析JSON
pytree_str = re.sub(r"<think>[\s\S]*?</think>", "", pytree_str).strip()
except Exception:
reasoning_text = None
# 单独捕获JSON解析错误并打印原始响应
try:
pytree_dict = json.loads(pytree_str)
except json.JSONDecodeError as e:
logging.error(f"❌ JSON解析失败{attempt + 1}/3 次)。原始响应如下:\n{pytree_str}")
continue
# 简单/复杂分别验证与返回
if mode == "simple":
try:
jsonschema.validate(instance=pytree_dict, schema=self.simple_schema)
logging.info("✅ 简单模式JSON Schema验证成功")
except jsonschema.ValidationError as e:
logging.warning(f"❌ 简单模式验证失败: {e.message}")
continue
# 附加元信息并生成简单可视化(单动作)
plan_id = str(uuid.uuid4())
pytree_dict['plan_id'] = plan_id
# 简单模式可视化:构造一个简化节点图
try:
vis_filename = "py_tree.png"
vis_path = os.path.join(self.vis_dir, vis_filename)
simple_node = {
"type": "action",
"name": pytree_dict.get('action', {}).get('name', 'action'),
"params": pytree_dict.get('action', {}).get('params', {})
}
_visualize_pytree(simple_node, os.path.splitext(vis_path)[0])
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
except Exception as e:
logging.warning(f"简单模式可视化失败: {e}")
# 保存推理链(若有)
try:
if reasoning_text:
reasoning_path = os.path.join(self.reasoning_dir, "reasoning_content.md")
with open(reasoning_path, 'w', encoding='utf-8') as rf:
rf.write(reasoning_text)
logging.info(f"📝 推理链已保存: {reasoning_path}")
# 终端预览最多N行
try:
lines = reasoning_text.splitlines()
preview = "\n".join(lines[: self.reasoning_preview_lines])
logging.info("🧠 推理链预览(前%d行)\n%s", self.reasoning_preview_lines, preview)
except Exception:
pass
else:
logging.info("未在模型输出中发现 <think> 推理链片段。若需捕获,请设置 ENABLE_REASONING_CAPTURE=true 以放宽JSON强制格式。")
except Exception as e:
logging.warning(f"保存推理链Markdown失败: {e}")
return pytree_dict
# 复杂模式回退:若模型误返回简单结构,则自动包装为含安全监控的行为树
if mode == "complex" and isinstance(pytree_dict, dict) and 'root' not in pytree_dict:
try:
jsonschema.validate(instance=pytree_dict, schema=self.simple_schema)
logging.warning("⚠️ 复杂模式生成了简单结构,触发自动包装为完整行为树的回退逻辑。")
simple_action_obj = pytree_dict.get('action') or {}
action_name = simple_action_obj.get('name')
action_params = simple_action_obj.get('params') if isinstance(simple_action_obj.get('params'), dict) else {}
safety_selector = {
"type": "Selector",
"name": "SafetyMonitor",
"params": {"memory": True},
"children": [
{"type": "condition", "name": "battery_above", "params": {"threshold": 0.3}},
{"type": "condition", "name": "gps_status", "params": {"min_satellites": 8}},
{"type": "Sequence", "name": "EmergencyHandler", "children": [
{"type": "action", "name": "emergency_return", "params": {"reason": "safety_breach"}},
{"type": "action", "name": "land", "params": {"mode": "home"}}
]}
]
}
main_children = [{"type": "action", "name": action_name, "params": action_params}]
if action_name != "land":
main_children.append({"type": "action", "name": "land", "params": {"mode": "home"}})
root_parallel = {
"type": "Parallel",
"name": "MissionWithSafety",
"params": {"policy": "all_success"},
"children": [
{"type": "Sequence", "name": "MainTask", "children": main_children},
safety_selector
]
}
pytree_dict = {"root": root_parallel}
except jsonschema.ValidationError:
# 不符合简单结构,按正常复杂验证继续
pass
if _validate_pytree_with_schema(pytree_dict, self.schema):
logging.info("✅ 成功生成并验证了Pytree")
plan_id = str(uuid.uuid4())
pytree_dict['plan_id'] = plan_id
# Generate visualization to a static path
vis_filename = "py_tree.png"
vis_path = os.path.join(self.vis_dir, vis_filename)
_visualize_pytree(pytree_dict['root'], os.path.splitext(vis_path)[0])
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
# 保存推理链(若有)
try:
if reasoning_text:
reasoning_path = os.path.join(self.reasoning_dir, "reasoning_content.md")
with open(reasoning_path, 'w', encoding='utf-8') as rf:
rf.write(reasoning_text)
logging.info(f"📝 推理链已保存: {reasoning_path}")
# 终端预览最多N行
try:
lines = reasoning_text.splitlines()
preview = "\n".join(lines[: self.reasoning_preview_lines])
logging.info("🧠 推理链预览(前%d行)\n%s", self.reasoning_preview_lines, preview)
except Exception:
pass
else:
logging.info("未在模型输出中发现 <think> 推理链片段。若需捕获,请设置 ENABLE_REASONING_CAPTURE=true 以放宽JSON强制格式。")
except Exception as e:
logging.warning(f"保存推理链Markdown失败: {e}")
return pytree_dict
else:
# 打印未通过验证的Pytree以便排查
preview = json.dumps(pytree_dict, ensure_ascii=False, indent=2)
logging.warning(f"❌ 未通过验证的Pytree{attempt + 1}/3 次尝试):\n{preview}")
logging.warning("生成的Pytree验证失败正在重试...")
except OpenAIError as e:
logging.error(f"生成Pytree时发生错误: {e}")
raise RuntimeError("在3次尝试后仍未能生成一个有效的Pytree。")
# Create a single instance for the application
py_tree_generator = PyTreeGenerator()