新增测试方法,优化系统提示词
This commit is contained in:
@@ -3,7 +3,7 @@ import os
|
||||
import logging
|
||||
import uuid
|
||||
import re
|
||||
from typing import Dict, Any, Optional, Set
|
||||
from typing import Dict, Any, Optional, Set, List
|
||||
import chromadb
|
||||
import openai
|
||||
from openai import OpenAIError
|
||||
@@ -144,6 +144,45 @@ def _fallback_parse_nodes(prompt_text: str) -> tuple[Set[str], Set[str]]:
|
||||
logging.error("在所有JSON代码块中都没有找到有效的节点定义结构。")
|
||||
return set(), set()
|
||||
|
||||
def _find_nodes_by_name(node: Dict, target_name: str) -> List[Dict]:
|
||||
"""递归查找所有指定名称的节点"""
|
||||
nodes_found = []
|
||||
|
||||
if node.get("name") == target_name:
|
||||
nodes_found.append(node)
|
||||
|
||||
# 递归搜索子节点
|
||||
for child in node.get("children", []):
|
||||
nodes_found.extend(_find_nodes_by_name(child, target_name))
|
||||
|
||||
return nodes_found
|
||||
|
||||
def _validate_safety_monitoring(pytree_instance: dict) -> bool:
|
||||
"""验证行为树是否包含必要的安全监控"""
|
||||
root_node = pytree_instance.get("root", {})
|
||||
|
||||
# 查找所有电池监控节点
|
||||
battery_nodes = _find_nodes_by_name(root_node, "battery_above")
|
||||
|
||||
# 检查是否包含安全监控结构
|
||||
safety_monitors = _find_nodes_by_name(root_node, "SafetyMonitor")
|
||||
|
||||
if not battery_nodes and not safety_monitors:
|
||||
logging.warning("⚠️ 安全警告: 行为树中没有发现电池监控节点或安全监控器")
|
||||
return False
|
||||
|
||||
# 检查电池阈值设置是否合理
|
||||
for battery_node in battery_nodes:
|
||||
threshold = battery_node.get("params", {}).get("threshold")
|
||||
if threshold is not None:
|
||||
if threshold < 0.25:
|
||||
logging.warning(f"⚠️ 安全警告: 电池阈值设置过低 ({threshold}),建议不低于0.25")
|
||||
elif threshold > 0.5:
|
||||
logging.warning(f"⚠️ 安全警告: 电池阈值设置过高 ({threshold}),可能影响任务执行")
|
||||
|
||||
logging.info("✅ 安全监控验证通过")
|
||||
return True
|
||||
|
||||
def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> dict:
|
||||
"""
|
||||
根据允许的行动和条件节点,动态生成一个JSON Schema。
|
||||
@@ -234,6 +273,48 @@ def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> di
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
# 电池监控节点的参数验证
|
||||
{
|
||||
"if": {
|
||||
"properties": {
|
||||
"type": {"const": "condition"},
|
||||
"name": {"const": "battery_above"}
|
||||
}
|
||||
},
|
||||
"then": {
|
||||
"properties": {
|
||||
"params": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"threshold": {"type": "number", "minimum": 0.0, "maximum": 1.0}
|
||||
},
|
||||
"required": ["threshold"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
# GPS状态节点的参数验证
|
||||
{
|
||||
"if": {
|
||||
"properties": {
|
||||
"type": {"const": "condition"},
|
||||
"name": {"const": "gps_status"}
|
||||
}
|
||||
},
|
||||
"then": {
|
||||
"properties": {
|
||||
"params": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"min_satellites": {"type": "integer", "minimum": 6, "maximum": 15}
|
||||
},
|
||||
"required": ["min_satellites"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -260,14 +341,26 @@ def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
|
||||
"""
|
||||
try:
|
||||
jsonschema.validate(instance=pytree_instance, schema=schema)
|
||||
logging.info("验证成功!Pytree格式和内容均符合规范。")
|
||||
return True
|
||||
logging.info("✅ JSON Schema验证成功")
|
||||
|
||||
# 额外验证安全监控
|
||||
safety_valid = _validate_safety_monitoring(pytree_instance)
|
||||
|
||||
return True and safety_valid
|
||||
except jsonschema.ValidationError as e:
|
||||
logging.warning("--- Pytree验证失败 ---")
|
||||
logging.warning("❌ Pytree验证失败")
|
||||
logging.warning(f"错误信息: {e.message}")
|
||||
error_path = list(e.path)
|
||||
logging.warning(f"错误路径: {' -> '.join(map(str, error_path)) if error_path else '根节点'}")
|
||||
logging.warning(f"出错的实例部分: {e.instance}")
|
||||
|
||||
# 提供更具体的错误信息
|
||||
if "object_detect" in str(e.message) or "object_detected" in str(e.message):
|
||||
logging.warning("💡 提示: 请确保目标类别是预定义列表中的有效值")
|
||||
elif "battery_above" in str(e.message):
|
||||
logging.warning("💡 提示: 电池阈值必须在0.0到1.0之间")
|
||||
elif "gps_status" in str(e.message):
|
||||
logging.warning("💡 提示: 最小卫星数量必须在6到15之间")
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.error(f"进行JSON Schema验证时发生未知错误: {e}")
|
||||
@@ -316,10 +409,10 @@ def _visualize_pytree(node: Dict, file_path: str):
|
||||
|
||||
# 保存为 .png 文件,并自动删除源码 .gv 文件
|
||||
output_path = dot.render(render_path, format='png', cleanup=True, view=False)
|
||||
logging.info("--- 任务树可视化成功 ---")
|
||||
logging.info("✅ 任务树可视化成功")
|
||||
logging.info(f"图形已保存到: {output_path}")
|
||||
except Exception as e:
|
||||
logging.error("--- 错误:生成可视化图形失败 ---")
|
||||
logging.error("❌ 生成可视化图形失败")
|
||||
logging.error("请确保您的系统已经正确安装了Graphviz图形库。")
|
||||
logging.error(f"错误详情: {e}")
|
||||
|
||||
@@ -382,6 +475,11 @@ def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
|
||||
style = 'filled'
|
||||
fillcolor = '#e1d5e7' # 紫色
|
||||
|
||||
# 特别标记安全相关节点
|
||||
if node.get('name') in ['battery_above', 'gps_status', 'SafetyMonitor']:
|
||||
border_color = '#ff0000' # 红色边框突出显示安全节点
|
||||
style = 'filled,bold' # 加粗
|
||||
|
||||
dot.node(current_id, label=node_label, shape=shape, style=style, fillcolor=fillcolor, color=border_color)
|
||||
|
||||
# 连接父节点
|
||||
@@ -505,7 +603,7 @@ class PyTreeGenerator:
|
||||
pytree_str = response.choices[0].message.content
|
||||
pytree_dict = json.loads(pytree_str)
|
||||
if _validate_pytree_with_schema(pytree_dict, self.schema):
|
||||
logging.info("成功生成并验证了Pytree。")
|
||||
logging.info("✅ 成功生成并验证了Pytree")
|
||||
plan_id = str(uuid.uuid4())
|
||||
pytree_dict['plan_id'] = plan_id
|
||||
|
||||
|
||||
Reference in New Issue
Block a user