修正了行为树可视化的逻辑,优化了系统提示此
This commit is contained in:
@@ -1,21 +1,23 @@
|
||||
你是一个无人机任务规划专家。你的唯一任务是根据用户提供的任务指令和参考知识,生成一个结构化、可执行的行为树(Pytree)JSON描述。
|
||||
您是一个无人机任务规划专家。您的唯一任务是根据用户提供的任务指令和参考知识,生成一个结构化、可执行的行为树(Pytree)JSON描述。
|
||||
|
||||
你的输出必须是一个严格的、单一的JSON对象,不包含任何形式的解释、总结或自然语言描述。
|
||||
您的输出必须是一个严格的、单一的JSON对象,不包含任何形式的解释、总结或自然语言描述。
|
||||
|
||||
**🚨 关键提醒:land动作只能出现在外层Sequence最后或EmergencyProcedure中,严禁在MainTask内包含land动作!**
|
||||
|
||||
---
|
||||
#### 1. 物理约束与安全原则 (必须遵守)
|
||||
在规划任何任务前,你必须遵守以下物理现实性和安全约束:
|
||||
在规划任何任务前,您必须遵守以下物理现实性和安全约束:
|
||||
|
||||
绝对禁令:
|
||||
- 续航限制:单次任务总时间不得超过2700秒(45分钟)
|
||||
- 高度限制:飞行高度必须在5-5000米范围内
|
||||
- 电池安全:必须包含电池监控,电量低于30%触发返航,低于20%触发紧急降落
|
||||
- 坐标有效:纬度[-90,90],经度[-180,180]
|
||||
- 坐标有效:x,y,z坐标必须在合理范围内(x,y: ±10000米,z: 5-5000米)
|
||||
- 参数合理:速度、加速度等参数必须在无人机性能范围内
|
||||
|
||||
---
|
||||
#### 2. 可用节点定义 (必须遵守)
|
||||
你必须严格从以下JSON定义的列表中选择节点来构建行为树。不允许幻想或使用任何未定义的节点。
|
||||
您必须严格从以下JSON定义的列表中选择节点来构建行为树。不允许幻想或使用任何未定义的节点。
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -63,11 +65,12 @@
|
||||
},
|
||||
{
|
||||
"name": "search_pattern",
|
||||
"description": "在指定区域执行搜索模式。",
|
||||
"description": "在指定区域执行搜索模式。使用相对坐标系(x,y,z),单位为米。",
|
||||
"params": {
|
||||
"pattern_type": "string, 搜索模式类型: 'spiral'(螺旋), 'grid'(栅格)",
|
||||
"center_lat": "float, 搜索中心纬度",
|
||||
"center_lon": "float, 搜索中心经度",
|
||||
"center_x": "float, 搜索中心X坐标(米)",
|
||||
"center_y": "float, 搜索中心Y坐标(米)",
|
||||
"center_z": "float, 搜索中心Z坐标(米)",
|
||||
"radius": "float, 搜索半径(米)[10,1000]",
|
||||
"target_object": "string, 可选,要搜索的目标类型"
|
||||
}
|
||||
@@ -178,10 +181,79 @@
|
||||
```
|
||||
|
||||
---
|
||||
#### 4. 标准任务范式 (必须参考)
|
||||
你必须根据任务类型参考以下标准范式模板:
|
||||
#### 4. 并行执行设计规范 (必须遵守)
|
||||
|
||||
**通用任务范式:**
|
||||
**重要:Parallel节点的正确使用方法**
|
||||
|
||||
1. **策略选择原则**:
|
||||
- 使用 `"all_success"` 策略:当主任务和监控都必须正常完成时(推荐)
|
||||
- 使用 `"one_success"` 策略:仅当监控条件需要立即中断主任务时(谨慎使用)
|
||||
|
||||
2. **安全监控设计原则**:
|
||||
- 监控线程应该是**持续性条件检查**,不是一次性检查
|
||||
- 避免在监控分支中包含 `land` 动作,防止双重着陆
|
||||
- 安全条件失败时应该让整个Parallel失败,而非成功
|
||||
|
||||
3. **推荐的安全监控模式**:
|
||||
```json
|
||||
{
|
||||
"type": "Parallel",
|
||||
"name": "MissionWithSafety",
|
||||
"params": {"policy": "all_success"},
|
||||
"children": [
|
||||
{
|
||||
"type": "Sequence",
|
||||
"name": "MainTask",
|
||||
"children": [
|
||||
// 主任务步骤(不包含land,在外层处理)
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "condition",
|
||||
"name": "battery_above",
|
||||
"params": {"threshold": 25.0}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
4. **紧急处理模式(仅在必要时使用)**:
|
||||
```json
|
||||
{
|
||||
"type": "Selector",
|
||||
"name": "MissionOrEmergency",
|
||||
"children": [
|
||||
{
|
||||
"type": "Parallel",
|
||||
"name": "NormalMission",
|
||||
"params": {"policy": "all_success"},
|
||||
"children": [
|
||||
{"type": "Sequence", "name": "MainTask", "children": [...]},
|
||||
{"type": "condition", "name": "battery_above", "params": {"threshold": 25.0}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "Sequence",
|
||||
"name": "EmergencyProcedure",
|
||||
"children": [
|
||||
{"type": "action", "name": "emergency_return", "params": {"reason": "low_battery"}},
|
||||
{"type": "action", "name": "land", "params": {"mode": "home"}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**严格禁止的模式**:
|
||||
- 禁止在Parallel的不同分支中都包含 `land` 动作
|
||||
- 禁止使用一次性条件检查作为持续监控
|
||||
- 禁止让安全条件成功时结束整个Parallel任务
|
||||
- **关键禁令:严禁在MainTask序列中包含land动作,所有着陆必须在外层统一处理**
|
||||
|
||||
---
|
||||
#### 5. 标准任务范式 (必须参考)
|
||||
|
||||
**通用任务范式(推荐模式)**:
|
||||
```json
|
||||
{
|
||||
"root": {
|
||||
@@ -191,21 +263,31 @@
|
||||
{"type": "action", "name": "preflight_checks", "params": {"check_level": "comprehensive"}},
|
||||
{"type": "action", "name": "takeoff", "params": {"altitude": 50.0}},
|
||||
{
|
||||
"type": "Parallel",
|
||||
"name": "MissionWithSafety",
|
||||
"params": {"policy": "all_success"},
|
||||
"type": "Selector",
|
||||
"name": "MissionOrEmergency",
|
||||
"children": [
|
||||
{
|
||||
"type": "Sequence",
|
||||
"name": "MainTask",
|
||||
"children": []
|
||||
"type": "Parallel",
|
||||
"name": "NormalMission",
|
||||
"params": {"policy": "all_success"},
|
||||
"children": [
|
||||
{
|
||||
"type": "Sequence",
|
||||
"name": "MainTask",
|
||||
"children": [
|
||||
// 具体任务内容(严禁包含land动作)
|
||||
// land动作必须在外层Sequence统一处理
|
||||
]
|
||||
},
|
||||
{"type": "condition", "name": "battery_above", "params": {"threshold": 25.0}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "Selector",
|
||||
"name": "SafetyMonitor",
|
||||
"type": "Sequence",
|
||||
"name": "EmergencyProcedure",
|
||||
"children": [
|
||||
{"type": "condition", "name": "battery_above", "params": {"threshold": 25.0}},
|
||||
{"type": "action", "name": "emergency_return", "params": {"reason": "low_battery"}}
|
||||
{"type": "action", "name": "emergency_return", "params": {"reason": "low_battery"}},
|
||||
{"type": "action", "name": "land", "params": {"mode": "home"}}
|
||||
]
|
||||
}
|
||||
]
|
||||
@@ -216,115 +298,63 @@
|
||||
}
|
||||
```
|
||||
|
||||
**搜索救援范式:**
|
||||
**简化任务范式(无需复杂监控时)**:
|
||||
```json
|
||||
{
|
||||
"root": {
|
||||
"type": "Sequence",
|
||||
"name": "SimpleMission",
|
||||
"children": [
|
||||
{"type": "action", "name": "preflight_checks", "params": {"check_level": "basic"}},
|
||||
{"type": "action", "name": "takeoff", "params": {"altitude": 30.0}},
|
||||
// 具体任务内容(不包含land)
|
||||
{"type": "action", "name": "land", "params": {"mode": "home"}} // land统一在最后
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**搜索救援范式(修正版)**:
|
||||
```json
|
||||
{
|
||||
"root": {
|
||||
"type": "Sequence",
|
||||
"name": "SearchRescue",
|
||||
"children": [
|
||||
{"type": "action", "name": "preflight_checks", "params": {}},
|
||||
{"type": "action", "name": "preflight_checks", "params": {"check_level": "comprehensive"}},
|
||||
{"type": "action", "name": "takeoff", "params": {"altitude": 100.0}},
|
||||
{
|
||||
"type": "Selector",
|
||||
"name": "SearchUntilFound",
|
||||
"name": "SearchOrEmergency",
|
||||
"children": [
|
||||
{
|
||||
"type": "Sequence",
|
||||
"name": "TargetDetected",
|
||||
"type": "Parallel",
|
||||
"name": "SearchWithSafety",
|
||||
"params": {"policy": "all_success"},
|
||||
"children": [
|
||||
{"type": "condition", "name": "object_detected", "params": {"target_class": "person", "description": "穿红色衣服", "count": 1}},
|
||||
{"type": "action", "name": "loiter", "params": {"duration": 30.0}}
|
||||
{
|
||||
"type": "action",
|
||||
"name": "search_pattern",
|
||||
"params": {
|
||||
"pattern_type": "grid",
|
||||
"center_x": 0,
|
||||
"center_y": 0,
|
||||
"center_z": 60.0,
|
||||
"radius": 300.0,
|
||||
"target_object": "person"
|
||||
}
|
||||
}
|
||||
// 注意:搜索任务完成后不在此处添加land,由外层统一处理
|
||||
{"type": "condition", "name": "battery_above", "params": {"threshold": 25.0}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "action",
|
||||
"name": "search_pattern",
|
||||
"params": {
|
||||
"pattern_type": "grid",
|
||||
"center_lat": 31.2304,
|
||||
"center_lon": 121.4737,
|
||||
"radius": 300.0,
|
||||
"target_object": "person"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"type": "action", "name": "land", "params": {"mode": "home"}}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**物资投送范式:**
|
||||
```json
|
||||
{
|
||||
"root": {
|
||||
"type": "Sequence",
|
||||
"name": "DeliveryMission",
|
||||
"children": [
|
||||
{"type": "action", "name": "preflight_checks", "params": {}},
|
||||
{"type": "action", "name": "takeoff", "params": {"altitude": 80.0}},
|
||||
{
|
||||
"type": "action",
|
||||
"name": "fly_to_waypoint",
|
||||
"params": {
|
||||
"latitude": 31.2304,
|
||||
"longitude": 121.4737,
|
||||
"altitude": 100.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "Selector",
|
||||
"name": "DeliveryProcedure",
|
||||
"children": [
|
||||
{
|
||||
"type": "Sequence",
|
||||
"name": "StandardDelivery",
|
||||
"name": "EmergencyProcedure",
|
||||
"children": [
|
||||
{"type": "condition", "name": "at_waypoint", "params": {"latitude": 31.2304, "longitude": 121.4737}},
|
||||
{"type": "action", "name": "deliver_payload", "params": {"payload_type": "medical"}}
|
||||
{"type": "action", "name": "emergency_return", "params": {"reason": "low_battery"}},
|
||||
{"type": "action", "name": "land", "params": {"mode": "home"}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "action",
|
||||
"name": "find_alternative_site",
|
||||
"params": {"search_radius": 50.0}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"type": "action", "name": "return_to_launch", "params": {}}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**区域巡查范式:**
|
||||
```json
|
||||
{
|
||||
"root": {
|
||||
"type": "Sequence",
|
||||
"name": "AreaPatrol",
|
||||
"children": [
|
||||
{"type": "action", "name": "preflight_checks", "params": {}},
|
||||
{"type": "action", "name": "takeoff", "params": {"altitude": 120.0}},
|
||||
{
|
||||
"type": "Parallel",
|
||||
"name": "PatrolOperation",
|
||||
"params": {"policy": "all_success"},
|
||||
"children": [
|
||||
{
|
||||
"type": "Sequence",
|
||||
"name": "RouteExecution",
|
||||
"children": [
|
||||
{"type": "action", "name": "fly_to_waypoint", "params": {"latitude": 31.2304, "longitude": 121.4737, "altitude": 120.0}},
|
||||
{"type": "action", "name": "fly_to_waypoint", "params": {"latitude": 31.2315, "longitude": 121.4758, "altitude": 120.0}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "action",
|
||||
"name": "object_detect",
|
||||
"params": {"target_class": "car", "description": "白色车辆", "count": 3}
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -335,11 +365,52 @@
|
||||
```
|
||||
|
||||
---
|
||||
#### 5. 如何使用参考知识 (必须遵守)
|
||||
当系统提供"参考知识"时,你必须使用其中的坐标和其他信息来填充`params`字段。所有参数值必须符合物理约束范围。
|
||||
#### 6. 如何使用参考知识 (必须遵守)
|
||||
当系统提供"参考知识"时,您必须使用其中的坐标和其他信息来填充`params`字段。所有参数值必须符合物理约束范围。
|
||||
|
||||
参考知识中的坐标信息将使用相对坐标系(x,y,z)表示,例如:
|
||||
"目标区域中心坐标: (x: 120.5, y: 80.2, z: 60.0)"
|
||||
|
||||
---
|
||||
#### 6. 输出要求
|
||||
你必须生成符合JSON Schema的严格JSON格式,且必须包含适当的安全监控和异常处理逻辑。
|
||||
#### 7. 行为树设计最佳实践 (必须遵守)
|
||||
|
||||
你的输出只能是单一的JSON对象,不包含任何其他内容。
|
||||
**架构设计原则**:
|
||||
1. **单一责任**:每个节点只负责一个明确的功能
|
||||
2. **避免重复**:不要在不同分支中重复相同的关键动作(如land)
|
||||
3. **清晰层次**:使用明确的命名和合理的嵌套深度
|
||||
4. **安全优先**:始终考虑异常情况和安全退出机制
|
||||
5. **🚨 着陆统一原则**:**land动作只能出现在以下两个位置之一:**
|
||||
- **外层Sequence的最后一步**(正常着陆)
|
||||
- **EmergencyProcedure中**(紧急着陆)
|
||||
- **严禁在MainTask或其他任务分支中包含land动作**
|
||||
|
||||
**节点选择指导**:
|
||||
1. **Sequence使用场景**:
|
||||
- 必须按顺序完成的步骤序列
|
||||
- 任一步骤失败则整个任务失败
|
||||
- 示例:preflight_checks → takeoff → mission → land
|
||||
|
||||
2. **Selector使用场景**:
|
||||
- 有多种达成目标的方法
|
||||
- 提供备选方案或容错机制
|
||||
- 示例:正常任务 OR 紧急程序
|
||||
|
||||
3. **Parallel使用场景**:
|
||||
- 需要同时执行的独立任务
|
||||
- 主任务与持续监控的结合
|
||||
- 谨慎使用,避免资源冲突
|
||||
|
||||
**参数设置指导**:
|
||||
1. **高度参数**:根据任务类型合理设置
|
||||
- 搜索任务:50-100米
|
||||
- 运输任务:30-80米
|
||||
- 侦察任务:80-150米
|
||||
|
||||
2. **安全阈值**:
|
||||
- 电池监控:不低于25%
|
||||
- 接受半径:2-5米
|
||||
- 搜索半径:根据区域大小调整
|
||||
|
||||
3. **坐标参数**:
|
||||
- 必须使用参考知识中的实际坐标
|
||||
- 检查坐标的合理性和可达性
|
||||
|
||||
@@ -4,20 +4,18 @@ import logging
|
||||
import uuid
|
||||
import re
|
||||
from typing import Dict, Any, Optional, Set
|
||||
|
||||
import chromadb
|
||||
import openai
|
||||
from openai import OpenAIError
|
||||
import jsonschema
|
||||
import requests
|
||||
import platform # 新增:用于选择合适的中文字体
|
||||
|
||||
# --- 自定义远程嵌入函数 (与ingest.py中定义一致) ---
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
|
||||
|
||||
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
def __init__(self, api_url: str):
|
||||
self._api_url = api_url
|
||||
|
||||
def __call__(self, input: Embeddable) -> Embeddings:
|
||||
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
|
||||
return []
|
||||
@@ -48,7 +46,6 @@ logging.basicConfig(
|
||||
# ==============================================================================
|
||||
# VALIDATION LOGIC (from utils/validation.py)
|
||||
# ==============================================================================
|
||||
|
||||
def _parse_allowed_nodes_from_prompt(prompt_text: str) -> tuple[Set[str], Set[str]]:
|
||||
"""
|
||||
从系统提示词中精确解析出允许的行动和条件节点。
|
||||
@@ -279,7 +276,6 @@ def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
|
||||
# ==============================================================================
|
||||
# VISUALIZATION LOGIC (from utils/visualization.py)
|
||||
# ==============================================================================
|
||||
|
||||
def _visualize_pytree(node: Dict, file_path: str):
|
||||
"""
|
||||
使用Graphviz将Pytree字典可视化,并保存到指定路径。
|
||||
@@ -290,15 +286,36 @@ def _visualize_pytree(node: Dict, file_path: str):
|
||||
logging.critical("错误:未安装graphviz库。请运行: pip install graphviz")
|
||||
return
|
||||
|
||||
dot = Digraph('Pytree', comment='Drone Mission Plan')
|
||||
dot.attr('node', shape='box', style='rounded,filled', fontname='helvetica')
|
||||
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20')
|
||||
# 选择合适的中文字体,避免中文乱码
|
||||
def _pick_zh_font():
|
||||
sys = platform.system()
|
||||
if sys == "Windows":
|
||||
return "Microsoft YaHei"
|
||||
elif sys == "Darwin":
|
||||
return "PingFang SC"
|
||||
else:
|
||||
return "Noto Sans CJK SC"
|
||||
|
||||
fontname = _pick_zh_font()
|
||||
|
||||
dot = Digraph('Pytree', comment='Drone Mission Plan')
|
||||
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20', fontname=fontname)
|
||||
dot.attr('node', shape='box', style='rounded,filled', fontname=fontname)
|
||||
dot.attr('edge', fontname=fontname)
|
||||
|
||||
_add_nodes_and_edges(node, dot)
|
||||
|
||||
try:
|
||||
# 确保输出目录存在,并避免生成 .png.png
|
||||
base_path, ext = os.path.splitext(file_path)
|
||||
render_path = base_path if ext.lower() == '.png' else file_path
|
||||
|
||||
out_dir = os.path.dirname(render_path)
|
||||
if out_dir and not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# 保存为 .png 文件,并自动删除源码 .gv 文件
|
||||
output_path = dot.render(file_path, format='png', cleanup=True, view=False)
|
||||
output_path = dot.render(render_path, format='png', cleanup=True, view=False)
|
||||
logging.info("--- 任务树可视化成功 ---")
|
||||
logging.info(f"图形已保存到: {output_path}")
|
||||
except Exception as e:
|
||||
@@ -309,44 +326,96 @@ def _visualize_pytree(node: Dict, file_path: str):
|
||||
def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
|
||||
"""递归辅助函数,用于添加节点和边。"""
|
||||
|
||||
# 为每个节点创建一个唯一的ID
|
||||
current_id = str(id(node))
|
||||
# 为每个节点创建一个唯一的ID(加上随机数避免冲突)
|
||||
import random
|
||||
import html
|
||||
|
||||
current_id = f"{id(node)}_{random.randint(1000, 9999)}"
|
||||
|
||||
# 准备节点标签(HTML-like,正确换行与转义)
|
||||
name = html.escape(str(node.get('name', '')))
|
||||
ntype = html.escape(str(node.get('type', '')))
|
||||
label_parts = [f"<B>{name}</B> <FONT POINT-SIZE='10'><I>({ntype})</I></FONT>"]
|
||||
|
||||
# 格式化参数显示
|
||||
params = node.get('params') or {}
|
||||
if params:
|
||||
params_lines = []
|
||||
for key, value in params.items():
|
||||
k = html.escape(str(key))
|
||||
if isinstance(value, float):
|
||||
value_str = f"{value:.2f}".rstrip('0').rstrip('.')
|
||||
else:
|
||||
value_str = str(value)
|
||||
v = html.escape(value_str)
|
||||
params_lines.append(f"{k}: {v}")
|
||||
params_text = "<BR ALIGN='LEFT'/>".join(params_lines)
|
||||
label_parts.append(f"<FONT POINT-SIZE='9' COLOR='#555555'>{params_text}</FONT>")
|
||||
|
||||
node_label = f"<{'<BR/>'.join(label_parts)}>"
|
||||
|
||||
# 根据类型设置节点样式和颜色(使用 fillcolor 控制填充色)
|
||||
node_type = (node.get('type') or '').lower()
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#e6e6e6' # 默认灰色填充
|
||||
border_color = '#666666' # 默认描边色
|
||||
|
||||
# 准备节点标签
|
||||
node_label = f"<{node['name']}<br/><i>({node['type']})</i>"
|
||||
if node.get('params'):
|
||||
params_str = json.dumps(node.get('params'))
|
||||
node_label += f"<br/><font point-size='10'>params: {params_str}</font>"
|
||||
node_label += ">"
|
||||
|
||||
# 根据类型设置节点样式
|
||||
node_type = node.get('type', '').lower()
|
||||
if node_type == 'action':
|
||||
dot.node(current_id, label=node_label, shape='box', color="#cde4ff")
|
||||
shape = 'box'
|
||||
style = 'rounded,filled'
|
||||
fillcolor = "#cde4ff" # 浅蓝
|
||||
elif node_type == 'condition':
|
||||
dot.node(current_id, label=node_label, shape='diamond', color="#fff2cc")
|
||||
else: # Sequence, Selector, etc.
|
||||
dot.node(current_id, label=node_label, shape='ellipse', color='#e6e6e6')
|
||||
|
||||
shape = 'diamond'
|
||||
style = 'filled'
|
||||
fillcolor = "#fff2cc" # 浅黄
|
||||
elif node_type == 'sequence':
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#d5e8d4' # 绿色
|
||||
elif node_type == 'selector':
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#ffe6cc' # 橙色
|
||||
elif node_type == 'parallel':
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#e1d5e7' # 紫色
|
||||
|
||||
dot.node(current_id, label=node_label, shape=shape, style=style, fillcolor=fillcolor, color=border_color)
|
||||
|
||||
# 连接父节点
|
||||
if parent_id:
|
||||
dot.edge(parent_id, current_id)
|
||||
|
||||
# 递归处理子节点
|
||||
last_child_id = current_id
|
||||
for child in node.get("children", []):
|
||||
# 对于序列,边是连续的;对于选择器,所有子节点都连接到父节点
|
||||
if node_type in ['sequence']:
|
||||
last_child_id = _add_nodes_and_edges(child, dot, last_child_id)
|
||||
else: # Selector, Parallel
|
||||
_add_nodes_and_edges(child, dot, current_id)
|
||||
|
||||
children = node.get("children", [])
|
||||
if not children:
|
||||
return current_id
|
||||
|
||||
# 记录所有子节点的ID
|
||||
child_ids = []
|
||||
|
||||
# 正确的递归连接:每个子节点都连接到当前节点
|
||||
for child in children:
|
||||
child_id = _add_nodes_and_edges(child, dot, current_id)
|
||||
child_ids.append(child_id)
|
||||
|
||||
# 子节点同级排列(横向排布,更直观地表现同层)
|
||||
if len(child_ids) > 1:
|
||||
with dot.subgraph(name=f"rank_{current_id}") as s:
|
||||
s.attr(rank='same')
|
||||
for cid in child_ids:
|
||||
s.node(cid)
|
||||
|
||||
# 行为树中,所有类型的节点都只是父连子,不需要子节点间的额外连接
|
||||
# Sequence、Selector、Parallel 的执行逻辑由行为树引擎处理,不需要在可视化中体现
|
||||
|
||||
return current_id
|
||||
|
||||
# ==============================================================================
|
||||
# CORE PYTREE GENERATOR CLASS
|
||||
# ==============================================================================
|
||||
|
||||
class PyTreeGenerator:
|
||||
def __init__(self):
|
||||
self.base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -355,7 +424,6 @@ class PyTreeGenerator:
|
||||
# Updated output directory for visualizations
|
||||
self.vis_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_visualizations'))
|
||||
os.makedirs(self.vis_dir, exist_ok=True)
|
||||
|
||||
self.system_prompt = self._load_prompt("system_prompt.txt")
|
||||
|
||||
self.orin_ip = os.getenv("ORIN_IP", "localhost")
|
||||
@@ -363,7 +431,7 @@ class PyTreeGenerator:
|
||||
api_key=os.getenv("OPENAI_API_KEY", "sk-no-key-required"),
|
||||
base_url=f"http://{self.orin_ip}:8081/v1"
|
||||
)
|
||||
|
||||
|
||||
# --- ChromaDB Client Setup ---
|
||||
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'vector_store'))
|
||||
self.chroma_client = chromadb.PersistentClient(path=vector_store_path)
|
||||
@@ -371,12 +439,11 @@ class PyTreeGenerator:
|
||||
# Explicitly use the remote embedding function for queries
|
||||
embedding_api_url = f"http://{self.orin_ip}:8090/v1/embeddings"
|
||||
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
|
||||
|
||||
self.collection = self.chroma_client.get_collection(
|
||||
name="drone_docs",
|
||||
embedding_function=embedding_func
|
||||
)
|
||||
|
||||
|
||||
allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.system_prompt)
|
||||
self.schema = _generate_pytree_schema(allowed_actions, allowed_conditions)
|
||||
|
||||
@@ -423,7 +490,6 @@ class PyTreeGenerator:
|
||||
final_user_prompt += augmentation
|
||||
else:
|
||||
logging.warning("未检索到上下文或检索失败,将使用原始用户提示词。")
|
||||
|
||||
for attempt in range(3):
|
||||
logging.info(f"--- 第 {attempt + 1}/3 次尝试生成Pytree ---")
|
||||
try:
|
||||
@@ -438,7 +504,6 @@ class PyTreeGenerator:
|
||||
)
|
||||
pytree_str = response.choices[0].message.content
|
||||
pytree_dict = json.loads(pytree_str)
|
||||
|
||||
if _validate_pytree_with_schema(pytree_dict, self.schema):
|
||||
logging.info("成功生成并验证了Pytree。")
|
||||
plan_id = str(uuid.uuid4())
|
||||
@@ -449,15 +514,13 @@ class PyTreeGenerator:
|
||||
vis_path = os.path.join(self.vis_dir, vis_filename)
|
||||
_visualize_pytree(pytree_dict['root'], os.path.splitext(vis_path)[0])
|
||||
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
|
||||
|
||||
return pytree_dict
|
||||
else:
|
||||
logging.warning("生成的Pytree验证失败,正在重试...")
|
||||
|
||||
except (OpenAIError, json.JSONDecodeError) as e:
|
||||
logging.error(f"生成Pytree时发生错误: {e}")
|
||||
|
||||
raise RuntimeError("在3次尝试后,仍未能生成一个有效的Pytree。")
|
||||
|
||||
# Create a single instance for the application
|
||||
py_tree_generator = PyTreeGenerator()
|
||||
py_tree_generator = PyTreeGenerator()
|
||||
Reference in New Issue
Block a user