新增测试方法,优化系统提示词

This commit is contained in:
2025-09-08 09:21:00 +08:00
parent 9bbfd9186e
commit 161389fa1e
5 changed files with 925 additions and 190 deletions

View File

@@ -3,7 +3,7 @@ import os
import logging
import uuid
import re
from typing import Dict, Any, Optional, Set
from typing import Dict, Any, Optional, Set, List
import chromadb
import openai
from openai import OpenAIError
@@ -144,6 +144,45 @@ def _fallback_parse_nodes(prompt_text: str) -> tuple[Set[str], Set[str]]:
logging.error("在所有JSON代码块中都没有找到有效的节点定义结构。")
return set(), set()
def _find_nodes_by_name(node: Dict, target_name: str) -> List[Dict]:
"""递归查找所有指定名称的节点"""
nodes_found = []
if node.get("name") == target_name:
nodes_found.append(node)
# 递归搜索子节点
for child in node.get("children", []):
nodes_found.extend(_find_nodes_by_name(child, target_name))
return nodes_found
def _validate_safety_monitoring(pytree_instance: dict) -> bool:
"""验证行为树是否包含必要的安全监控"""
root_node = pytree_instance.get("root", {})
# 查找所有电池监控节点
battery_nodes = _find_nodes_by_name(root_node, "battery_above")
# 检查是否包含安全监控结构
safety_monitors = _find_nodes_by_name(root_node, "SafetyMonitor")
if not battery_nodes and not safety_monitors:
logging.warning("⚠️ 安全警告: 行为树中没有发现电池监控节点或安全监控器")
return False
# 检查电池阈值设置是否合理
for battery_node in battery_nodes:
threshold = battery_node.get("params", {}).get("threshold")
if threshold is not None:
if threshold < 0.25:
logging.warning(f"⚠️ 安全警告: 电池阈值设置过低 ({threshold})建议不低于0.25")
elif threshold > 0.5:
logging.warning(f"⚠️ 安全警告: 电池阈值设置过高 ({threshold}),可能影响任务执行")
logging.info("✅ 安全监控验证通过")
return True
def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> dict:
"""
根据允许的行动和条件节点动态生成一个JSON Schema。
@@ -234,6 +273,48 @@ def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> di
}
}
}
},
# 电池监控节点的参数验证
{
"if": {
"properties": {
"type": {"const": "condition"},
"name": {"const": "battery_above"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"threshold": {"type": "number", "minimum": 0.0, "maximum": 1.0}
},
"required": ["threshold"],
"additionalProperties": False
}
}
}
},
# GPS状态节点的参数验证
{
"if": {
"properties": {
"type": {"const": "condition"},
"name": {"const": "gps_status"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"min_satellites": {"type": "integer", "minimum": 6, "maximum": 15}
},
"required": ["min_satellites"],
"additionalProperties": False
}
}
}
}
]
}
@@ -260,14 +341,26 @@ def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
"""
try:
jsonschema.validate(instance=pytree_instance, schema=schema)
logging.info("验证成功Pytree格式和内容均符合规范。")
return True
logging.info("✅ JSON Schema验证成功")
# 额外验证安全监控
safety_valid = _validate_safety_monitoring(pytree_instance)
return True and safety_valid
except jsonschema.ValidationError as e:
logging.warning("--- Pytree验证失败 ---")
logging.warning(" Pytree验证失败")
logging.warning(f"错误信息: {e.message}")
error_path = list(e.path)
logging.warning(f"错误路径: {' -> '.join(map(str, error_path)) if error_path else '根节点'}")
logging.warning(f"出错的实例部分: {e.instance}")
# 提供更具体的错误信息
if "object_detect" in str(e.message) or "object_detected" in str(e.message):
logging.warning("💡 提示: 请确保目标类别是预定义列表中的有效值")
elif "battery_above" in str(e.message):
logging.warning("💡 提示: 电池阈值必须在0.0到1.0之间")
elif "gps_status" in str(e.message):
logging.warning("💡 提示: 最小卫星数量必须在6到15之间")
return False
except Exception as e:
logging.error(f"进行JSON Schema验证时发生未知错误: {e}")
@@ -316,10 +409,10 @@ def _visualize_pytree(node: Dict, file_path: str):
# 保存为 .png 文件,并自动删除源码 .gv 文件
output_path = dot.render(render_path, format='png', cleanup=True, view=False)
logging.info("--- 任务树可视化成功 ---")
logging.info(" 任务树可视化成功")
logging.info(f"图形已保存到: {output_path}")
except Exception as e:
logging.error("--- 错误:生成可视化图形失败 ---")
logging.error("生成可视化图形失败")
logging.error("请确保您的系统已经正确安装了Graphviz图形库。")
logging.error(f"错误详情: {e}")
@@ -382,6 +475,11 @@ def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
style = 'filled'
fillcolor = '#e1d5e7' # 紫色
# 特别标记安全相关节点
if node.get('name') in ['battery_above', 'gps_status', 'SafetyMonitor']:
border_color = '#ff0000' # 红色边框突出显示安全节点
style = 'filled,bold' # 加粗
dot.node(current_id, label=node_label, shape=shape, style=style, fillcolor=fillcolor, color=border_color)
# 连接父节点
@@ -505,7 +603,7 @@ class PyTreeGenerator:
pytree_str = response.choices[0].message.content
pytree_dict = json.loads(pytree_str)
if _validate_pytree_with_schema(pytree_dict, self.schema):
logging.info("成功生成并验证了Pytree")
logging.info("成功生成并验证了Pytree")
plan_id = str(uuid.uuid4())
pytree_dict['plan_id'] = plan_id