优化交互式测试验证脚本,针对场景4修改提示词以及代码
This commit is contained in:
@@ -52,7 +52,7 @@ def _parse_allowed_nodes_from_prompt(prompt_text: str) -> tuple[Set[str], Set[st
|
||||
"""
|
||||
try:
|
||||
# 使用更精确的正则表达式匹配节点定义部分
|
||||
node_section_pattern = r"#### 2\. 可用节点定义.*?```json\s*({.*?})\s*```"
|
||||
node_section_pattern = r"#### 1\. 可用节点定义.*?```json\s*({.*?})\s*```"
|
||||
match = re.search(node_section_pattern, prompt_text, re.DOTALL | re.IGNORECASE)
|
||||
|
||||
if not match:
|
||||
@@ -144,51 +144,12 @@ def _fallback_parse_nodes(prompt_text: str) -> tuple[Set[str], Set[str]]:
|
||||
logging.error("在所有JSON代码块中都没有找到有效的节点定义结构。")
|
||||
return set(), set()
|
||||
|
||||
def _find_nodes_by_name(node: Dict, target_name: str) -> List[Dict]:
|
||||
"""递归查找所有指定名称的节点"""
|
||||
nodes_found = []
|
||||
|
||||
if node.get("name") == target_name:
|
||||
nodes_found.append(node)
|
||||
|
||||
# 递归搜索子节点
|
||||
for child in node.get("children", []):
|
||||
nodes_found.extend(_find_nodes_by_name(child, target_name))
|
||||
|
||||
return nodes_found
|
||||
|
||||
def _validate_safety_monitoring(pytree_instance: dict) -> bool:
|
||||
"""验证行为树是否包含必要的安全监控"""
|
||||
root_node = pytree_instance.get("root", {})
|
||||
|
||||
# 查找所有电池监控节点
|
||||
battery_nodes = _find_nodes_by_name(root_node, "battery_above")
|
||||
|
||||
# 检查是否包含安全监控结构
|
||||
safety_monitors = _find_nodes_by_name(root_node, "SafetyMonitor")
|
||||
|
||||
if not battery_nodes and not safety_monitors:
|
||||
logging.warning("⚠️ 安全警告: 行为树中没有发现电池监控节点或安全监控器")
|
||||
return False
|
||||
|
||||
# 检查电池阈值设置是否合理
|
||||
for battery_node in battery_nodes:
|
||||
threshold = battery_node.get("params", {}).get("threshold")
|
||||
if threshold is not None:
|
||||
if threshold < 0.25:
|
||||
logging.warning(f"⚠️ 安全警告: 电池阈值设置过低 ({threshold}),建议不低于0.25")
|
||||
elif threshold > 0.5:
|
||||
logging.warning(f"⚠️ 安全警告: 电池阈值设置过高 ({threshold}),可能影响任务执行")
|
||||
|
||||
logging.info("✅ 安全监控验证通过")
|
||||
return True
|
||||
|
||||
def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> dict:
|
||||
"""
|
||||
根据允许的行动和条件节点,动态生成一个JSON Schema。
|
||||
"""
|
||||
# 所有可能的节点类型
|
||||
node_types = ["action", "condition", "Sequence", "Selector", "Parallel"]
|
||||
node_types = ["action", "condition", "Sequence", "Selector", "Parallel", "decorator"]
|
||||
|
||||
# 目标检测相关的类别枚举
|
||||
target_classes = [
|
||||
@@ -201,38 +162,44 @@ def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> di
|
||||
"sandwich", "orange", "broccoli", "carrot", "hot_dog", "pizza", "donut", "cake", "chair",
|
||||
"couch", "potted_plant", "bed", "dining_table", "toilet", "tv", "laptop", "mouse", "remote",
|
||||
"keyboard", "cell_phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
|
||||
"clock", "vase", "scissors", "teddy_bear", "hair_drier", "toothbrush","balloon"
|
||||
"clock", "vase", "scissors", "teddy_bear", "hair_drier", "toothbrush","balloon","trash","window"
|
||||
]
|
||||
|
||||
# 递归节点定义
|
||||
node_definition = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"type": "string", "enum": node_types},
|
||||
# 修改:手动构造不区分大小写的正则,避免使用不支持的 (?i) 标志
|
||||
# 匹配: action, condition, sequence, selector, parallel, decorator (忽略大小写)
|
||||
"type": {
|
||||
"type": "string",
|
||||
"pattern": "^([Aa][Cc][Tt][Ii][Oo][Nn]|[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]|[Ss][Ee][Qq][Uu][Ee][Nn][Cc][Ee]|[Ss][Ee][Ll][Ee][Cc][Tt][Oo][Rr]|[Pp][Aa][Rr][Aa][Ll][Ll][Ee][Ll]|[Dd][Ee][Cc][Oo][Rr][Aa][Tt][Oo][Rr])$"
|
||||
},
|
||||
"name": {"type": "string"},
|
||||
"params": {"type": "object"},
|
||||
"children": {
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/definitions/node"}
|
||||
}
|
||||
},
|
||||
"child": {"$ref": "#/definitions/node"}
|
||||
},
|
||||
"required": ["type", "name"],
|
||||
"allOf": [
|
||||
# 动作节点验证
|
||||
# 动作节点验证 (忽略大小写)
|
||||
{
|
||||
"if": {"properties": {"type": {"const": "action"}}},
|
||||
"if": {"properties": {"type": {"pattern": "^[Aa][Cc][Tt][Ii][Oo][Nn]$"}}},
|
||||
"then": {"properties": {"name": {"enum": sorted(list(allowed_actions))}}}
|
||||
},
|
||||
# 条件节点验证
|
||||
# 条件节点验证 (忽略大小写)
|
||||
{
|
||||
"if": {"properties": {"type": {"const": "condition"}}},
|
||||
"if": {"properties": {"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"}}},
|
||||
"then": {"properties": {"name": {"enum": sorted(list(allowed_conditions))}}}
|
||||
},
|
||||
# 目标检测动作节点的参数验证
|
||||
# 目标检测动作节点的参数验证 (忽略大小写)
|
||||
{
|
||||
"if": {
|
||||
"properties": {
|
||||
"type": {"const": "action"},
|
||||
"type": {"pattern": "^[Aa][Cc][Tt][Ii][Oo][Nn]$"},
|
||||
"name": {"const": "object_detect"}
|
||||
}
|
||||
},
|
||||
@@ -251,11 +218,11 @@ def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> di
|
||||
}
|
||||
}
|
||||
},
|
||||
# 目标检测条件节点的参数验证
|
||||
# 目标检测条件节点的参数验证 (忽略大小写)
|
||||
{
|
||||
"if": {
|
||||
"properties": {
|
||||
"type": {"const": "condition"},
|
||||
"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"},
|
||||
"name": {"const": "object_detected"}
|
||||
}
|
||||
},
|
||||
@@ -274,11 +241,11 @@ def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> di
|
||||
}
|
||||
}
|
||||
},
|
||||
# 电池监控节点的参数验证
|
||||
# 电池监控节点的参数验证 (忽略大小写)
|
||||
{
|
||||
"if": {
|
||||
"properties": {
|
||||
"type": {"const": "condition"},
|
||||
"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"},
|
||||
"name": {"const": "battery_above"}
|
||||
}
|
||||
},
|
||||
@@ -295,11 +262,11 @@ def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> di
|
||||
}
|
||||
}
|
||||
},
|
||||
# GPS状态节点的参数验证
|
||||
# GPS状态节点的参数验证 (忽略大小写)
|
||||
{
|
||||
"if": {
|
||||
"properties": {
|
||||
"type": {"const": "condition"},
|
||||
"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"},
|
||||
"name": {"const": "gps_status"}
|
||||
}
|
||||
},
|
||||
@@ -372,10 +339,7 @@ def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
|
||||
jsonschema.validate(instance=pytree_instance, schema=schema)
|
||||
logging.info("✅ JSON Schema验证成功")
|
||||
|
||||
# 额外验证安全监控
|
||||
safety_valid = _validate_safety_monitoring(pytree_instance)
|
||||
|
||||
return True and safety_valid
|
||||
return True
|
||||
except jsonschema.ValidationError as e:
|
||||
logging.warning("❌ Pytree验证失败")
|
||||
logging.warning(f"错误信息: {e.message}")
|
||||
@@ -503,6 +467,10 @@ def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#e1d5e7' # 紫色
|
||||
elif node_type == 'decorator':
|
||||
shape = 'doubleoctagon'
|
||||
style = 'filled'
|
||||
fillcolor = '#f8cecc' # 浅红
|
||||
|
||||
# 特别标记安全相关节点
|
||||
if node.get('name') in ['battery_above', 'gps_status', 'SafetyMonitor']:
|
||||
@@ -515,28 +483,28 @@ def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
|
||||
if parent_id:
|
||||
dot.edge(parent_id, current_id)
|
||||
|
||||
# 递归处理子节点
|
||||
# 递归处理子节点 (Sequence, Selector, Parallel 等)
|
||||
children = node.get("children", [])
|
||||
if not children:
|
||||
return current_id
|
||||
|
||||
# 记录所有子节点的ID
|
||||
child_ids = []
|
||||
|
||||
# 正确的递归连接:每个子节点都连接到当前节点
|
||||
for child in children:
|
||||
child_id = _add_nodes_and_edges(child, dot, current_id)
|
||||
child_ids.append(child_id)
|
||||
|
||||
# 子节点同级排列(横向排布,更直观地表现同层)
|
||||
if len(child_ids) > 1:
|
||||
with dot.subgraph(name=f"rank_{current_id}") as s:
|
||||
s.attr(rank='same')
|
||||
for cid in child_ids:
|
||||
s.node(cid)
|
||||
|
||||
# 行为树中,所有类型的节点都只是父连子,不需要子节点间的额外连接
|
||||
# Sequence、Selector、Parallel 的执行逻辑由行为树引擎处理,不需要在可视化中体现
|
||||
if children:
|
||||
# 记录所有子节点的ID
|
||||
child_ids = []
|
||||
|
||||
# 正确的递归连接:每个子节点都连接到当前节点
|
||||
for child in children:
|
||||
child_id = _add_nodes_and_edges(child, dot, current_id)
|
||||
child_ids.append(child_id)
|
||||
|
||||
# 子节点同级排列(横向排布,更直观地表现同层)
|
||||
if len(child_ids) > 1:
|
||||
with dot.subgraph(name=f"rank_{current_id}") as s:
|
||||
s.attr(rank='same')
|
||||
for cid in child_ids:
|
||||
s.node(cid)
|
||||
|
||||
# 递归处理单子节点 (Decorator)
|
||||
child = node.get("child")
|
||||
if child:
|
||||
_add_nodes_and_edges(child, dot, current_id)
|
||||
|
||||
return current_id
|
||||
|
||||
@@ -588,7 +556,7 @@ class PyTreeGenerator:
|
||||
self.complex_llm_client = openai.OpenAI(api_key=self.api_key, base_url=self.complex_base_url)
|
||||
|
||||
# --- ChromaDB Client Setup ---
|
||||
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'vector_store'))
|
||||
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'rag','vector_store'))
|
||||
self.chroma_client = chromadb.PersistentClient(path=vector_store_path)
|
||||
|
||||
# Explicitly use the remote embedding function for queries
|
||||
@@ -725,7 +693,7 @@ class PyTreeGenerator:
|
||||
pytree_str = combined_text if combined_text else (msg_content or "")
|
||||
raw_full_text_for_logging = pytree_str # 保存完整原文(含 <think>)以便失败时完整打印
|
||||
|
||||
# 提取 <think> 推理链内容(若存在)
|
||||
# 提取 <think> 推理链内容(若有)
|
||||
reasoning_text = None
|
||||
try:
|
||||
think_match = re.search(r"<think>([\s\S]*?)</think>", pytree_str)
|
||||
@@ -827,49 +795,7 @@ class PyTreeGenerator:
|
||||
pytree_dict['final_prompt'] = final_prompt
|
||||
return pytree_dict
|
||||
|
||||
# 复杂模式回退:若模型误返回简单结构(root是单个action),则自动包装为含安全监控的行为树
|
||||
if mode == "complex" and isinstance(pytree_dict, dict) and 'root' in pytree_dict:
|
||||
root_node = pytree_dict.get('root', {})
|
||||
# 检查是否是简单结构(root是单个action节点,没有children)
|
||||
if (root_node.get('type') == 'action' and
|
||||
('children' not in root_node or not root_node.get('children'))):
|
||||
try:
|
||||
jsonschema.validate(instance=pytree_dict, schema=self.simple_schema)
|
||||
logging.warning("⚠️ 复杂模式生成了简单结构(单个action),触发自动包装为完整行为树的回退逻辑。")
|
||||
action_name = root_node.get('name')
|
||||
action_params = root_node.get('params') if isinstance(root_node.get('params'), dict) else {}
|
||||
|
||||
safety_selector = {
|
||||
"type": "Selector",
|
||||
"name": "SafetyMonitor",
|
||||
"params": {"memory": True},
|
||||
"children": [
|
||||
{"type": "condition", "name": "battery_above", "params": {"threshold": 0.3}},
|
||||
{"type": "condition", "name": "gps_status", "params": {"min_satellites": 8}},
|
||||
{"type": "Sequence", "name": "EmergencyHandler", "children": [
|
||||
{"type": "action", "name": "emergency_return", "params": {"reason": "safety_breach"}},
|
||||
{"type": "action", "name": "land", "params": {"mode": "home"}}
|
||||
]}
|
||||
]
|
||||
}
|
||||
|
||||
main_children = [{"type": "action", "name": action_name, "params": action_params}]
|
||||
if action_name != "land":
|
||||
main_children.append({"type": "action", "name": "land", "params": {"mode": "home"}})
|
||||
|
||||
root_parallel = {
|
||||
"type": "Parallel",
|
||||
"name": "MissionWithSafety",
|
||||
"params": {"policy": "all_success"},
|
||||
"children": [
|
||||
{"type": "Sequence", "name": "MainTask", "children": main_children},
|
||||
safety_selector
|
||||
]
|
||||
}
|
||||
pytree_dict = {"root": root_parallel}
|
||||
except jsonschema.ValidationError:
|
||||
# 不符合简单结构,按正常复杂验证继续
|
||||
pass
|
||||
# 验证生成的复杂行为树
|
||||
if _validate_pytree_with_schema(pytree_dict, self.schema):
|
||||
logging.info("✅ 成功生成并验证了Pytree")
|
||||
plan_id = str(uuid.uuid4())
|
||||
|
||||
Reference in New Issue
Block a user