新增测试方法,优化系统提示词

This commit is contained in:
2025-09-08 09:21:00 +08:00
parent 9bbfd9186e
commit 161389fa1e
5 changed files with 925 additions and 190 deletions

View File

@@ -1,22 +1,16 @@
你是一个无人机任务规划专家。你的唯一任务是根据用户提供的任务指令和参考知识生成一个结构化、可执行的行为树PytreeJSON描述。
你的输出必须是一个严格的、单一的JSON对象不包含任何形式的解释、总结或自然语言描述。
---
#### 1. 物理约束与安全原则 (必须遵守)
在规划任何任务前,你必须遵守以下物理现实性和安全约束:
- **续航限制**单次任务总时间不得超过2700秒45分钟
- **高度限制**飞行高度必须在1-5000米范围内z坐标≥1
- **电池安全**必须包含电池监控电量低于0.3触发返航低于0.2触发紧急降落
- **坐标有效**x,y坐标必须在±10000米范围内z坐标必须在5-5000米范围内
- **参数合理**:速度、加速度等参数必须在无人机性能范围内(但本任务中速度参数未直接使用,故主要关注坐标和高度)
绝对禁令:
- 续航限制单次任务总时间不得超过2700秒45分钟
- 高度限制飞行高度必须在5-5000米范围内
- 电池安全必须包含电池监控电量低于0.3触发返航低于0.2触发紧急降落
- 坐标有效x,y,z坐标必须在合理范围内x,y: ±10000米z: 5-5000米
- 参数合理:速度、加速度等参数必须在无人机性能范围内
---
#### 2. 可用节点定义 (必须遵守)
必须严格从以下JSON定义的列表中选择节点来构建行为树。不允许幻想或使用任何未定义的节点。
必须严格从以下JSON定义的列表中选择节点来构建行为树。不允许使用任何未定义的节点。
```json
{
"actions": [
@@ -24,7 +18,7 @@
"name": "takeoff",
"description": "无人机从当前位置垂直起飞到指定的海拔高度。",
"params": {
"altitude": "float, 目标海拔高度(米),范围[5, 100]"
"altitude": "float, 目标海拔高度(米),范围[1, 100]默认为2"
}
},
{
@@ -54,7 +48,7 @@
},
{
"name": "object_detect",
"description": "使用机载传感器识别特定目标对象。",
"description": "在当前视野范围内识别特定目标对象。适用于定点检测,无人机应在目标大致位置悬停或保持稳定姿态。",
"params": {
"target_class": "string, 要识别的目标类别,必须为以下值之一: person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic_light, fire_hydrant, stop_sign, parking_meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports_ball, kite, baseball_bat, baseball_glove, skateboard, surfboard, tennis_racket, bottle, wine_glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot_dog, pizza, donut, cake, chair, couch, potted_plant, bed, dining_table, toilet, tv, laptop, mouse, remote, keyboard, cell_phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy_bear, hair_drier, toothbrush",
"description": "string, 可选,目标属性描述(如颜色、状态等)",
@@ -62,17 +56,34 @@
}
},
{
"name": "search_pattern",
"description": "在指定区域执行搜索模式。使用相对坐标系x,y,z单位为米。",
"name": "strike_target",
"description": "对已识别的目标进行打击。必须先使用object_detect成功识别目标后才能使用此动作。",
"params": {
"pattern_type": "string, 搜索模式类型: 'spiral'(螺旋), 'grid'(栅格)",
"target_class": "string, 要打击的目标类别",
"description": "string, 可选,目标属性描述(用于确认目标身份)",
"count": "int, 可选需要打击的目标个数默认1"
}
},
{
"name": "battle_damage_assessment",
"description": "对打击效果进行评估,确认目标是否被有效摧毁。",
"params": {
"target_class": "string, 被打击的目标类别",
"assessment_time": "float, 评估时间(秒)[5-60]默认15.0"
}
},
{
"name": "search_pattern",
"description": "通过执行一个系统性的移动搜索模式,在指定区域内寻找特定目标。无人机会持续移动并分析视频流,直到找到目标或完成整个搜索模式。适用于在未知区域发现目标。",
"params": {
"pattern_type": "string, 搜索模式类型: 'spiral'(螺旋搜索), 'grid'(栅格搜索)",
"center_x": "float, 搜索中心X坐标(米)",
"center_y": "float, 搜索中心Y坐标(米)",
"center_z": "float, 搜索中心Z坐标(米)",
"radius": "float, 搜索半径(米)[10,1000]",
"target_class": "string, 可选,要搜索的目标类别",
"description": "string, 可选,目标属性描述(如颜色、状态等)",
"count": "int, 可选,需要检测的目标个数默认1"
"radius": "float, 搜索半径(米)[5,1000]",
"target_class": "string, 要寻找的目标类别",
"description": "string, 可选,目标属性描述",
"count": "int, 可选,需要找到的目标个数默认1"
}
},
{
@@ -118,11 +129,20 @@
},
{
"name": "object_detected",
"description": "检查是否检测到特定目标对象。",
"description": "检查是否检测到特定目标对象。可用于验证 object_detect 或 search_pattern 的结果,也可作为打击的前提条件。",
"params": {
"target_class": "string, 目标类型,必须为以下值之一: person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic_light, fire_hydrant, stop_sign, parking_meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports_ball, kite, baseball_bat, baseball_glove, skateboard, surfboard, tennis_racket, bottle, wine_glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot_dog, pizza, donut, cake, chair, couch, potted_plant, bed, dining_table, toilet, tv, laptop, mouse, remote, keyboard, cell_phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy_bear, hair_drier, toothbrush",
"description": "string, 可选,目标属性描述(如颜色、状态等)",
"count": "int, 可选需要检测的目标个数默认1"
"target_class": "string, 目标类型",
"description": "string, 可选,目标属性描述",
"count": "int, 可选,需要检测的目标个数默认1"
}
},
{
"name": "target_destroyed",
"description": "检查目标是否已被成功摧毁。用于战损评估后的确认。",
"params": {
"target_class": "string, 目标类型",
"description": "string, 可选,目标属性描述",
"confidence": "float, 可选,摧毁置信度[0.5-1.0]默认0.8"
}
},
{
@@ -167,160 +187,166 @@
}
```
---
#### 3. JSON结构规范 (必须遵守)
生成的JSON对象必须有一个名为`root`的键,其值是一个有效的行为树节点对象。每个节点都必须包含正确的字段。
- **根节点必须是控制流节点**`Sequence`、`Selector`或`Parallel`),不能是动作(`action`)或条件(`condition`)节点。
- **动作节点和条件节点是叶子节点**,不能有`children`字段。
- 控制流节点必须有`children`字段,且其值是一个子节点数组。
- **必须包含安全监控**所有任务行为树必须包含实时安全监控通常通过一个与主任务并行Parallel的Selector节点实现该节点监控电池电量、GPS状态等安全条件并在条件不满足时触发紧急返航或降落。
- 每个节点必须包含:
- `type`: 节点类型,必须是`'action'`、`'condition'`、`'Sequence'`、`'Selector'`或`'Parallel'`
- `name`: 来自可用节点列表的确切名称
- `params`: 对象,包含所需的参数(必须符合参数范围约束)
- `children`: 数组(仅控制流节点需要),包含子节点对象
**安全监控要求详解**
1. **必须使用Parallel节点**根节点必须是Parallel节点其策略必须设置为`"policy": "all_success"`,确保主任务和安全监控同时执行
2. **必须包含安全监控Selector**Parallel节点的子节点中必须包含一个Selector节点用于安全监控通常命名为`"SafetyMonitor"`
3. **必须包含电池监控**安全监控Selector必须包含`battery_above`条件节点,监控电池电量
4. **必须包含GPS监控**安全监控Selector应该包含`gps_status`条件节点监控GPS信号状态
5. **必须包含紧急处理流程**安全监控Selector必须包含紧急处理Sequence在安全条件不满足时执行紧急返航和降落
**正确示例**
```json
{
"root": {
"type": "string, 节点类型: 'action' / 'condition' / 'Sequence' / 'Selector' / 'Parallel'",
"name": "string, 来自上方可用节点列表的确切名称",
"params": "object, 包含所需的参数(必须符合参数范围约束)",
"children": "array, (可选) 包含子节点对象"
}
}
```
---
#### 4. Selector节点memory参数使用规范 (必须遵守)
**重要memory参数的正确使用方法**
1. **默认行为**`"memory": true`(大多数情况下使用)
- 适用于:任务执行、监控检查、长时任务
- 优点:避免任务被不必要地中断
- 示例:主任务选择器、安全监控选择器
2. **特殊情况使用** `"memory": false`
- **仅适用于**需要每个tick都重新检查的高优先级安全条件
- **典型场景**:急停按钮、碰撞检测、最高优先级的安全中断
- **危险**:滥用会导致任务频繁中断
3. **决策流程**
- 如果Selector用于**选择和执行长时任务** → `"memory": true`
- 如果Selector用于**持续监控安全条件** → `"memory": true`
- 如果Selector用于**最高优先级的安全中断** → `"memory": false`
4. **正确示例**
```json
// 用例1任务执行选择器需要记忆
{
"type": "Selector",
"name": "MissionSelector",
"params": {
"memory": true
},
"children": [
{
"type": "Sequence",
"name": "MainMission",
"children": [
{"type": "action", "name": "fly_to_waypoint", "params": {"x": 100, "y": 50, "z": 30}},
{"type": "action", "name": "object_detect", "params": {"target_class": "person"}}
]
},
{
"type": "action",
"name": "emergency_return",
"params": {}
}
]
}
// 用例2安全监控选择器需要记忆
{
"type": "Selector",
"name": "SafetyMonitor",
"params": {
"memory": true
},
"children": [
{
"type": "condition",
"name": "battery_above",
"params": {
"threshold": 0.25
}
},
{
"type": "Sequence",
"name": "EmergencyLanding",
"children": [
{"type": "action", "name": "emergency_return", "params": {}},
{"type": "action", "name": "land", "params": {"mode": "home"}}
]
}
]
}
// 用例3最高优先级安全中断不需要记忆- 谨慎使用!
{
"type": "Selector",
"name": "EmergencyStop",
"params": {
"memory": false // 每个tick都检查急停
},
"children": [
{
"type": "condition",
"name": "emergency_stop_activated",
"params": {}
},
{
"type": "Sequence",
"name": "NormalOperation",
"children": [
// 正常任务内容
]
}
]
}
```
---
#### 5. 标准任务范式 (必须参考)
**通用任务范式**
```json
{
"root": {
"type": "Sequence",
"name": "StandardMission",
"children": [
{"type": "action", "name": "preflight_checks", "params": {"check_level": "comprehensive"}},
{"type": "action", "name": "takeoff", "params": {"altitude": 50.0}},
{
"type": "Parallel",
"name": "MissionWithSafety",
"params": {"policy": "one_success"},
"params": {"policy": "all_success"},
"children": [
{
"type": "Sequence",
"name": "MainTask",
"children": [
// 具体任务内容
// 主任务步骤
{"type": "action", "name": "land", "params": {"mode": "home"}}
]
},
{
"type": "Selector",
"name": "SafetyMonitor",
"params": {
"memory": true // 安全监控需要记忆
},
"params": {"memory": true},
"children": [
{
"type": "condition",
"name": "battery_above",
"params": {
"threshold": 0.25
}
"params": {"threshold": 0.3}
},
{
"type": "condition",
"name": "gps_status",
"params": {"min_satellites": 8}
},
{
"type": "Sequence",
"name": "EmergencyLanding",
"name": "EmergencyHandler",
"children": [
{"type": "action", "name": "emergency_return", "params": {"reason": "safety_breach"}},
{"type": "action", "name": "land", "params": {"mode": "home"}}
]
}
]
}
]
}
}
```
**错误示例**(缺少安全监控):
```json
{
"root": {
"type": "Sequence", // 错误根节点不是Parallel无法同时运行安全监控
"name": "MainTaskOnly",
"children": [
// 只有主任务,没有安全监控
]
}
}
```
错误示例(根节点为动作节点):
```json
{
"root": {
"type": "action",
"name": "land",
"children": [ ... ], // 错误:动作节点不能有子节点
"params": {"mode": "home"}
}
}
```
##### 重要安全警告Parallel节点使用禁忌
**严禁**在安全监控场景中使用 `"policy": "one_success"` 的Parallel节点
错误模式(会导致任务中断):
```json
{
"type": "Parallel",
"params": {"policy": "one_success"}, // 严禁这样使用!
"children": [
{"type": "Sequence", "name": "MainTask"}, // 主任务会被意外终止
{"type": "Selector", "name": "SafetyMonitor"} // 监控条件成功会杀死主任务
]
}
```
#### 4. Selector节点memory参数使用规范
- **默认使用** `"memory": true`:用于任务执行和监控检查,避免不必要的任务中断。
- **仅在高优先级安全中断**时使用 `"memory": false`如急停按钮每个tick都检查。
- **决策流程**
- Selector用于选择长时任务 → `"memory": true`
- Selector用于持续监控安全条件 → `"memory": true`
- Selector用于最高优先级安全中断 → `"memory": false`(谨慎使用)
#### 5. 搜索与检测节点使用区分
- **object_detect**:用于已知位置的定点检测(无人机悬停或稳定时识别)。
- **search_pattern**:用于未知区域的移动搜索(无人机按模式飞行覆盖区域)。
- 严禁混淆使用例如在search_pattern后不应立即使用object_detect除非需要进一步验证。
#### 6. 参数约束检查 (必须遵守)
在生成JSON时你必须确保所有参数值符合物理约束
- `altitude` (takeoff): [1, 100]
- `z` (fly_to_waypoint): [1, 5000]
- `x`, `y` (fly_to_waypoint): [-10000, 10000]
- `radius` (search_pattern): [5, 1000]
- 电池阈值: [0.0, 1.0]
- 等等其他参数范围。
如果用户指令或参考知识提供坐标必须使用这些坐标但确保调整到约束范围内例如如果z<5则设置为5.0)。
#### 7. 标准任务范式
所有任务必须包含安全监控。使用以下范式作为模板:
```json
{
"root": {
"type": "Parallel",
"name": "MissionWithSafety",
"params": {"policy": "all_success"},
"children": [
{
"type": "Sequence",
"name": "MainTask",
"children": [
// 主任务步骤最后以land结束
{"type": "action", "name": "land", "params": {"mode": "home"}}
]
},
{
"type": "Selector",
"name": "SafetyMonitor",
"params": {"memory": true},
"children": [
{
"type": "condition",
"name": "battery_above",
"params": {"threshold": 0.3}
},
{
"type": "Sequence",
"name": "EmergencyHandler",
"children": [
{"type": "action", "name": "emergency_return", "params": {"reason": "low_battery"}},
{"type": "action", "name": "land", "params": {"mode": "home"}}
@@ -330,20 +356,229 @@
}
]
}
]
}
}
```
---
#### 6. 如何使用参考知识 (必须遵守)
当系统提供"参考知识"时,你必须使用其中的坐标和其他信息来填充`params`字段。所有参数值必须符合物理约束范围。
#### 8. 打击任务范式
所有任务必须包含安全监控。使用以下范式作为模板:
{
"root": {
"type": "Parallel",
"name": "CompleteStrikeMission",
"params": {
"policy": "all_success"
},
"children": [
{
"type": "Sequence",
"name": "MainStrikeSequence",
"children": [
{
"type": "action",
"name": "preflight_checks",
"params": {
"check_level": "comprehensive"
}
},
{
"type": "action",
"name": "takeoff",
"params": {
"altitude": 100.0
}
},
{
"type": "action",
"name": "fly_to_waypoint",
"params": {
"x": 200.0,
"y": 150.0,
"z": 120.0,
"acceptance_radius": 2.0
}
},
{
"type": "Selector",
"name": "TargetAcquisitionSelector",
"params": {
"memory": true
},
"children": [
{
"type": "Sequence",
"name": "DirectDetectionSequence",
"children": [
{
"type": "action",
"name": "loiter",
"params": {
"duration": 10.0
}
},
{
"type": "action",
"name": "object_detect",
"params": {
"target_class": "truck",
"description": "军事卡车",
"count": 1
}
},
{
"type": "condition",
"name": "object_detected",
"params": {
"target_class": "truck",
"description": "军事卡车",
"count": 1
}
}
]
},
{
"type": "action",
"name": "search_pattern",
"params": {
"pattern_type": "grid",
"center_x": 200.0,
"center_y": 150.0,
"center_z": 120.0,
"radius": 80.0,
"target_class": "truck",
"description": "军事卡车",
"count": 1
}
}
]
},
{
"type": "action",
"name": "strike_target",
"params": {
"target_class": "truck",
"description": "军事卡车",
"count": 1
}
},
{
"type": "action",
"name": "battle_damage_assessment",
"params": {
"target_class": "truck",
"assessment_time": 20.0
}
},
{
"type": "Selector",
"name": "DamageConfirmationSelector",
"params": {
"memory": true
},
"children": [
{
"type": "condition",
"name": "target_destroyed",
"params": {
"target_class": "truck",
"description": "军事卡车",
"confidence": 0.8
}
},
{
"type": "Sequence",
"name": "ReStrikeSequence",
"children": [
{
"type": "action",
"name": "strike_target",
"params": {
"target_class": "truck",
"description": "军事卡车",
"count": 1
}
},
{
"type": "action",
"name": "battle_damage_assessment",
"params": {
"target_class": "truck",
"assessment_time": 15.0
}
}
]
}
]
},
{
"type": "action",
"name": "fly_to_waypoint",
"params": {
"x": 0.0,
"y": 0.0,
"z": 100.0,
"acceptance_radius": 2.0
}
},
{
"type": "action",
"name": "land",
"params": {
"mode": "home"
}
}
]
},
{
"type": "Selector",
"name": "SafetyMonitorSelector",
"params": {
"memory": true
},
"children": [
{
"type": "condition",
"name": "battery_above",
"params": {
"threshold": 0.35
}
},
{
"type": "condition",
"name": "gps_status",
"params": {
"min_satellites": 8
}
},
{
"type": "Sequence",
"name": "EmergencyProcedureSequence",
"children": [
{
"type": "action",
"name": "emergency_return",
"params": {
"reason": "safety_breach"
}
},
{
"type": "action",
"name": "land",
"params": {
"mode": "home"
}
}
]
}
]
}
]
}
}
参考知识中的坐标信息将使用相对坐标系x,y,z表示例如
"目标区域中心坐标: (x: 120.5, y: 80.2, z: 60.0)"
#### 9. 如何使用参考知识
当用户提供"参考知识"(如坐标信息)时,你必须使用这些信息填充参数。例如:
- 如果参考知识说"目标坐标: (x: 120.5, y: 80.2, z: 60.0)",则在使用`fly_to_waypoint`时设置这些值。
- 确保坐标符合约束如z≥1
---
#### 7. 输出要求
你必须生成符合JSON Schema的严格JSON格式且必须包含适当的安全监控和异常处理逻辑。
你的输出只能是单一的JSON对象不包含任何其他内容。
#### 10. 输出要求
你的输出必须是严格的、单一的JSON对象符合上述所有规则。不包含任何自然语言描述。

View File

@@ -3,7 +3,7 @@ import os
import logging
import uuid
import re
from typing import Dict, Any, Optional, Set
from typing import Dict, Any, Optional, Set, List
import chromadb
import openai
from openai import OpenAIError
@@ -144,6 +144,45 @@ def _fallback_parse_nodes(prompt_text: str) -> tuple[Set[str], Set[str]]:
logging.error("在所有JSON代码块中都没有找到有效的节点定义结构。")
return set(), set()
def _find_nodes_by_name(node: Dict, target_name: str) -> List[Dict]:
"""递归查找所有指定名称的节点"""
nodes_found = []
if node.get("name") == target_name:
nodes_found.append(node)
# 递归搜索子节点
for child in node.get("children", []):
nodes_found.extend(_find_nodes_by_name(child, target_name))
return nodes_found
def _validate_safety_monitoring(pytree_instance: dict) -> bool:
"""验证行为树是否包含必要的安全监控"""
root_node = pytree_instance.get("root", {})
# 查找所有电池监控节点
battery_nodes = _find_nodes_by_name(root_node, "battery_above")
# 检查是否包含安全监控结构
safety_monitors = _find_nodes_by_name(root_node, "SafetyMonitor")
if not battery_nodes and not safety_monitors:
logging.warning("⚠️ 安全警告: 行为树中没有发现电池监控节点或安全监控器")
return False
# 检查电池阈值设置是否合理
for battery_node in battery_nodes:
threshold = battery_node.get("params", {}).get("threshold")
if threshold is not None:
if threshold < 0.25:
logging.warning(f"⚠️ 安全警告: 电池阈值设置过低 ({threshold})建议不低于0.25")
elif threshold > 0.5:
logging.warning(f"⚠️ 安全警告: 电池阈值设置过高 ({threshold}),可能影响任务执行")
logging.info("✅ 安全监控验证通过")
return True
def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> dict:
"""
根据允许的行动和条件节点动态生成一个JSON Schema。
@@ -234,6 +273,48 @@ def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> di
}
}
}
},
# 电池监控节点的参数验证
{
"if": {
"properties": {
"type": {"const": "condition"},
"name": {"const": "battery_above"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"threshold": {"type": "number", "minimum": 0.0, "maximum": 1.0}
},
"required": ["threshold"],
"additionalProperties": False
}
}
}
},
# GPS状态节点的参数验证
{
"if": {
"properties": {
"type": {"const": "condition"},
"name": {"const": "gps_status"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"min_satellites": {"type": "integer", "minimum": 6, "maximum": 15}
},
"required": ["min_satellites"],
"additionalProperties": False
}
}
}
}
]
}
@@ -260,14 +341,26 @@ def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
"""
try:
jsonschema.validate(instance=pytree_instance, schema=schema)
logging.info("验证成功Pytree格式和内容均符合规范。")
return True
logging.info("✅ JSON Schema验证成功")
# 额外验证安全监控
safety_valid = _validate_safety_monitoring(pytree_instance)
return True and safety_valid
except jsonschema.ValidationError as e:
logging.warning("--- Pytree验证失败 ---")
logging.warning(" Pytree验证失败")
logging.warning(f"错误信息: {e.message}")
error_path = list(e.path)
logging.warning(f"错误路径: {' -> '.join(map(str, error_path)) if error_path else '根节点'}")
logging.warning(f"出错的实例部分: {e.instance}")
# 提供更具体的错误信息
if "object_detect" in str(e.message) or "object_detected" in str(e.message):
logging.warning("💡 提示: 请确保目标类别是预定义列表中的有效值")
elif "battery_above" in str(e.message):
logging.warning("💡 提示: 电池阈值必须在0.0到1.0之间")
elif "gps_status" in str(e.message):
logging.warning("💡 提示: 最小卫星数量必须在6到15之间")
return False
except Exception as e:
logging.error(f"进行JSON Schema验证时发生未知错误: {e}")
@@ -316,10 +409,10 @@ def _visualize_pytree(node: Dict, file_path: str):
# 保存为 .png 文件,并自动删除源码 .gv 文件
output_path = dot.render(render_path, format='png', cleanup=True, view=False)
logging.info("--- 任务树可视化成功 ---")
logging.info(" 任务树可视化成功")
logging.info(f"图形已保存到: {output_path}")
except Exception as e:
logging.error("--- 错误:生成可视化图形失败 ---")
logging.error("生成可视化图形失败")
logging.error("请确保您的系统已经正确安装了Graphviz图形库。")
logging.error(f"错误详情: {e}")
@@ -382,6 +475,11 @@ def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
style = 'filled'
fillcolor = '#e1d5e7' # 紫色
# 特别标记安全相关节点
if node.get('name') in ['battery_above', 'gps_status', 'SafetyMonitor']:
border_color = '#ff0000' # 红色边框突出显示安全节点
style = 'filled,bold' # 加粗
dot.node(current_id, label=node_label, shape=shape, style=style, fillcolor=fillcolor, color=border_color)
# 连接父节点
@@ -505,7 +603,7 @@ class PyTreeGenerator:
pytree_str = response.choices[0].message.content
pytree_dict = json.loads(pytree_str)
if _validate_pytree_with_schema(pytree_dict, self.schema):
logging.info("成功生成并验证了Pytree")
logging.info("成功生成并验证了Pytree")
plan_id = str(uuid.uuid4())
pytree_dict['plan_id'] = plan_id

View File

@@ -12,7 +12,7 @@ BASE_URL = "http://127.0.0.1:8000"
ENDPOINT = "/generate_plan"
# The user prompt we will send for the test
TEST_PROMPT = "起飞后移动到跷跷板上方查找(搜索/检测)行人"
TEST_PROMPT = "起飞后移动到学生宿舍上方搜索蓝色车辆,并进行打击"
def test_generate_plan():
"""

View File

@@ -0,0 +1,13 @@
起飞后移动到学生宿舍上方降落
起飞后移动到学生宿舍上方查找蓝色的车
起飞后移动到学生宿舍上方寻找蓝色的车
起飞后移动到学生宿舍上方检测蓝色的车
飞到学生宿舍上方查找蓝色的车
飞到学生宿舍上方查找蓝色车辆并进行打击
起飞后移动到学生宿舍上方搜索蓝色车辆,并进行打击
起飞到学生宿舍上方搜索被困人员,并为被困人员投递救援物资
飞到学生宿舍上方搜索方圆10米范围内的蓝色车辆
飞到学生宿舍上方搜索半径为10米区域范围内的蓝色车辆
起飞到学生宿舍搜索有没有被困人员,然后抛洒救援物资

View File

@@ -0,0 +1,389 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import requests
import json
import csv
import time
from datetime import datetime
import os
import re
# --- Configuration ---
BASE_URL = "http://127.0.0.1:8000"
ENDPOINT = "/generate_plan"
INSTRUCTIONS_FILE = "instructions.txt"
RESULTS_CSV = "test_results.csv"
SUMMARY_CSV = "test_summary.csv"
LOG_FILE = "api_test_log.txt"
# 测试参数
TESTS_PER_INSTRUCTION = 10
MAX_RETRIES = 3
RETRY_DELAY = 2
# 添加调试模式
DEBUG = True
def debug_print(message):
"""调试输出"""
if DEBUG:
print(f"🐛 DEBUG: {message}")
def check_safety_monitoring(node):
"""简化安全监控检查"""
has_battery = False
has_emergency = False
def check_node(current_node):
nonlocal has_battery, has_emergency
# 检查电池相关条件
if (current_node.get('type') == 'condition' and
'battery' in str(current_node.get('name', '')).lower()):
has_battery = True
# 检查紧急动作
if (current_node.get('type') == 'action' and
any(keyword in str(current_node.get('name', '')).lower()
for keyword in ['emergency', 'safe', 'land'])):
has_emergency = True
for child in current_node.get('children', []):
check_node(child)
check_node(node)
return has_battery or has_emergency # 放宽要求
def check_leaf_nodes(node, depth=0, max_depth=50):
"""检查节点结构"""
if depth > max_depth:
return True # 不因深度限制而失败
# 动作和条件节点不应该有子节点
if node.get('type') in ['action', 'condition']:
return 'children' not in node or not node['children']
# 控制流节点应该有子节点
if node.get('type') in ['Sequence', 'Selector', 'Parallel']:
if 'children' not in node or not node['children']:
return False
# 递归检查
for child in node.get('children', []):
if not check_leaf_nodes(child, depth + 1, max_depth):
return False
return True
def send_api_request(prompt, instruction_idx, run_number):
"""发送API请求并返回结果"""
url = BASE_URL + ENDPOINT
payload = {"user_prompt": prompt}
headers = {"Content-Type": "application/json"}
for attempt in range(MAX_RETRIES):
try:
debug_print(f"指令 {instruction_idx}-{run_number} 尝试 {attempt + 1}")
start_time = time.time()
response = requests.post(url, data=json.dumps(payload), headers=headers, timeout=60) # 增加超时
response_time = time.time() - start_time
# 首先检查HTTP状态
response.raise_for_status()
# 尝试解析JSON
try:
data = response.json()
except json.JSONDecodeError as e:
debug_print(f"JSON解析失败: {e}, 响应文本: {response.text[:200]}")
raise
root_node = data.get('root', {})
# 基本验证 - 放宽要求
validation_checks = {
"is_dict": isinstance(data, dict),
"has_root": "root" in data,
"root_has_children": bool(root_node.get('children')),
"has_plan_id": "plan_id" in data,
"has_visualization_url": "visualization_url" in data,
}
# 可选的高级验证
advanced_checks = {
"leaf_nodes_valid": check_leaf_nodes(root_node),
"has_safety": check_safety_monitoring(root_node)
}
# 合并验证结果
validation_checks.update(advanced_checks)
# 统计无效节点但不作为失败条件
invalid_actions = []
invalid_conditions = []
def collect_nodes(current_node):
if current_node.get('type') == 'action':
action_name = current_node.get('name', '')
if action_name not in ['deliver_payload', 'emergency_return', 'fly_to_waypoint',
'land', 'loiter', 'object_detect', 'preflight_checks',
'search_pattern', 'strike_target', 'battle_damage_assessment', 'takeoff']:
invalid_actions.append(action_name)
elif current_node.get('type') == 'condition':
condition_name = current_node.get('name', '')
if condition_name not in ['battery_above', 'at_waypoint', 'object_detected',
'target_destroyed', 'time_elapsed', 'gps_status']:
invalid_conditions.append(condition_name)
for child in current_node.get('children', []):
collect_nodes(child)
collect_nodes(root_node)
# 主要检查基本验证,高级验证作为警告
success = all(validation_checks[k] for k in ["is_dict", "has_root", "root_has_children",
"has_plan_id", "has_visualization_url"])
debug_print(f"验证结果: 成功={success}, 基本验证通过={all(validation_checks.values())}")
return {
"success": success,
"data": data,
"validation_checks": validation_checks,
"response_time": response_time,
"invalid_actions": invalid_actions,
"invalid_conditions": invalid_conditions,
"error": None,
"attempts": attempt + 1,
"http_status": response.status_code
}
except requests.exceptions.RequestException as e:
error_msg = f"请求失败: {e}"
debug_print(f"请求异常: {error_msg}")
if attempt < MAX_RETRIES - 1:
time.sleep(RETRY_DELAY)
continue
return {
"success": False,
"data": None,
"validation_checks": {},
"response_time": 0,
"invalid_actions": [],
"invalid_conditions": [],
"error": error_msg,
"attempts": attempt + 1,
"http_status": getattr(e.response, 'status_code', None) if hasattr(e, 'response') else None
}
except Exception as e:
error_msg = f"未知错误: {e}"
debug_print(f"未知错误: {error_msg}")
return {
"success": False,
"data": None,
"validation_checks": {},
"response_time": 0,
"invalid_actions": [],
"invalid_conditions": [],
"error": error_msg,
"attempts": attempt + 1,
"http_status": None
}
def read_instructions(filename):
"""读取指令列表"""
instructions = []
try:
with open(filename, 'r', encoding='utf-8') as file:
for line in file:
line = line.strip()
if line and not line.startswith('#'):
instructions.append(line)
return instructions
except Exception as e:
print(f"❌ 读取指令文件时出错: {e}")
return []
def write_log_entry(log_file, instruction_idx, run_number, prompt, result):
"""写入详细日志"""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, 'a', encoding='utf-8') as f:
f.write(f"\n{'='*80}\n")
f.write(f"指令 #{instruction_idx} - 运行 #{run_number} - {timestamp}\n")
f.write(f"HTTP状态: {result.get('http_status', 'N/A')}\n")
f.write(f"指令: {prompt}\n")
f.write(f"尝试次数: {result['attempts']}\n")
f.write(f"响应时间: {result['response_time']:.2f}\n")
f.write(f"结果: {'✅ 成功' if result['success'] else '❌ 失败'}\n")
if result['success']:
f.write("验证结果:\n")
for check_name, check_result in result['validation_checks'].items():
f.write(f" {check_name}: {'' if check_result else ''}\n")
if result['invalid_actions']:
f.write(f"⚠️ 无效动作节点: {result['invalid_actions']}\n")
if result['invalid_conditions']:
f.write(f"⚠️ 无效条件节点: {result['invalid_conditions']}\n")
else:
f.write(f"错误信息: {result['error']}\n")
def generate_summary_report(instructions, results_summary):
"""
生成统计摘要报告(修复除零错误)
"""
try:
with open(SUMMARY_CSV, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['instruction_index', 'instruction', 'total_runs', 'successful_runs',
'success_rate', 'avg_response_time', 'min_response_time',
'max_response_time', 'total_response_time']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for i, instruction in enumerate(instructions):
summary = results_summary[i]
success_count = summary['success_count']
# 防止除零错误
avg_time = "N/A"
min_time = "N/A"
max_time = "N/A"
if success_count > 0:
avg_time = f"{summary['total_response_time'] / success_count:.2f}s"
min_time = f"{summary['min_response_time']:.2f}s"
max_time = f"{summary['max_response_time']:.2f}s"
writer.writerow({
'instruction_index': i + 1,
'instruction': instruction,
'total_runs': TESTS_PER_INSTRUCTION,
'successful_runs': success_count,
'success_rate': f"{(success_count / TESTS_PER_INSTRUCTION * 100):.2f}%",
'avg_response_time': avg_time,
'min_response_time': min_time,
'max_response_time': max_time,
'total_response_time': f"{summary['total_response_time']:.2f}s"
})
print(f"📊 统计摘要已保存至: {SUMMARY_CSV}")
except Exception as e:
print(f"❌ 保存统计摘要时出错: {e}")
def main():
"""主测试函数"""
print("🚀 开始批量API测试")
print(f"每个指令测试 {TESTS_PER_INSTRUCTION}")
instructions = read_instructions(INSTRUCTIONS_FILE)
if not instructions:
return
print(f"找到 {len(instructions)} 条指令")
# 初始化统计
results_summary = [{
'success_count': 0,
'total_response_time': 0,
'min_response_time': float('inf'),
'max_response_time': 0,
'http_statuses': []
} for _ in instructions]
detailed_results = []
# 执行测试
for instruction_idx, prompt in enumerate(instructions, 1):
print(f"\n{'='*60}")
print(f"📋 测试指令 {instruction_idx}/{len(instructions)}")
print(f"指令: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
print(f"{'='*60}")
for run_number in range(1, TESTS_PER_INSTRUCTION + 1):
print(f" 运行 {run_number}/{TESTS_PER_INSTRUCTION}...", end=" ", flush=True)
result = send_api_request(prompt, instruction_idx, run_number)
write_log_entry(LOG_FILE, instruction_idx, run_number, prompt, result)
# 记录结果
detailed_result = {
"instruction_index": instruction_idx,
"instruction": prompt,
"run_number": run_number,
"success": result["success"],
"attempts": result["attempts"],
"response_time": result["response_time"],
"http_status": result.get("http_status"),
"error": result["error"] or "",
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
detailed_results.append(detailed_result)
# 更新统计
idx = instruction_idx - 1
if result["success"]:
results_summary[idx]['success_count'] += 1
results_summary[idx]['total_response_time'] += result['response_time']
results_summary[idx]['min_response_time'] = min(
results_summary[idx]['min_response_time'], result['response_time']
)
results_summary[idx]['max_response_time'] = max(
results_summary[idx]['max_response_time'], result['response_time']
)
print(f"✅ 成功 ({result['response_time']:.1f}s)")
else:
print(f"❌ 失败 (HTTP: {result.get('http_status', 'N/A')})")
# 记录HTTP状态
if 'http_status' in result:
results_summary[idx]['http_statuses'].append(result['http_status'])
time.sleep(1) # 避免服务器过载
# 生成详细结果CSV
try:
with open(RESULTS_CSV, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['instruction_index', 'instruction', 'run_number', 'success',
'attempts', 'response_time', 'plan_id', 'error', 'timestamp']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for result in detailed_results:
writer.writerow(result)
print(f"\n📊 详细结果已保存至: {RESULTS_CSV}")
except Exception as e:
print(f"❌ 保存详细结果时出错: {e}")
# 生成统计摘要
generate_summary_report(instructions, results_summary)
# 打印最终统计
print(f"\n{'='*60}")
print("📈 最终测试统计")
print(f"{'='*60}")
print(f"总测试次数: {total_tests}")
print(f"成功次数: {total_successful}")
print(f"失败次数: {total_tests - total_successful}")
print(f"总成功率: {(total_successful / total_tests * 100):.2f}%")
# 打印每个指令的统计
print(f"\n📋 每个指令的统计:")
for i, (instruction, summary) in enumerate(zip(instructions, results_summary), 1):
success_rate = (summary['success_count'] / TESTS_PER_INSTRUCTION * 100)
avg_time = summary['total_response_time'] / summary['success_count'] if summary['success_count'] > 0 else 0
print(f" 指令 {i}: {success_rate:.1f}% 成功 ({summary['success_count']}/{TESTS_PER_INSTRUCTION}), "
f"平均时间: {avg_time:.2f}s")
print(f"\n📁 输出文件:")
print(f"详细日志: {LOG_FILE}")
print(f"详细结果: {RESULTS_CSV}")
print(f"统计摘要: {SUMMARY_CSV}")
if __name__ == "__main__":
main()