优化交互式测试验证脚本,针对场景4修改提示词以及代码
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -18,7 +18,11 @@
|
||||
{"name": "takeoff", "description": "无人机从当前位置垂直起飞到指定的海拔高度。", "params": {"altitude": "float, 目标海拔高度(米),范围[1, 100],默认为2"}},
|
||||
{"name": "land", "description": "降落无人机。可选择当前位置或返航点降落。", "params": {"mode": "string, 可选值: 'current'(当前位置), 'home'(返航点)"}},
|
||||
{"name": "fly_to_waypoint", "description": "导航至一个指定坐标点。使用相对坐标系(x,y,z),单位为米。", "params": {"x": "float", "y": "float", "z": "float", "acceptance_radius": "float, 可选,默认2.0"}},
|
||||
{"name": "move_direction", "description": "按指定方向直线移动。方向可为绝对方位或相对机体朝向。", "params": {"direction": "string: north|south|east|west|forward|backward|left|right", "distance": "float[1,10000], 可选, 不指定则持续移动"}},
|
||||
{"name": "move_direction", "description": "按指定方向直线移动。方向可为绝对方位或相对机体朝向。", "params": {"direction": "string: north|south|east|west|forward|backward|left|right", "distance": "float[1,10000], 可选, 不指定则持续移动", "speed": "float, 可选"}},
|
||||
{"name": "approach_target", "description": "快速趋近目标至固定距离。", "params": {"target_class": "string, 要趋近的目标类别", "description": "string, 可选", "stop_distance": "float, 期望的最终停止距离", "speed": "float, 可选"}},
|
||||
{"name": "rotate", "description": "旋转固定角度。", "params": {"angle": "float, 旋转角度(正数逆时针, 负数顺时针)", "angular_velocity": "rad/s, 旋转角速度"}},
|
||||
{"name": "rotate_search", "description": "原地旋转搜索目标。", "params": {"target_class": "string, 要搜寻的目标类别", "description": "string, 可选", "step_angle": "float, 可选, 每一步旋转的角度", "total_rotation": "float, 可选, 总共旋转搜索的角度"}},
|
||||
{"name": "manual_confirmation", "description": "前端弹窗是否继续执行后续任务。", "params": {}},
|
||||
{"name": "orbit_around_point", "description": "以给定中心点为中心,等速圆周飞行指定圈数。", "params": {"center_x": "float", "center_y": "float", "center_z": "float", "radius": "float[5,1000]", "laps": "int[1,20]", "clockwise": "boolean, 可选, 默认true", "speed_mps": "float[0.5,15], 可选", "gimbal_lock": "boolean, 可选, 默认true"}},
|
||||
{"name": "orbit_around_target", "description": "以目标为中心,等速圆周飞行指定圈数(需已有目标)。", "params": {"target_class": "string, 取值同object_detect列表", "description": "string, 可选", "radius": "float[5,1000]", "laps": "int[1,20]", "clockwise": "boolean, 可选, 默认true", "speed_mps": "float[0.5,15], 可选", "gimbal_lock": "boolean, 可选, 默认true"}},
|
||||
{"name": "loiter", "description": "在当前位置上空悬停一段时间或直到条件触发。", "params": {"duration": "float, 可选[1,600]", "until_condition": "string, 可选"}},
|
||||
@@ -32,7 +36,6 @@
|
||||
{"name": "emergency_return", "description": "执行紧急返航程序。", "params": {"reason": "string"}}
|
||||
],
|
||||
"conditions": [
|
||||
{"name": "battery_above", "description": "电池电量高于阈值。", "params": {"threshold": "float[0.0,1.0]"}},
|
||||
{"name": "at_waypoint", "description": "在指定坐标容差范围内。", "params": {"x": "float", "y": "float", "z": "float", "tolerance": "float, 可选, 默认3.0"}},
|
||||
{"name": "object_detected", "description": "检测到特定目标。", "params": {"target_class": "string", "description": "string, 可选", "count": "int, 可选, 默认1"}},
|
||||
{"name": "target_destroyed", "description": "目标已被摧毁。", "params": {"target_class": "string", "description": "string, 可选", "confidence": "float[0.5-1.0], 默认0.8"}},
|
||||
@@ -58,4 +61,4 @@
|
||||
- “环绕X米Y圈” → 若有目标上下文则使用 `orbit_around_target`,否则根据是否给出中心坐标选择 `orbit_around_point`;`radius=X`,`laps=Y`,默认 `clockwise=true`,`gimbal_lock=true`
|
||||
- “顺时针/逆时针” → `clockwise=true/false`
|
||||
- “等速” → 若未给速度则 `speed_mps` 采用默认值(例如3.0);若口令指明速度,裁剪到[0.5,15]
|
||||
- “以(x,y,z)为中心”/“当前位置为中心” → 选择 `orbit_around_point` 并填充 `center_x/center_y/center_z`
|
||||
- “以(x,y,z)为中心”/“当前位置为中心” → 选择 `orbit_around_point` 并填充 `center_x/center_y/center_z`
|
||||
|
||||
@@ -9,7 +9,11 @@
|
||||
{"name":"takeoff","params":{"altitude":"float[1,100],默认2"}},
|
||||
{"name":"land","params":{"mode":"'current'/'home'"}},
|
||||
{"name":"fly_to_waypoint","params":{"x":"±10000","y":"±10000","z":"[1,5000]","acceptance_radius":"默认2.0"}},
|
||||
{"name":"move_direction","params":{"direction":"north/south/east/west/forward/backward/left/right","distance":"[1,10000],缺省持续移动"}},
|
||||
{"name":"move_direction","params":{"direction":"north/south/east/west/forward/backward/left/right","distance":"[1,10000],缺省持续移动","speed":"float,可选"}},
|
||||
{"name":"approach_target","params":{"target_class":"string,要趋近的目标类别","description":"string,可选,目标属性描述","stop_distance":"float,期望的最终停止距离","speed":"float,可选,期望的逼近速度"}},
|
||||
{"name":"rotate","params":{"angle":"float,旋转角度(正数逆时针,负数顺时针)","angular_velocity":"rad/s,旋转角速度"}},
|
||||
{"name":"rotate_search","params":{"target_class":"string,要搜寻的目标类别","description":"string,可选,目标属性描述","step_angle":"float,可选,每一步旋转的角度","total_rotation":"float,可选,总共旋转搜索的角度"}},
|
||||
{"name":"manual_confirmation","params":{}},
|
||||
{"name":"orbit_around_point","params":{"center_x":"±10000","center_y":"±10000","center_z":"[1,5000]","radius":"[5,1000]","laps":"[1,20]","clockwise":"默认true","speed_mps":"[0.5,15]","gimbal_lock":"默认true"}},
|
||||
{"name":"orbit_around_target","params":{"target_class":"见object_detect列表","description":"可选,目标属性","radius":"[5,1000]","laps":"[1,20]","clockwise":"默认true","speed_mps":"[0.5,15]","gimbal_lock":"默认true"}},
|
||||
{"name":"loiter","params":{"duration":"[1,600]秒/until_condition:可选"}},
|
||||
@@ -20,10 +24,10 @@
|
||||
{"name":"track_object","params":{"target_class":"同object_detect","description":"可选,目标属性","track_time":"[1,600]秒(必传,不可用'duration')","min_confidence":"[0.5,1.0]默认0.7","safe_distance":"[2,50]默认10"}},
|
||||
{"name":"deliver_payload","params":{"payload_type":"string","release_altitude":"[2,100]默认5"}},
|
||||
{"name":"preflight_checks","params":{"check_level":"basic/comprehensive"}},
|
||||
{"name":"emergency_return","params":{"reason":"string"}}
|
||||
{"name":"emergency_return","params":{"reason":"string"}},
|
||||
{"name":"take_photos","params":{"target_class":"同object_detect","description":"可选,目标属性","track_time":"[1,600]秒(必传,不可用'duration')","min_confidence":"[0.5,1.0]默认0.7","safe_distance":"[2,50]默认10"}}
|
||||
],
|
||||
"conditions": [
|
||||
{"name":"battery_above","params":{"threshold":"[0.0,1.0],必传"}},
|
||||
{"name":"at_waypoint","params":{"x":"±10000","y":"±10000","z":"[1,5000]","tolerance":"默认3.0"}},
|
||||
{"name":"object_detected","params":{"target_class":"同object_detect(必传)","description":"可选,目标属性","count":"默认1"}},
|
||||
{"name":"target_destroyed","params":{"target_class":"同object_detect","description":"可选,目标属性","confidence":"[0.5,1.0]默认0.8"}},
|
||||
@@ -34,6 +38,9 @@
|
||||
{"name":"Sequence","params":{},"children":"子节点数组(按序执行,全成功则成功)"},
|
||||
{"name":"Selector","params":{"memory":"默认true"},"children":"子节点数组(执行到成功为止)"},
|
||||
{"name":"Parallel","params":{"policy":"all_success"},"children":"子节点数组(同时执行,严禁用'one_success')"}
|
||||
],
|
||||
"decorators": [
|
||||
{"name":"SuccessIsFailure","params":{},"child":"单一子节点(将子节点的成功结果反转为失败)"}
|
||||
]
|
||||
}
|
||||
```
|
||||
@@ -42,53 +49,91 @@
|
||||
## 二、节点必填字段(后端Schema强制要求,缺一验证失败)
|
||||
每个节点必须包含以下字段,字段名/类型不可自定义:
|
||||
1. **`type`**:
|
||||
- 动作节点→`"action"`,条件节点→`"condition"`,控制流节点→`"Sequence"`/`"Selector"`/`"Parallel"`(与`name`字段值完全一致);
|
||||
2. **`name`**:必须是上述JSON中`actions`/`conditions`/`control_flow`下的`name`值(如“gps_status”不可错写为“gps_check”);
|
||||
3. **`params`**:严格匹配上述节点的`params`定义,无自定义参数(如优先级排序不可加“priority”字段,仅用`description`);
|
||||
4. **`children`**:仅控制流节点必含(子节点数组),动作/条件节点无此字段。
|
||||
- 动作节点→`"action"`,条件节点→`"condition"`,控制流节点→`"Sequence"`/`"Selector"`/`"Parallel"`,装饰器节点→`"decorator"`;
|
||||
2. **`name`**:必须是上述JSON中定义的`name`值;
|
||||
3. **`params`**:严格匹配上述节点的`params`定义,无自定义参数;
|
||||
4. **`children`**:仅控制流节点必含(子节点数组);
|
||||
5. **`child`**:仅装饰器节点必含(单一子节点对象,非数组)。
|
||||
|
||||
|
||||
## 三、行为树固定结构(通用不变,确保安全验证)
|
||||
根节点必须是`Parallel`,`children`含`MainTask`(Sequence)和`SafetyMonitor`(Selector),结构不随任务类型(含优先级排序)修改:
|
||||
## 三、标准任务结构模板(单次起降流程)
|
||||
大多数任务应遵循“起飞 -> 接近 -> 执行 -> 返航/降落”的单次闭环流程,参考结构如下:
|
||||
```json
|
||||
{
|
||||
"root": {
|
||||
"type": "Parallel",
|
||||
"name": "MissionWithSafety",
|
||||
"params": {"policy": "all_success"},
|
||||
"type": "Sequence",
|
||||
"name": "MainTask",
|
||||
"children": [
|
||||
{
|
||||
"type": "Sequence",
|
||||
"name": "MainTask",
|
||||
"params": {},
|
||||
"children": [
|
||||
// 通用主任务步骤(含优先级排序任务示例,需按用户指令替换):
|
||||
{"type":"action","name":"preflight_checks","params":{"check_level":"comprehensive"}},
|
||||
{"type":"action","name":"takeoff","params":{"altitude":10.0}},
|
||||
{"type":"action","name":"fly_to_waypoint","params":{"x":200.0,"y":150.0,"z":10.0}}, // 搜索区坐标(用户未给时填合理值)
|
||||
{"type":"action","name":"search_pattern","params":{"pattern_type":"grid","center_x":200.0,"center_y":150.0,"center_z":10.0,"radius":50.0,"target_class":"balloon","description":"红色"}},
|
||||
{"type":"condition","name":"object_detected","params":{"target_class":"balloon","description":"红色"}}, // 确认高优先级目标
|
||||
{"type":"action","name":"track_object","params":{"target_class":"balloon","description":"红色","track_time":30.0}},
|
||||
{"type":"action","name":"strike_target","params":{"target_class":"balloon","description":"红色"}},
|
||||
{"type":"action","name":"land","params":{"mode":"home"}}
|
||||
]
|
||||
},
|
||||
{"type":"action","name":"preflight_checks","params":{"check_level":"comprehensive"}},
|
||||
{"type":"action","name":"takeoff","params":{"altitude":10.0}},
|
||||
{"type":"action","name":"fly_to_waypoint","params":{"x":100.0,"y":50.0,"z":10.0}}, // 接近目标区域
|
||||
// --- 核心任务区 (根据指令替换) ---
|
||||
{"type":"action","name":"rotate_search","params":{"target_class":"person","description":"目标描述"}},
|
||||
{"type":"action","name":"object_detect","params":{"target_class":"person","description":"目标描述"}},
|
||||
// -------------------------------
|
||||
{"type":"action","name":"land","params":{"mode":"home"}}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 四、场景示例(请灵活参考)
|
||||
|
||||
#### 场景 1:线性搜索任务(Sequence + Selector)
|
||||
**指令**:“去研究所正大门,搜索扎辫子女子并拍照。”
|
||||
**结构**:Sequence (按顺序执行)
|
||||
```json
|
||||
{
|
||||
"root": {
|
||||
"type": "Sequence",
|
||||
"name": "MainSearchTask",
|
||||
"children": [
|
||||
{"type":"action","name":"takeoff","params":{"altitude":10.0}},
|
||||
{"type":"action","name":"fly_to_waypoint","params":{"x":100.0,"y":50.0,"z":10.0}},
|
||||
{"type":"action","name":"rotate_search","params":{"target_class":"person","description":"扎辫子女子"}},
|
||||
{
|
||||
"type": "Selector",
|
||||
"name": "SafetyMonitor",
|
||||
"params": {"memory": true},
|
||||
"name": "CheckAndPhoto",
|
||||
"children": [
|
||||
{"type":"condition","name":"battery_above","params":{"threshold":0.3}},
|
||||
{"type":"condition","name":"gps_status","params":{"min_satellites":8}},
|
||||
{
|
||||
"type":"Sequence",
|
||||
"name":"EmergencyHandler",
|
||||
"params": {},
|
||||
"type": "Sequence",
|
||||
"name": "PhotoIfFound",
|
||||
"children": [
|
||||
{"type":"action","name":"emergency_return","params":{"reason":"safety_breach"}},
|
||||
{"type":"action","name":"land","params":{"mode":"home"}}
|
||||
{"type":"condition","name":"object_detected","params":{"target_class":"person","description":"扎辫子女子"}},
|
||||
{"type":"action","name":"take_photos","params":{"target_class":"person","description":"扎辫子女子","track_time":10.0}}
|
||||
]
|
||||
}
|
||||
},
|
||||
{"type":"action","name":"loiter","params":{"duration":5.0}} // 未发现时的备选动作
|
||||
]
|
||||
},
|
||||
{"type":"action","name":"land","params":{"mode":"home"}}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 场景 2:带中断逻辑的巡逻(Selector 示例)
|
||||
**指令**:“飞往航点A。如果途中发现可疑人员,则悬停。”
|
||||
**结构**:
|
||||
```json
|
||||
{
|
||||
"root": {
|
||||
"type": "Sequence",
|
||||
"children": [
|
||||
{"type":"action","name":"takeoff","params":{"altitude":10.0}},
|
||||
{
|
||||
"type": "Selector",
|
||||
"name": "FlyOrDetect",
|
||||
"children": [
|
||||
{
|
||||
"type": "Sequence",
|
||||
"name": "InterruptionLogic",
|
||||
"children": [
|
||||
{"type":"action","name":"object_detect","params":{"target_class":"person"}},
|
||||
{"type":"action","name":"loiter","params":{"duration":5.0}}
|
||||
]
|
||||
},
|
||||
{"type":"action","name":"fly_to_waypoint","params":{"x":100.0,"y":50.0,"z":10.0}}
|
||||
]
|
||||
}
|
||||
]
|
||||
@@ -96,22 +141,20 @@
|
||||
}
|
||||
```
|
||||
|
||||
## 五、优先级排序任务通用示例
|
||||
当用户指令中明确提出有多个待考察且具有优先级关系的物体时,节点描述须为优先级关系。
|
||||
| 用户指令场景 | `target_class` | `description` |
|
||||
|-----------------------------|-----------------|-------------------------|
|
||||
| 红气球>蓝气球>绿气球 | `balloon` | `(红>蓝>绿)` |
|
||||
| 军用卡车>民用卡车>面包车 | `truck` | `(军用卡车>民用卡车>面包车)` |
|
||||
|
||||
## 四、优先级排序任务通用示例
|
||||
当用户指令中明确提出有多个待考察且具有优先级关系的物体时,节点描述须为优先级关系。比如当指令为已知有三个气球,危险级关系为红色气球大于蓝色气球大于绿色气球,要求优先跟踪最危险的气球时,节点的描述参考下表情形。
|
||||
| 用户指令场景 | `target_class` | `description` | 核心节点示例(search_pattern) |
|
||||
|-----------------------------|-----------------|-------------------------|------------------------------------------------------------------------------------------------|
|
||||
| 红气球>蓝气球>绿气球 | `balloon` | `(红>蓝>绿)` | `{"type":"action","name":"search_pattern","params":{"pattern_type":"grid","center_x":200,"center_y":150,"center_z":10,"radius":50,"target_class":"balloon","description":"(红>蓝>绿)"}}` |
|
||||
| 军用卡车>民用卡车>面包车 | `truck` | `(军用卡车>民用卡车>面包车)` | `{"type":"action","name":"object_detect","params":{"target_class":"truck","description":"(军用卡车>民用卡车>面包车)"}}` |
|
||||
## 六、高频错误规避
|
||||
1. 优先级排序不可修改`target_class`,仅用`description`填排序规则;
|
||||
2. `track_object`必传`track_time`;
|
||||
3. `gps_status`的`min_satellites`必须在6-15之间;
|
||||
4. 严禁输出 markdown 代码块标记,直接输出 JSON 纯文本;
|
||||
5. 控制流节点的 `type` 必须是 `"Sequence"`, `"Selector"` 或 `"Parallel"`;
|
||||
6. 当用户指令中要求执行动作前增加人工确认时,比如“我确认后拍照”,则必须在拍照动作前增加manual_confirmation节点
|
||||
|
||||
|
||||
## 五、高频错误规避(确保验证通过)
|
||||
1. 优先级排序不可修改`target_class`:如“民用卡车、面包车与军用卡车中,军用卡车优先”,`target_class`仍为`truck`,仅用`description`填排序规则;
|
||||
2. 在没有明确指出物体之间的优先级关系情况下,`description`字段只描述物体属性本身,严禁与用户指令中不存在的物体进行排序;
|
||||
3. `track_object`必传`track_time`:不可用`duration`替代(如跟踪30秒填`"track_time":30.0`);
|
||||
4. `gps_status`的`min_satellites`必须在6-15之间(如8,不可缺省);
|
||||
5. 无自定义节点:“锁定高优先级目标”需通过`object_detect`+`object_detected`实现,不可用“lock_high_risk_target”。
|
||||
|
||||
|
||||
## 六、输出要求
|
||||
仅输出1个严格符合上述所有规则的JSON对象,**确保:1. 优先级排序逻辑正确填入`description`;2. `target_class`匹配预定义列表;3. 行为树结构不变;4. 后端解析与Schema验证无错误**,无任何冗余内容。
|
||||
## 七、输出要求
|
||||
仅输出1个严格符合上述所有规则的JSON对象。
|
||||
|
||||
@@ -52,7 +52,7 @@ def _parse_allowed_nodes_from_prompt(prompt_text: str) -> tuple[Set[str], Set[st
|
||||
"""
|
||||
try:
|
||||
# 使用更精确的正则表达式匹配节点定义部分
|
||||
node_section_pattern = r"#### 2\. 可用节点定义.*?```json\s*({.*?})\s*```"
|
||||
node_section_pattern = r"#### 1\. 可用节点定义.*?```json\s*({.*?})\s*```"
|
||||
match = re.search(node_section_pattern, prompt_text, re.DOTALL | re.IGNORECASE)
|
||||
|
||||
if not match:
|
||||
@@ -144,51 +144,12 @@ def _fallback_parse_nodes(prompt_text: str) -> tuple[Set[str], Set[str]]:
|
||||
logging.error("在所有JSON代码块中都没有找到有效的节点定义结构。")
|
||||
return set(), set()
|
||||
|
||||
def _find_nodes_by_name(node: Dict, target_name: str) -> List[Dict]:
|
||||
"""递归查找所有指定名称的节点"""
|
||||
nodes_found = []
|
||||
|
||||
if node.get("name") == target_name:
|
||||
nodes_found.append(node)
|
||||
|
||||
# 递归搜索子节点
|
||||
for child in node.get("children", []):
|
||||
nodes_found.extend(_find_nodes_by_name(child, target_name))
|
||||
|
||||
return nodes_found
|
||||
|
||||
def _validate_safety_monitoring(pytree_instance: dict) -> bool:
|
||||
"""验证行为树是否包含必要的安全监控"""
|
||||
root_node = pytree_instance.get("root", {})
|
||||
|
||||
# 查找所有电池监控节点
|
||||
battery_nodes = _find_nodes_by_name(root_node, "battery_above")
|
||||
|
||||
# 检查是否包含安全监控结构
|
||||
safety_monitors = _find_nodes_by_name(root_node, "SafetyMonitor")
|
||||
|
||||
if not battery_nodes and not safety_monitors:
|
||||
logging.warning("⚠️ 安全警告: 行为树中没有发现电池监控节点或安全监控器")
|
||||
return False
|
||||
|
||||
# 检查电池阈值设置是否合理
|
||||
for battery_node in battery_nodes:
|
||||
threshold = battery_node.get("params", {}).get("threshold")
|
||||
if threshold is not None:
|
||||
if threshold < 0.25:
|
||||
logging.warning(f"⚠️ 安全警告: 电池阈值设置过低 ({threshold}),建议不低于0.25")
|
||||
elif threshold > 0.5:
|
||||
logging.warning(f"⚠️ 安全警告: 电池阈值设置过高 ({threshold}),可能影响任务执行")
|
||||
|
||||
logging.info("✅ 安全监控验证通过")
|
||||
return True
|
||||
|
||||
def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> dict:
|
||||
"""
|
||||
根据允许的行动和条件节点,动态生成一个JSON Schema。
|
||||
"""
|
||||
# 所有可能的节点类型
|
||||
node_types = ["action", "condition", "Sequence", "Selector", "Parallel"]
|
||||
node_types = ["action", "condition", "Sequence", "Selector", "Parallel", "decorator"]
|
||||
|
||||
# 目标检测相关的类别枚举
|
||||
target_classes = [
|
||||
@@ -201,38 +162,44 @@ def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> di
|
||||
"sandwich", "orange", "broccoli", "carrot", "hot_dog", "pizza", "donut", "cake", "chair",
|
||||
"couch", "potted_plant", "bed", "dining_table", "toilet", "tv", "laptop", "mouse", "remote",
|
||||
"keyboard", "cell_phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
|
||||
"clock", "vase", "scissors", "teddy_bear", "hair_drier", "toothbrush","balloon"
|
||||
"clock", "vase", "scissors", "teddy_bear", "hair_drier", "toothbrush","balloon","trash","window"
|
||||
]
|
||||
|
||||
# 递归节点定义
|
||||
node_definition = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"type": "string", "enum": node_types},
|
||||
# 修改:手动构造不区分大小写的正则,避免使用不支持的 (?i) 标志
|
||||
# 匹配: action, condition, sequence, selector, parallel, decorator (忽略大小写)
|
||||
"type": {
|
||||
"type": "string",
|
||||
"pattern": "^([Aa][Cc][Tt][Ii][Oo][Nn]|[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]|[Ss][Ee][Qq][Uu][Ee][Nn][Cc][Ee]|[Ss][Ee][Ll][Ee][Cc][Tt][Oo][Rr]|[Pp][Aa][Rr][Aa][Ll][Ll][Ee][Ll]|[Dd][Ee][Cc][Oo][Rr][Aa][Tt][Oo][Rr])$"
|
||||
},
|
||||
"name": {"type": "string"},
|
||||
"params": {"type": "object"},
|
||||
"children": {
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/definitions/node"}
|
||||
}
|
||||
},
|
||||
"child": {"$ref": "#/definitions/node"}
|
||||
},
|
||||
"required": ["type", "name"],
|
||||
"allOf": [
|
||||
# 动作节点验证
|
||||
# 动作节点验证 (忽略大小写)
|
||||
{
|
||||
"if": {"properties": {"type": {"const": "action"}}},
|
||||
"if": {"properties": {"type": {"pattern": "^[Aa][Cc][Tt][Ii][Oo][Nn]$"}}},
|
||||
"then": {"properties": {"name": {"enum": sorted(list(allowed_actions))}}}
|
||||
},
|
||||
# 条件节点验证
|
||||
# 条件节点验证 (忽略大小写)
|
||||
{
|
||||
"if": {"properties": {"type": {"const": "condition"}}},
|
||||
"if": {"properties": {"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"}}},
|
||||
"then": {"properties": {"name": {"enum": sorted(list(allowed_conditions))}}}
|
||||
},
|
||||
# 目标检测动作节点的参数验证
|
||||
# 目标检测动作节点的参数验证 (忽略大小写)
|
||||
{
|
||||
"if": {
|
||||
"properties": {
|
||||
"type": {"const": "action"},
|
||||
"type": {"pattern": "^[Aa][Cc][Tt][Ii][Oo][Nn]$"},
|
||||
"name": {"const": "object_detect"}
|
||||
}
|
||||
},
|
||||
@@ -251,11 +218,11 @@ def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> di
|
||||
}
|
||||
}
|
||||
},
|
||||
# 目标检测条件节点的参数验证
|
||||
# 目标检测条件节点的参数验证 (忽略大小写)
|
||||
{
|
||||
"if": {
|
||||
"properties": {
|
||||
"type": {"const": "condition"},
|
||||
"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"},
|
||||
"name": {"const": "object_detected"}
|
||||
}
|
||||
},
|
||||
@@ -274,11 +241,11 @@ def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> di
|
||||
}
|
||||
}
|
||||
},
|
||||
# 电池监控节点的参数验证
|
||||
# 电池监控节点的参数验证 (忽略大小写)
|
||||
{
|
||||
"if": {
|
||||
"properties": {
|
||||
"type": {"const": "condition"},
|
||||
"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"},
|
||||
"name": {"const": "battery_above"}
|
||||
}
|
||||
},
|
||||
@@ -295,11 +262,11 @@ def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> di
|
||||
}
|
||||
}
|
||||
},
|
||||
# GPS状态节点的参数验证
|
||||
# GPS状态节点的参数验证 (忽略大小写)
|
||||
{
|
||||
"if": {
|
||||
"properties": {
|
||||
"type": {"const": "condition"},
|
||||
"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"},
|
||||
"name": {"const": "gps_status"}
|
||||
}
|
||||
},
|
||||
@@ -372,10 +339,7 @@ def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
|
||||
jsonschema.validate(instance=pytree_instance, schema=schema)
|
||||
logging.info("✅ JSON Schema验证成功")
|
||||
|
||||
# 额外验证安全监控
|
||||
safety_valid = _validate_safety_monitoring(pytree_instance)
|
||||
|
||||
return True and safety_valid
|
||||
return True
|
||||
except jsonschema.ValidationError as e:
|
||||
logging.warning("❌ Pytree验证失败")
|
||||
logging.warning(f"错误信息: {e.message}")
|
||||
@@ -503,6 +467,10 @@ def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#e1d5e7' # 紫色
|
||||
elif node_type == 'decorator':
|
||||
shape = 'doubleoctagon'
|
||||
style = 'filled'
|
||||
fillcolor = '#f8cecc' # 浅红
|
||||
|
||||
# 特别标记安全相关节点
|
||||
if node.get('name') in ['battery_above', 'gps_status', 'SafetyMonitor']:
|
||||
@@ -515,28 +483,28 @@ def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
|
||||
if parent_id:
|
||||
dot.edge(parent_id, current_id)
|
||||
|
||||
# 递归处理子节点
|
||||
# 递归处理子节点 (Sequence, Selector, Parallel 等)
|
||||
children = node.get("children", [])
|
||||
if not children:
|
||||
return current_id
|
||||
|
||||
# 记录所有子节点的ID
|
||||
child_ids = []
|
||||
|
||||
# 正确的递归连接:每个子节点都连接到当前节点
|
||||
for child in children:
|
||||
child_id = _add_nodes_and_edges(child, dot, current_id)
|
||||
child_ids.append(child_id)
|
||||
|
||||
# 子节点同级排列(横向排布,更直观地表现同层)
|
||||
if len(child_ids) > 1:
|
||||
with dot.subgraph(name=f"rank_{current_id}") as s:
|
||||
s.attr(rank='same')
|
||||
for cid in child_ids:
|
||||
s.node(cid)
|
||||
|
||||
# 行为树中,所有类型的节点都只是父连子,不需要子节点间的额外连接
|
||||
# Sequence、Selector、Parallel 的执行逻辑由行为树引擎处理,不需要在可视化中体现
|
||||
if children:
|
||||
# 记录所有子节点的ID
|
||||
child_ids = []
|
||||
|
||||
# 正确的递归连接:每个子节点都连接到当前节点
|
||||
for child in children:
|
||||
child_id = _add_nodes_and_edges(child, dot, current_id)
|
||||
child_ids.append(child_id)
|
||||
|
||||
# 子节点同级排列(横向排布,更直观地表现同层)
|
||||
if len(child_ids) > 1:
|
||||
with dot.subgraph(name=f"rank_{current_id}") as s:
|
||||
s.attr(rank='same')
|
||||
for cid in child_ids:
|
||||
s.node(cid)
|
||||
|
||||
# 递归处理单子节点 (Decorator)
|
||||
child = node.get("child")
|
||||
if child:
|
||||
_add_nodes_and_edges(child, dot, current_id)
|
||||
|
||||
return current_id
|
||||
|
||||
@@ -588,7 +556,7 @@ class PyTreeGenerator:
|
||||
self.complex_llm_client = openai.OpenAI(api_key=self.api_key, base_url=self.complex_base_url)
|
||||
|
||||
# --- ChromaDB Client Setup ---
|
||||
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'vector_store'))
|
||||
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'rag','vector_store'))
|
||||
self.chroma_client = chromadb.PersistentClient(path=vector_store_path)
|
||||
|
||||
# Explicitly use the remote embedding function for queries
|
||||
@@ -725,7 +693,7 @@ class PyTreeGenerator:
|
||||
pytree_str = combined_text if combined_text else (msg_content or "")
|
||||
raw_full_text_for_logging = pytree_str # 保存完整原文(含 <think>)以便失败时完整打印
|
||||
|
||||
# 提取 <think> 推理链内容(若存在)
|
||||
# 提取 <think> 推理链内容(若有)
|
||||
reasoning_text = None
|
||||
try:
|
||||
think_match = re.search(r"<think>([\s\S]*?)</think>", pytree_str)
|
||||
@@ -827,49 +795,7 @@ class PyTreeGenerator:
|
||||
pytree_dict['final_prompt'] = final_prompt
|
||||
return pytree_dict
|
||||
|
||||
# 复杂模式回退:若模型误返回简单结构(root是单个action),则自动包装为含安全监控的行为树
|
||||
if mode == "complex" and isinstance(pytree_dict, dict) and 'root' in pytree_dict:
|
||||
root_node = pytree_dict.get('root', {})
|
||||
# 检查是否是简单结构(root是单个action节点,没有children)
|
||||
if (root_node.get('type') == 'action' and
|
||||
('children' not in root_node or not root_node.get('children'))):
|
||||
try:
|
||||
jsonschema.validate(instance=pytree_dict, schema=self.simple_schema)
|
||||
logging.warning("⚠️ 复杂模式生成了简单结构(单个action),触发自动包装为完整行为树的回退逻辑。")
|
||||
action_name = root_node.get('name')
|
||||
action_params = root_node.get('params') if isinstance(root_node.get('params'), dict) else {}
|
||||
|
||||
safety_selector = {
|
||||
"type": "Selector",
|
||||
"name": "SafetyMonitor",
|
||||
"params": {"memory": True},
|
||||
"children": [
|
||||
{"type": "condition", "name": "battery_above", "params": {"threshold": 0.3}},
|
||||
{"type": "condition", "name": "gps_status", "params": {"min_satellites": 8}},
|
||||
{"type": "Sequence", "name": "EmergencyHandler", "children": [
|
||||
{"type": "action", "name": "emergency_return", "params": {"reason": "safety_breach"}},
|
||||
{"type": "action", "name": "land", "params": {"mode": "home"}}
|
||||
]}
|
||||
]
|
||||
}
|
||||
|
||||
main_children = [{"type": "action", "name": action_name, "params": action_params}]
|
||||
if action_name != "land":
|
||||
main_children.append({"type": "action", "name": "land", "params": {"mode": "home"}})
|
||||
|
||||
root_parallel = {
|
||||
"type": "Parallel",
|
||||
"name": "MissionWithSafety",
|
||||
"params": {"policy": "all_success"},
|
||||
"children": [
|
||||
{"type": "Sequence", "name": "MainTask", "children": main_children},
|
||||
safety_selector
|
||||
]
|
||||
}
|
||||
pytree_dict = {"root": root_parallel}
|
||||
except jsonschema.ValidationError:
|
||||
# 不符合简单结构,按正常复杂验证继续
|
||||
pass
|
||||
# 验证生成的复杂行为树
|
||||
if _validate_pytree_with_schema(pytree_dict, self.schema):
|
||||
logging.info("✅ 成功生成并验证了Pytree")
|
||||
plan_id = str(uuid.uuid4())
|
||||
|
||||
39
tools/rag/README.md
Normal file
39
tools/rag/README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# RAG & Map Tools
|
||||
|
||||
该目录包含了地图构建、知识库生成和向量数据库管理的相关工具。
|
||||
|
||||
## 目录结构
|
||||
|
||||
- **knowledge_base/**: 存放源文档数据。
|
||||
- 支持格式: `.txt`, `.md`, `.pdf`
|
||||
- 生成格式: `.json`, `.ndjson` (由 `build_knowledge_base.py` 生成)
|
||||
|
||||
- **map/**: 存放地图原始数据。
|
||||
- `.osm` (OpenStreetMap 数据)
|
||||
- `.world` (Gazebo 仿真环境数据)
|
||||
|
||||
- **vector_store/**: ChromaDB 向量数据库的持久化存储目录。
|
||||
|
||||
## 脚本说明
|
||||
|
||||
### 1. `build_knowledge_base.py`
|
||||
**功能**: 处理 `map/` 目录下的地图文件,提取地理信息和语义描述,生成知识库文件到 `knowledge_base/` 目录。
|
||||
**使用方法**:
|
||||
```bash
|
||||
python build_knowledge_base.py
|
||||
```
|
||||
|
||||
### 2. `ingest.py`
|
||||
**功能**: 读取 `knowledge_base/` 中的所有文档,调用嵌入模型(Embedding Model)将其向量化,并存入 `vector_store/` 中的 ChromaDB 数据库。
|
||||
**使用方法**:
|
||||
```bash
|
||||
python ingest.py
|
||||
```
|
||||
**依赖**: 需要确保后端嵌入服务(如 `llama-server`)已启动,或者配置正确的 `ORIN_IP` 环境变量。
|
||||
|
||||
## 工作流
|
||||
1. 将地图文件放入 `map/`。
|
||||
2. 运行 `build_knowledge_base.py` 生成文本描述。
|
||||
3. 将其他补充文档放入 `knowledge_base/`。
|
||||
4. 运行 `ingest.py` 构建向量索引。
|
||||
|
||||
157
tools/rag/build_knowledge_base.py
Normal file
157
tools/rag/build_knowledge_base.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import xml.etree.ElementTree as ET
|
||||
import json
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import os
|
||||
|
||||
# --- 配置日志 ---
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
def process_osm_json(input_path: Path) -> list[str]:
|
||||
"""
|
||||
处理OpenStreetMap的JSON文件,返回描述性句子列表。
|
||||
"""
|
||||
logging.info(f"正在以OSM JSON格式处理文件: {input_path.name}")
|
||||
descriptions = []
|
||||
try:
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logging.error(f"读取或解析 {input_path.name} 时出错: {e}")
|
||||
return []
|
||||
|
||||
elements = data.get('elements', [])
|
||||
if not elements:
|
||||
return []
|
||||
|
||||
nodes_map = {node['id']: node for node in elements if node.get('type') == 'node'}
|
||||
ways = [elem for elem in elements if elem.get('type') == 'way']
|
||||
|
||||
for way in ways:
|
||||
tags = way.get('tags', {})
|
||||
if 'name' not in tags:
|
||||
continue
|
||||
|
||||
way_name = tags.get('name')
|
||||
way_nodes_ids = way.get('nodes', [])
|
||||
if not way_nodes_ids:
|
||||
continue
|
||||
|
||||
total_lat, total_lon, node_count = 0, 0, 0
|
||||
for node_id in way_nodes_ids:
|
||||
node_info = nodes_map.get(node_id)
|
||||
if node_info:
|
||||
total_lat += node_info.get('lat', 0)
|
||||
total_lon += node_info.get('lon', 0)
|
||||
node_count += 1
|
||||
|
||||
if node_count == 0:
|
||||
continue
|
||||
|
||||
center_lat = total_lat / node_count
|
||||
center_lon = total_lon / node_count
|
||||
|
||||
sentence = f"在地图上有一个名为 '{way_name}' 的地点或区域"
|
||||
other_tags = {k: v for k, v in tags.items() if k != 'name'}
|
||||
if other_tags:
|
||||
tag_descs = [f"{key}是'{value}'" for key, value in other_tags.items()]
|
||||
sentence += f",它的{ '、'.join(tag_descs) }"
|
||||
sentence += f",其中心位置坐标大约在 ({center_lat:.6f}, {center_lon:.6f})。"
|
||||
descriptions.append(sentence)
|
||||
|
||||
logging.info(f"从 {input_path.name} 提取了 {len(descriptions)} 条位置描述。")
|
||||
return descriptions
|
||||
|
||||
|
||||
def process_gazebo_world(input_path: Path) -> list[str]:
|
||||
"""
|
||||
处理Gazebo的.world文件,返回描述性句子列表。
|
||||
"""
|
||||
logging.info(f"正在以Gazebo World格式处理文件: {input_path.name}")
|
||||
descriptions = []
|
||||
try:
|
||||
tree = ET.parse(input_path)
|
||||
root = tree.getroot()
|
||||
except ET.ParseError as e:
|
||||
logging.error(f"解析XML文件 {input_path.name} 失败: {e}")
|
||||
return []
|
||||
|
||||
models = root.findall('.//model')
|
||||
for model in models:
|
||||
model_name = model.get('name')
|
||||
pose_element = model.find('pose')
|
||||
|
||||
if model_name and pose_element is not None and pose_element.text:
|
||||
try:
|
||||
pose_values = [float(p) for p in pose_element.text.strip().split()]
|
||||
sentence = (
|
||||
f"仿真环境中有一个名为 '{model_name}' 的物体,"
|
||||
f"其位置和姿态(x, y, z, roll, pitch, yaw)为: {pose_values}。"
|
||||
)
|
||||
descriptions.append(sentence)
|
||||
except (ValueError, IndexError):
|
||||
logging.warning(f"跳过模型 '{model_name}',因其pose格式不正确。")
|
||||
|
||||
logging.info(f"从 {input_path.name} 提取了 {len(descriptions)} 个物体信息。")
|
||||
return descriptions
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
主函数,扫描源数据目录,为每个文件生成独立的NDJSON知识库。
|
||||
"""
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
# 输入源: tools/map/
|
||||
source_data_dir = script_dir / 'map'
|
||||
# 输出目录: tools/knowledge_base/
|
||||
output_knowledge_base_dir = script_dir / 'knowledge_base'
|
||||
|
||||
if not source_data_dir.exists():
|
||||
logging.error(f"源数据目录不存在: {source_data_dir}")
|
||||
return
|
||||
|
||||
# 确保输出目录存在
|
||||
output_knowledge_base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_files_processed = 0
|
||||
logging.info(f"--- 开始扫描源数据目录: {source_data_dir} ---")
|
||||
|
||||
for file_path in source_data_dir.iterdir():
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
descriptions = []
|
||||
if file_path.suffix == '.json':
|
||||
descriptions = process_osm_json(file_path)
|
||||
elif file_path.suffix == '.world':
|
||||
descriptions = process_gazebo_world(file_path)
|
||||
else:
|
||||
logging.warning(f"跳过不支持的文件类型: {file_path.name}")
|
||||
continue
|
||||
|
||||
if not descriptions:
|
||||
logging.warning(f"未能从 {file_path.name} 提取有效信息,跳过生成文件。")
|
||||
continue
|
||||
|
||||
output_filename = file_path.stem + '_knowledge.ndjson'
|
||||
output_path = output_knowledge_base_dir / output_filename
|
||||
|
||||
try:
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for sentence in descriptions:
|
||||
json_record = {"text": sentence}
|
||||
f.write(json.dumps(json_record, ensure_ascii=False) + '\n')
|
||||
logging.info(f"成功为 '{file_path.name}' 生成知识库文件: {output_path.name}")
|
||||
total_files_processed += 1
|
||||
except IOError as e:
|
||||
logging.error(f"写入输出文件 '{output_path.name}' 失败: {e}")
|
||||
|
||||
logging.info("--- 数据处理完成 ---")
|
||||
if total_files_processed > 0:
|
||||
logging.info(f"共为 {total_files_processed} 个源文件生成了知识库。")
|
||||
else:
|
||||
logging.warning("未生成任何知识库文件。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
194
tools/rag/ingest.py
Normal file
194
tools/rag/ingest.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# 该代码用于将本地知识库中的文档导入到ChromaDB中,并使用远程嵌入模型进行向量化
|
||||
import os
|
||||
from pathlib import Path
|
||||
import chromadb
|
||||
# from chromadb.utils import embedding_functions - 不再需要
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
|
||||
from unstructured.partition.auto import partition
|
||||
from rich.progress import track
|
||||
import logging
|
||||
import requests # 导入requests
|
||||
import json # 导入json模块
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# --- 配置 ---
|
||||
# 获取脚本所在目录,确保路径的正确性
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
KNOWLEDGE_BASE_DIR = SCRIPT_DIR / "knowledge_base"
|
||||
VECTOR_STORE_DIR = SCRIPT_DIR / "vector_store"
|
||||
COLLECTION_NAME = "drone_docs"
|
||||
# EMBEDDING_MODEL_NAME = "bge-small-zh-v1.5" # 不再需要,模型名在函数内部处理
|
||||
|
||||
# --- 自定义远程嵌入函数 ---
|
||||
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
"""
|
||||
一个使用远程、兼容OpenAI API的嵌入服务的嵌入函数。
|
||||
"""
|
||||
def __init__(self, api_url: str):
|
||||
self._api_url = api_url
|
||||
logging.info(f"自定义嵌入函数已初始化,将连接到: {self._api_url}")
|
||||
|
||||
def __call__(self, input: Embeddable) -> Embeddings:
|
||||
"""
|
||||
对输入的文档进行嵌入。
|
||||
"""
|
||||
# 我们的服务只能处理文本,所以检查输入是否为字符串列表
|
||||
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
|
||||
logging.error("此嵌入函数仅支持字符串列表(文档)作为输入。")
|
||||
return []
|
||||
|
||||
try:
|
||||
# 移除 "model" 参数,因为embedding服务可能不需要它
|
||||
response = requests.post(
|
||||
self._api_url,
|
||||
json={"input": input},
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
response.raise_for_status() # 如果请求失败则抛出HTTPError
|
||||
|
||||
# 按照OpenAI API的格式解析返回的嵌入向量
|
||||
data = response.json().get("data", [])
|
||||
if not data:
|
||||
raise ValueError("API响应中没有找到'data'字段或'data'为空")
|
||||
|
||||
embeddings = [item['embedding'] for item in data]
|
||||
return embeddings
|
||||
|
||||
except requests.RequestException as e:
|
||||
logging.error(f"调用嵌入API失败: {e}")
|
||||
# 返回一个空列表或根据需要处理错误
|
||||
return []
|
||||
except (ValueError, KeyError) as e:
|
||||
logging.error(f"解析API响应失败: {e}")
|
||||
logging.error(f"收到的响应内容: {response.text}")
|
||||
return []
|
||||
|
||||
|
||||
def get_documents(directory: Path):
|
||||
"""从知识库目录加载所有文档并进行切分"""
|
||||
documents = []
|
||||
logging.info(f"从 '{directory}' 加载文档...")
|
||||
for file_path in directory.rglob("*"):
|
||||
if file_path.is_file() and not file_path.name.startswith('.'):
|
||||
try:
|
||||
# 对简单文本文件直接读取
|
||||
if file_path.suffix in ['.txt', '.md']:
|
||||
text = file_path.read_text(encoding='utf-8')
|
||||
documents.append({
|
||||
"text": text,
|
||||
"metadata": {"source": str(file_path.name)}
|
||||
})
|
||||
logging.info(f"成功处理文本文件: {file_path.name}")
|
||||
# 特别处理常规的JSON文件
|
||||
elif file_path.suffix == '.json':
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
if 'elements' in data and isinstance(data['elements'], list):
|
||||
for element in data['elements']:
|
||||
documents.append({
|
||||
"text": json.dumps(element, ensure_ascii=False),
|
||||
"metadata": {"source": str(file_path.name)}
|
||||
})
|
||||
logging.info(f"成功处理JSON文件: {file_path.name}, 提取了 {len(data['elements'])} 个元素。")
|
||||
else:
|
||||
documents.append({
|
||||
"text": json.dumps(data, ensure_ascii=False),
|
||||
"metadata": {"source": str(file_path.name)}
|
||||
})
|
||||
logging.info(f"成功处理JSON文件: {file_path.name} (作为单个文档)")
|
||||
# 新增:专门处理我们生成的 NDJSON 文件
|
||||
elif file_path.suffix == '.ndjson':
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
count = 0
|
||||
for line in f:
|
||||
try:
|
||||
record = json.loads(line)
|
||||
if 'text' in record and isinstance(record['text'], str):
|
||||
documents.append({
|
||||
"text": record['text'],
|
||||
"metadata": {"source": str(file_path.name)}
|
||||
})
|
||||
count += 1
|
||||
except json.JSONDecodeError:
|
||||
logging.warning(f"跳过无效的JSON行: {line.strip()}")
|
||||
if count > 0:
|
||||
logging.info(f"成功处理NDJSON文件: {file_path.name}, 提取了 {count} 个文档。")
|
||||
# 对其他所有文件类型,使用unstructured
|
||||
else:
|
||||
elements = partition(filename=str(file_path))
|
||||
for element in elements:
|
||||
documents.append({
|
||||
"text": element.text,
|
||||
"metadata": {"source": str(file_path.name)}
|
||||
})
|
||||
logging.info(f"成功处理文件: {file_path.name} (使用unstructured)")
|
||||
except Exception as e:
|
||||
logging.error(f"处理文件 {file_path.name} 失败: {e}")
|
||||
return documents
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数,执行文档入库流程"""
|
||||
if not KNOWLEDGE_BASE_DIR.exists():
|
||||
KNOWLEDGE_BASE_DIR.mkdir(parents=True)
|
||||
logging.warning(f"知识库目录不存在,已自动创建: {KNOWLEDGE_BASE_DIR}")
|
||||
logging.warning("请向该目录中添加您的知识文件(如 .txt, .pdf, .md)。")
|
||||
return
|
||||
|
||||
# 1. 加载并切分文档
|
||||
docs_to_ingest = get_documents(KNOWLEDGE_BASE_DIR)
|
||||
if not docs_to_ingest:
|
||||
logging.warning("在知识库中未找到可处理的文档。")
|
||||
return
|
||||
|
||||
# 2. 初始化ChromaDB客户端和远程嵌入函数
|
||||
orin_ip = os.getenv("ORIN_IP", "localhost")
|
||||
embedding_api_url = f"http://{orin_ip}:8090/v1/embeddings"
|
||||
|
||||
logging.info(f"正在初始化远程嵌入函数,目标服务地址: {embedding_api_url}")
|
||||
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
|
||||
|
||||
client = chromadb.PersistentClient(path=str(VECTOR_STORE_DIR))
|
||||
|
||||
# 3. 创建或获取集合
|
||||
logging.info(f"正在访问ChromaDB集合: {COLLECTION_NAME}")
|
||||
collection = client.get_or_create_collection(
|
||||
name=COLLECTION_NAME,
|
||||
embedding_function=embedding_func
|
||||
)
|
||||
|
||||
# 4. 将文档向量化并存入数据库
|
||||
logging.info(f"开始将 {len(docs_to_ingest)} 个文档块入库...")
|
||||
|
||||
# 为了避免重复添加,可以先检查
|
||||
# (这里为了简单,我们每次都重新添加,生产环境需要更复杂的逻辑)
|
||||
|
||||
doc_texts = [doc['text'] for doc in docs_to_ingest]
|
||||
metadatas = [doc['metadata'] for doc in docs_to_ingest]
|
||||
ids = [f"doc_{KNOWLEDGE_BASE_DIR.name}_{i}" for i in range(len(doc_texts))]
|
||||
|
||||
try:
|
||||
# ChromaDB的add方法会自动处理嵌入
|
||||
collection.add(
|
||||
documents=doc_texts,
|
||||
metadatas=metadatas,
|
||||
ids=ids
|
||||
)
|
||||
logging.info("所有文档块已成功入库!")
|
||||
except Exception as e:
|
||||
logging.error(f"向ChromaDB添加文档时出错: {e}")
|
||||
|
||||
|
||||
# 验证一下
|
||||
count = collection.count()
|
||||
logging.info(f"数据库中现在有 {count} 个条目。")
|
||||
|
||||
print("\n✅ 数据入库完成!")
|
||||
print(f"知识库位于: {KNOWLEDGE_BASE_DIR}")
|
||||
print(f"向量数据库位于: {VECTOR_STORE_DIR}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
6
tools/rag/knowledge_base/export_knowledge.ndjson
Normal file
6
tools/rag/knowledge_base/export_knowledge.ndjson
Normal file
@@ -0,0 +1,6 @@
|
||||
{"text": "在地图上有一个名为 '跷跷板' 的地点或区域,它的leisure是'playground',其中心位置坐标大约在 (x:15, y:-8.5, z:1.2)。"}
|
||||
{"text": "在地图上有一个名为 'A地' 的地点或区域,它的building是'commercial',其中心位置坐标大约在 (x:10, y:-10, z:2)。"}
|
||||
{"text": "在地图上有一个名为 '学生宿舍' 的地点或区域,它的building是'dormitory',其中心位置坐标大约在 (x:5, y:3, z:2)。"}
|
||||
{"text": "地点:'研究所正大门'。别名:'大门'、'入口'。坐标:(x:-23.8, y:292.8, z:14)。建议悬停高度:14米。适合任务:定点侦察、拍照。"}
|
||||
{"text": "地点:'研究所广场'。属性:开阔区域。坐标:(x:-24.0, y:241.8, z:14)。建议搜索半径:30米。适合任务:寻找人员、旋转搜索。"}
|
||||
{"text": "路线:'研究所外围巡逻'。关键航点序列:[(x:-24.0, y:241.8), (x:-107.8, y:289.8), (x:-106.5, y:241.3), (x:-23.80, y:292.80)]。高度:14米。适合环绕侦察任务。"}
|
||||
70349
tools/rag/map/export.json
Normal file
70349
tools/rag/map/export.json
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tools/rag/vector_store/chroma.sqlite3
Normal file
BIN
tools/rag/vector_store/chroma.sqlite3
Normal file
Binary file not shown.
57
tools/test_validate/README.md
Normal file
57
tools/test_validate/README.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# Test & Validation Tools (Unified)
|
||||
|
||||
该目录包含用于测试无人机规划系统、API 接口及 LLM 服务的集成验证工具集。
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
使用统一入口脚本启动交互式菜单:
|
||||
|
||||
```bash
|
||||
python run_tests.py
|
||||
```
|
||||
|
||||
## 🛠️ 测试模式
|
||||
|
||||
### 1. 交互式单次测试 (Mode 1)
|
||||
- **场景**: 快速验证单条指令,调试 Prompt。
|
||||
- **操作**: 在终端输入自然语言指令,即时获取结果。
|
||||
- **输出**: `validation/temporary/{指令名}/`
|
||||
- `response.json`: 完整 API 响应
|
||||
- `plan.png`: 可视化任务树
|
||||
- `process.log`: 请求与响应日志
|
||||
|
||||
### 2. 批量/场景测试 (Mode 2)
|
||||
- **场景**:
|
||||
- **场景测试**: 验证一组预定义指令的正确性(默认运行 1 次)。
|
||||
- **稳定性测试**: 对同一组指令进行高频重复测试(如运行 10 次),检测成功率和延迟抖动。
|
||||
- **操作**:
|
||||
1. 选择指令文件(位于 `instructions/` 目录)。
|
||||
2. 输入每条指令的运行次数(默认 1)。
|
||||
- **输出**: `validation/{时间戳}/`
|
||||
- `test_summary.csv`: 统计摘要(成功率、平均耗时)
|
||||
- `test_details.csv`: 每次运行的详细记录
|
||||
- `instructions_backup.txt`: 本次测试使用的指令备份
|
||||
- `{指令名}/`: 包含所有运行的 `.json` 和 `.png` 产物
|
||||
|
||||
## 📂 目录结构
|
||||
|
||||
```text
|
||||
tools/test_validate/
|
||||
├── instructions/ # 指令集文件 (.txt)
|
||||
├── modules/ # 功能模块
|
||||
│ ├── api_client.py # API 客户端核心
|
||||
│ ├── interactive_test.py # 交互式测试逻辑
|
||||
│ ├── batch_runner.py # 批量测试逻辑
|
||||
│ ├── visualizer.py # 可视化工具库
|
||||
│ ├── llm_tester.py # LLM 连接测试
|
||||
│ └── drone_uploader.py # 任务上传工具
|
||||
├── validation/ # 测试产物输出
|
||||
│ ├── temporary/ # 交互式测试结果
|
||||
│ └── {时间戳}/ # 批量测试结果
|
||||
└── run_tests.py # 主程序入口
|
||||
```
|
||||
|
||||
## 📄 配置文件
|
||||
|
||||
- **instructions/validate_instructions.txt**: 默认的预定义场景指令集。
|
||||
- 您可以在 `instructions/` 下添加任意 `.txt` 文件,测试时会在菜单中自动列出供选择。
|
||||
@@ -0,0 +1,5 @@
|
||||
去研究所正大门,搜索扎辫子女子并拍照。
|
||||
查找戴帽子的女子,找到后近距离拍照。
|
||||
到研究所广场,对长发可疑男子进行拍照。
|
||||
到研究所广场,寻找黄色衣服男子,我确认后再对其拍照。
|
||||
立即返航。
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
60
tools/test_validate/modules/api_client.py
Normal file
60
tools/test_validate/modules/api_client.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import requests
|
||||
import time
|
||||
import json
|
||||
|
||||
class APIClient:
|
||||
def __init__(self, base_url="http://127.0.0.1:8000"):
|
||||
self.base_url = base_url
|
||||
self.endpoint = "/generate_plan"
|
||||
|
||||
def send_request(self, prompt, timeout=60):
|
||||
"""
|
||||
Sends a request to the API and returns a structured result.
|
||||
Returns:
|
||||
dict: {
|
||||
"success": bool,
|
||||
"data": dict or None,
|
||||
"latency": float (seconds),
|
||||
"error": str or None,
|
||||
"http_status": int or None
|
||||
}
|
||||
"""
|
||||
url = f"{self.base_url}{self.endpoint}"
|
||||
payload = {"user_prompt": prompt}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
response = requests.post(url, json=payload, headers=headers, timeout=timeout)
|
||||
latency = time.time() - start_time
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
return {
|
||||
"success": True,
|
||||
"data": data,
|
||||
"latency": latency,
|
||||
"error": None,
|
||||
"http_status": response.status_code
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {
|
||||
"success": False,
|
||||
"data": None,
|
||||
"latency": latency,
|
||||
"error": f"Invalid JSON response: {response.text[:200]}",
|
||||
"http_status": response.status_code
|
||||
}
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
latency = time.time() - start_time
|
||||
return {
|
||||
"success": False,
|
||||
"data": None,
|
||||
"latency": latency,
|
||||
"error": str(e),
|
||||
"http_status": getattr(e.response, 'status_code', None)
|
||||
}
|
||||
|
||||
129
tools/test_validate/modules/batch_runner.py
Normal file
129
tools/test_validate/modules/batch_runner.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import csv
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from .api_client import APIClient
|
||||
from .visualizer import generate_visualization, sanitize_filename
|
||||
|
||||
def run_batch_test():
|
||||
# 1. 选择指令文件
|
||||
base_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
instr_dir = os.path.join(base_dir, "instructions")
|
||||
files = [f for f in os.listdir(instr_dir) if f.endswith('.txt')]
|
||||
|
||||
if not files:
|
||||
print("❌ 未在 instructions 目录下找到 .txt 文件")
|
||||
return
|
||||
|
||||
print("\n请选择测试指令文件:")
|
||||
for i, f in enumerate(files):
|
||||
print(f"{i+1}. {f}")
|
||||
|
||||
try:
|
||||
idx = int(input("请输入序号: ").strip()) - 1
|
||||
if idx < 0 or idx >= len(files):
|
||||
print("❌ 无效序号")
|
||||
return
|
||||
selected_file = os.path.join(instr_dir, files[idx])
|
||||
except ValueError:
|
||||
print("❌ 输入无效")
|
||||
return
|
||||
|
||||
# 2. 配置参数
|
||||
try:
|
||||
iterations = int(input("请输入每条指令的测试次数 (默认1): ").strip() or "1")
|
||||
except ValueError:
|
||||
iterations = 1
|
||||
|
||||
# 3. 准备输出目录
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
output_dir = os.path.join(base_dir, "validation", timestamp)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 备份指令文件
|
||||
shutil.copy(selected_file, os.path.join(output_dir, "instructions_backup.txt"))
|
||||
|
||||
# 读取指令
|
||||
with open(selected_file, 'r', encoding='utf-8') as f:
|
||||
instructions = [line.strip() for line in f if line.strip() and not line.startswith('#')]
|
||||
|
||||
print(f"\n🚀 开始批量测试 (共 {len(instructions)} 条指令, 每条 {iterations} 次)")
|
||||
print(f"📂 输出目录: {output_dir}\n")
|
||||
|
||||
client = APIClient()
|
||||
detailed_results = []
|
||||
summary_stats = {}
|
||||
|
||||
for i, prompt in enumerate(instructions, 1):
|
||||
print(f"[{i}/{len(instructions)}] 测试指令: {prompt[:30]}...")
|
||||
safe_name = sanitize_filename(prompt)
|
||||
instr_out_dir = os.path.join(output_dir, safe_name)
|
||||
os.makedirs(instr_out_dir, exist_ok=True)
|
||||
|
||||
success_count = 0
|
||||
total_latency = 0
|
||||
|
||||
for k in range(1, iterations + 1):
|
||||
print(f" - 第 {k} 次...", end="", flush=True)
|
||||
result = client.send_request(prompt)
|
||||
|
||||
# 记录详情
|
||||
detailed_results.append({
|
||||
"instruction": prompt,
|
||||
"run_id": k,
|
||||
"success": result['success'],
|
||||
"latency": result['latency'],
|
||||
"error": result['error'] or ""
|
||||
})
|
||||
|
||||
# 保存产物
|
||||
if result['success']:
|
||||
print(f" ✅ ({result['latency']:.2f}s)")
|
||||
success_count += 1
|
||||
total_latency += result['latency']
|
||||
|
||||
# 保存JSON
|
||||
with open(os.path.join(instr_out_dir, f"{k}.json"), 'w', encoding='utf-8') as f:
|
||||
json.dump(result['data'], f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 保存图片
|
||||
if result['data'] and 'root' in result['data']:
|
||||
generate_visualization(result['data']['root'], os.path.join(instr_out_dir, f"{k}.png"))
|
||||
else:
|
||||
print(f" ❌ {result['error']}")
|
||||
|
||||
time.sleep(0.5) # 避免过快请求
|
||||
|
||||
# 统计单条指令
|
||||
avg_lat = total_latency / success_count if success_count > 0 else 0
|
||||
summary_stats[prompt] = {
|
||||
"total_runs": iterations,
|
||||
"success_runs": success_count,
|
||||
"success_rate": f"{(success_count/iterations)*100:.1f}%",
|
||||
"avg_latency": f"{avg_lat:.2f}s"
|
||||
}
|
||||
|
||||
# 4. 生成报告
|
||||
# 详细报告
|
||||
with open(os.path.join(output_dir, "test_details.csv"), 'w', newline='', encoding='utf-8') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=["instruction", "run_id", "success", "latency", "error"])
|
||||
writer.writeheader()
|
||||
writer.writerows(detailed_results)
|
||||
|
||||
# 摘要报告
|
||||
with open(os.path.join(output_dir, "test_summary.csv"), 'w', newline='', encoding='utf-8') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["Instruction", "Total Runs", "Success Runs", "Success Rate", "Avg Latency"])
|
||||
for prompt, stats in summary_stats.items():
|
||||
writer.writerow([
|
||||
prompt,
|
||||
stats["total_runs"],
|
||||
stats["success_runs"],
|
||||
stats["success_rate"],
|
||||
stats["avg_latency"]
|
||||
])
|
||||
|
||||
print(f"\n✅ 测试完成! 统计报告已保存至 {output_dir}")
|
||||
|
||||
37
tools/test_validate/modules/drone_uploader.py
Normal file
37
tools/test_validate/modules/drone_uploader.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import requests
|
||||
import os
|
||||
import sys
|
||||
|
||||
def upload_mission(drone_ip, file_path):
|
||||
"""上传一个JSON任务文件到无人机"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"Error: File not found at {file_path}")
|
||||
return
|
||||
|
||||
url = f"http://{drone_ip}:5000/missions"
|
||||
print(f"正在上传 {file_path} 到 {url} ...")
|
||||
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
files = {'file': (os.path.basename(file_path), f, 'application/json')}
|
||||
response = requests.post(url, files=files, timeout=10)
|
||||
|
||||
# 检查HTTP响应状态码
|
||||
response.raise_for_status()
|
||||
|
||||
print("上传成功!")
|
||||
print("无人机端响应:", response.json())
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"上传过程中发生错误: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) < 3:
|
||||
print("用法: python ground_station_client.py [无人机IP地址] [JSON文件路径]")
|
||||
print("示例: python ground_station_client.py 192.168.1.10 ./missions/rescue_mission.json")
|
||||
sys.exit(1)
|
||||
|
||||
drone_ip_address = sys.argv[1]
|
||||
mission_file_path = sys.argv[2]
|
||||
|
||||
upload_mission(drone_ip_address, mission_file_path)
|
||||
87
tools/test_validate/modules/interactive_test.py
Normal file
87
tools/test_validate/modules/interactive_test.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from .api_client import APIClient
|
||||
from .visualizer import generate_visualization, sanitize_filename
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
|
||||
def run_interactive_test():
|
||||
client = APIClient()
|
||||
print("\n🚀 进入交互式测试模式 (输入 'exit' 或 'q' 退出)")
|
||||
|
||||
while True:
|
||||
try:
|
||||
prompt = input("\n请输入测试指令: ").strip()
|
||||
if prompt.lower() in ['exit', 'q']:
|
||||
break
|
||||
if not prompt:
|
||||
continue
|
||||
|
||||
print("⏳ 正在请求后端 API...")
|
||||
result = client.send_request(prompt)
|
||||
|
||||
if result['success']:
|
||||
print(f"✅ 请求成功 (耗时: {result['latency']:.2f}s)")
|
||||
|
||||
# 创建输出目录
|
||||
sanitized_name = sanitize_filename(prompt)
|
||||
output_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"validation",
|
||||
"temporary",
|
||||
sanitized_name
|
||||
)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 保存 JSON
|
||||
json_path = os.path.join(output_dir, "response.json")
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(result['data'], f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 保存日志
|
||||
log_path = os.path.join(output_dir, "process.log")
|
||||
with open(log_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"Prompt: {prompt}\n")
|
||||
f.write(f"Status: {result['http_status']}\n")
|
||||
f.write(f"Latency: {result['latency']}\n")
|
||||
f.write(f"Response: {json.dumps(result['data'], ensure_ascii=False)}\n")
|
||||
|
||||
# 生成图片
|
||||
if result['data'] and 'root' in result['data']:
|
||||
png_path = os.path.join(output_dir, "plan.png")
|
||||
if generate_visualization(result['data']['root'], png_path):
|
||||
print(f"🖼️ 可视化图已生成: {png_path}")
|
||||
else:
|
||||
print("⚠️ 可视化生成失败")
|
||||
|
||||
print(f"📂 结果已保存至: {output_dir}")
|
||||
else:
|
||||
print(f"❌ 请求失败: {result['error']}")
|
||||
|
||||
# 即使失败也保存日志,以便排查
|
||||
sanitized_name = sanitize_filename(prompt)
|
||||
output_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"validation",
|
||||
"temporary",
|
||||
sanitized_name
|
||||
)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
log_path = os.path.join(output_dir, "process.log")
|
||||
with open(log_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"Prompt: {prompt}\n")
|
||||
f.write(f"Status: {result['http_status']}\n")
|
||||
f.write(f"Latency: {result['latency']}\n")
|
||||
f.write(f"Error: {result['error']}\n")
|
||||
# 如果有部分数据,也记录下来
|
||||
if result['data']:
|
||||
f.write(f"Partial Response: {json.dumps(result['data'], ensure_ascii=False)}\n")
|
||||
print(f"⚠️ 错误日志已保存至: {log_path}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n已取消")
|
||||
break
|
||||
|
||||
174
tools/test_validate/modules/llm_tester.py
Normal file
174
tools/test_validate/modules/llm_tester.py
Normal file
@@ -0,0 +1,174 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def build_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="调用本地 llama-server (OpenAI兼容) 进行推理,支持自定义系统/用户提示词"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
default=os.getenv("SIMPLE_BASE_URL", "http://127.0.0.1:8081/v1"),
|
||||
help="llama-server 的基础URL(默认: http://127.0.0.1:8081/v1,或环境变量 SIMPLE_BASE_URL)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default=os.getenv("SIMPLE_MODEL", "local-model"),
|
||||
help="模型名称(默认: local-model,或环境变量 SIMPLE_MODEL)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system",
|
||||
default="You are a helpful assistant.",
|
||||
help="系统提示词(system role)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system-file",
|
||||
default=None,
|
||||
help="系统提示词文件路径(txt);若提供,则覆盖 --system 的字符串",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--user",
|
||||
default=None,
|
||||
help="用户提示词(user role);若不传则从交互式输入读取",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="采样温度(默认: 0.2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="最大生成Token数(默认: 4096)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=120.0,
|
||||
help="HTTP超时时间秒(默认: 120)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="打印完整返回JSON",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def call_llama_server(
|
||||
base_url: str,
|
||||
model: str,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
timeout: float,
|
||||
) -> Dict[str, Any]:
|
||||
endpoint = base_url.rstrip("/") + "/chat/completions"
|
||||
headers: Dict[str, str] = {"Content-Type": "application/json"}
|
||||
|
||||
# 兼容需要API Key的代理/服务(llama-server通常不强制)
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
resp = requests.post(endpoint, headers=headers, data=json.dumps(payload), timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = build_args()
|
||||
|
||||
user_prompt = args.user
|
||||
if not user_prompt:
|
||||
try:
|
||||
user_prompt = input("请输入用户提示词: ")
|
||||
except KeyboardInterrupt:
|
||||
print("\n已取消。")
|
||||
sys.exit(1)
|
||||
|
||||
# 解析系统提示词:优先使用 --system-file
|
||||
system_prompt = args.system
|
||||
if args.system_file:
|
||||
try:
|
||||
with open(args.system_file, "r", encoding="utf-8") as f:
|
||||
system_prompt = f.read()
|
||||
except Exception as e:
|
||||
print("\n❌ 读取系统提示词文件失败:")
|
||||
print(str(e))
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
print("--- llama-server 推理 ---")
|
||||
print(f"Base URL: {args.base_url}")
|
||||
print(f"Model: {args.model}")
|
||||
if args.system_file:
|
||||
print(f"System(from file): {args.system_file}")
|
||||
else:
|
||||
print(f"System: {system_prompt}")
|
||||
print(f"User: {user_prompt}")
|
||||
|
||||
data = call_llama_server(
|
||||
base_url=args.base_url,
|
||||
model=args.model,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
temperature=args.temperature,
|
||||
max_tokens=args.max_tokens,
|
||||
timeout=args.timeout,
|
||||
)
|
||||
|
||||
if args.verbose:
|
||||
print("\n完整返回JSON:")
|
||||
print(json.dumps(data, ensure_ascii=False, indent=2))
|
||||
|
||||
# 尝试按OpenAI兼容格式提取assistant内容
|
||||
content = None
|
||||
try:
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if content is not None:
|
||||
print("\n模型输出:")
|
||||
print(content)
|
||||
else:
|
||||
# 兜底打印
|
||||
print("\n无法按OpenAI兼容字段解析内容,原始返回如下:")
|
||||
print(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print("\n❌ 请求失败:请确认 llama-server 已在 8081 端口启动并可访问。")
|
||||
print(f"详情: {e}")
|
||||
sys.exit(2)
|
||||
except Exception as e:
|
||||
print("\n❌ 发生未预期的错误:")
|
||||
print(str(e))
|
||||
sys.exit(3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
287
tools/test_validate/modules/visualizer.py
Normal file
287
tools/test_validate/modules/visualizer.py
Normal file
@@ -0,0 +1,287 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
从API测试日志中提取JSON响应并批量可视化
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
import platform
|
||||
import random
|
||||
import html
|
||||
from typing import Dict, List, Tuple
|
||||
from collections import defaultdict
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
def sanitize_filename(text: str) -> str:
|
||||
"""将文本转换为安全的文件名"""
|
||||
# 移除或替换不安全的字符
|
||||
text = re.sub(r'[<>:"/\\|?*]', '_', text)
|
||||
# 限制长度
|
||||
if len(text) > 100:
|
||||
text = text[:100]
|
||||
return text
|
||||
|
||||
def _pick_zh_font():
|
||||
"""选择合适的中文字体"""
|
||||
sys = platform.system()
|
||||
if sys == "Windows":
|
||||
return "Microsoft YaHei"
|
||||
elif sys == "Darwin":
|
||||
return "PingFang SC"
|
||||
else:
|
||||
return "Noto Sans CJK SC"
|
||||
|
||||
def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
|
||||
"""递归辅助函数,用于添加节点和边。"""
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
logging.critical("错误:未安装graphviz库。请运行: pip install graphviz")
|
||||
return ""
|
||||
|
||||
current_id = f"{id(node)}_{random.randint(1000, 9999)}"
|
||||
|
||||
# 准备节点标签(HTML-like,正确换行与转义)
|
||||
name = html.escape(str(node.get('name', '')))
|
||||
ntype = html.escape(str(node.get('type', '')))
|
||||
label_parts = [f"<B>{name}</B> <FONT POINT-SIZE='10'><I>({ntype})</I></FONT>"]
|
||||
|
||||
# 格式化参数显示
|
||||
params = node.get('params') or {}
|
||||
if params:
|
||||
params_lines = []
|
||||
for key, value in params.items():
|
||||
k = html.escape(str(key))
|
||||
if isinstance(value, float):
|
||||
value_str = f"{value:.2f}".rstrip('0').rstrip('.')
|
||||
else:
|
||||
value_str = str(value)
|
||||
v = html.escape(value_str)
|
||||
params_lines.append(f"{k}: {v}")
|
||||
params_text = "<BR ALIGN='LEFT'/>".join(params_lines)
|
||||
label_parts.append(f"<FONT POINT-SIZE='9' COLOR='#555555'>{params_text}</FONT>")
|
||||
|
||||
node_label = f"<{'<BR/>'.join(label_parts)}>"
|
||||
|
||||
# 根据类型设置节点样式和颜色
|
||||
node_type = (node.get('type') or '').lower()
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#e6e6e6' # 默认灰色填充
|
||||
border_color = '#666666' # 默认描边色
|
||||
|
||||
if node_type == 'action':
|
||||
shape = 'box'
|
||||
style = 'rounded,filled'
|
||||
fillcolor = "#cde4ff" # 浅蓝
|
||||
elif node_type == 'condition':
|
||||
shape = 'diamond'
|
||||
style = 'filled'
|
||||
fillcolor = "#fff2cc" # 浅黄
|
||||
elif node_type == 'sequence':
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#d5e8d4' # 绿色
|
||||
elif node_type == 'selector':
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#ffe6cc' # 橙色
|
||||
elif node_type == 'parallel':
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#e1d5e7' # 紫色
|
||||
|
||||
# 特别标记安全相关节点
|
||||
if node.get('name') in ['battery_above', 'gps_status', 'SafetyMonitor']:
|
||||
border_color = '#ff0000' # 红色边框突出显示安全节点
|
||||
style = 'filled,bold' # 加粗
|
||||
|
||||
dot.node(current_id, label=node_label, shape=shape, style=style, fillcolor=fillcolor, color=border_color)
|
||||
|
||||
# 连接父节点
|
||||
if parent_id:
|
||||
dot.edge(parent_id, current_id)
|
||||
|
||||
# 递归处理子节点
|
||||
children = node.get("children", [])
|
||||
if not children:
|
||||
return current_id
|
||||
|
||||
# 记录所有子节点的ID
|
||||
child_ids = []
|
||||
|
||||
# 正确的递归连接:每个子节点都连接到当前节点
|
||||
for child in children:
|
||||
child_id = _add_nodes_and_edges(child, dot, current_id)
|
||||
child_ids.append(child_id)
|
||||
|
||||
# 子节点同级排列(横向排布,更直观地表现同层)
|
||||
if len(child_ids) > 1:
|
||||
with dot.subgraph(name=f"rank_{current_id}") as s:
|
||||
s.attr(rank='same')
|
||||
for cid in child_ids:
|
||||
s.node(cid)
|
||||
|
||||
return current_id
|
||||
|
||||
def generate_visualization(node: Dict, file_path: str):
|
||||
"""
|
||||
使用Graphviz将Pytree字典可视化,并保存到指定路径。
|
||||
"""
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
logging.critical("错误:未安装graphviz库。请运行: pip install graphviz")
|
||||
return False
|
||||
|
||||
fontname = _pick_zh_font()
|
||||
|
||||
dot = Digraph('Pytree', comment='Drone Mission Plan')
|
||||
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20', fontname=fontname)
|
||||
dot.attr('node', shape='box', style='rounded,filled', fontname=fontname)
|
||||
dot.attr('edge', fontname=fontname)
|
||||
|
||||
_add_nodes_and_edges(node, dot)
|
||||
|
||||
try:
|
||||
# 确保输出目录存在
|
||||
base_path, ext = os.path.splitext(file_path)
|
||||
render_path = base_path if ext.lower() == '.png' else file_path
|
||||
|
||||
out_dir = os.path.dirname(render_path)
|
||||
if out_dir and not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# 保存为 .png 文件,并自动删除源码 .gv 文件
|
||||
dot.render(render_path, format='png', cleanup=True, view=False)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"❌ 生成可视化图形失败: {e}")
|
||||
return False
|
||||
|
||||
# 保留旧的函数以兼容(如果有其他脚本引用)
|
||||
def _visualize_pytree(node: Dict, file_path: str):
|
||||
return generate_visualization(node, file_path)
|
||||
|
||||
def parse_log_file(log_file_path: str) -> Dict[str, List[Dict]]:
|
||||
"""
|
||||
解析日志文件,提取原始指令和完整API响应JSON
|
||||
返回: {原始指令: [JSON响应列表]}
|
||||
"""
|
||||
with open(log_file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# 按分隔符分割条目
|
||||
entries = re.split(r'={80,}', content)
|
||||
|
||||
results = defaultdict(list)
|
||||
|
||||
for entry in entries:
|
||||
if not entry.strip():
|
||||
continue
|
||||
|
||||
# 提取原始指令
|
||||
instruction_match = re.search(r'原始指令:\s*(.+)', entry)
|
||||
if not instruction_match:
|
||||
continue
|
||||
|
||||
original_instruction = instruction_match.group(1).strip()
|
||||
|
||||
# 提取完整API响应JSON
|
||||
json_match = re.search(r'完整API响应:\s*\n(\{.*\})', entry, re.DOTALL)
|
||||
if not json_match:
|
||||
logging.warning(f"未找到指令 '{original_instruction}' 的JSON响应")
|
||||
continue
|
||||
|
||||
json_str = json_match.group(1).strip()
|
||||
|
||||
try:
|
||||
json_obj = json.loads(json_str)
|
||||
results[original_instruction].append(json_obj)
|
||||
logging.info(f"成功提取指令 '{original_instruction}' 的JSON响应")
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"解析指令 '{original_instruction}' 的JSON失败: {e}")
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def process_and_visualize(log_file_path: str, output_dir: str):
|
||||
"""
|
||||
处理日志文件并批量可视化
|
||||
"""
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 解析日志文件
|
||||
logging.info(f"开始解析日志文件: {log_file_path}")
|
||||
instruction_responses = parse_log_file(log_file_path)
|
||||
|
||||
logging.info(f"共找到 {len(instruction_responses)} 个不同的原始指令")
|
||||
|
||||
# 处理每个指令的所有响应
|
||||
for instruction, responses in instruction_responses.items():
|
||||
logging.info(f"\n处理指令: {instruction} (共 {len(responses)} 个响应)")
|
||||
|
||||
# 创建指令目录(使用安全的文件名)
|
||||
safe_instruction_name = sanitize_filename(instruction)
|
||||
instruction_dir = os.path.join(output_dir, safe_instruction_name)
|
||||
os.makedirs(instruction_dir, exist_ok=True)
|
||||
|
||||
# 处理每个响应
|
||||
for idx, response in enumerate(responses, 1):
|
||||
try:
|
||||
# 提取root节点
|
||||
root_node = response.get('root')
|
||||
if not root_node:
|
||||
logging.warning(f"响应 #{idx} 没有root节点,跳过")
|
||||
continue
|
||||
|
||||
# 生成文件名
|
||||
json_filename = f"response_{idx}.json"
|
||||
png_filename = f"response_{idx}.png"
|
||||
|
||||
json_path = os.path.join(instruction_dir, json_filename)
|
||||
png_path = os.path.join(instruction_dir, png_filename)
|
||||
|
||||
# 保存JSON文件
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(response, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logging.info(f" 保存JSON: {json_filename}")
|
||||
|
||||
# 生成可视化
|
||||
generate_visualization(root_node, png_path)
|
||||
logging.info(f" 生成可视化: {png_filename}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"处理响应 #{idx} 时出错: {e}")
|
||||
continue
|
||||
|
||||
logging.info(f"\n✅ 所有处理完成!结果保存在: {output_dir}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="批量可视化API测试日志")
|
||||
parser.add_argument("--log", default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "api_test_log.txt"), help="日志文件路径")
|
||||
parser.add_argument("--out", default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "validation"), help="输出目录")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
log_file = args.log
|
||||
output_directory = args.out
|
||||
|
||||
print(f"日志文件: {log_file}")
|
||||
print(f"输出目录: {output_directory}")
|
||||
|
||||
if os.path.exists(log_file):
|
||||
process_and_visualize(log_file, output_directory)
|
||||
else:
|
||||
print(f"错误: 找不到日志文件 {log_file}")
|
||||
118
tools/test_validate/run_tests.py
Executable file
118
tools/test_validate/run_tests.py
Executable file
@@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import subprocess
|
||||
|
||||
# Add current directory to path so modules can be imported
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from modules.interactive_test import run_interactive_test
|
||||
from modules.batch_runner import run_batch_test
|
||||
|
||||
def clear_screen():
|
||||
os.system('cls' if os.name == 'nt' else 'clear')
|
||||
|
||||
def print_header():
|
||||
clear_screen()
|
||||
print("=" * 60)
|
||||
print(" Drone Planning 系统测试工具箱 (Unified)")
|
||||
print("=" * 60)
|
||||
print(f"当前工作目录: {os.getcwd()}")
|
||||
print("-" * 60)
|
||||
|
||||
def run_legacy_module(module_name, args=None):
|
||||
"""运行旧的独立 Python 脚本 (用于 LLM 测试和上传工具)"""
|
||||
script_path = os.path.join(os.path.dirname(__file__), "modules", module_name)
|
||||
|
||||
if not os.path.exists(script_path):
|
||||
print(f"❌ 错误: 找不到脚本 {script_path}")
|
||||
input("\n按回车键继续...")
|
||||
return
|
||||
|
||||
cmd = [sys.executable, script_path]
|
||||
if args:
|
||||
cmd.extend(args)
|
||||
|
||||
print(f"\n🚀 正在启动: {module_name} ...\n")
|
||||
try:
|
||||
# 保持当前环境变量
|
||||
env = os.environ.copy()
|
||||
# 确保 PYTHONPATH 包含当前目录
|
||||
env["PYTHONPATH"] = os.path.dirname(__file__) + os.pathsep + env.get("PYTHONPATH", "")
|
||||
subprocess.run(cmd, env=env, check=False)
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ 操作已取消")
|
||||
except Exception as e:
|
||||
print(f"\n❌ 运行出错: {e}")
|
||||
|
||||
input("\n按回车键返回菜单...")
|
||||
|
||||
def menu_drone_upload():
|
||||
print("\n[3] 上传任务到无人机")
|
||||
print("说明: 将生成的任务文件上传到无人机 (Ground Station Client)。")
|
||||
|
||||
ip = input("\n请输入无人机 IP 地址 (默认: 127.0.0.1): ").strip() or "127.0.0.1"
|
||||
file_path = input("请输入任务文件路径 (.json): ").strip()
|
||||
|
||||
if not file_path:
|
||||
print("❌ 必须提供文件路径!")
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
print(f"❌ 文件不存在: {file_path}")
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
run_legacy_module("drone_uploader.py", [ip, file_path])
|
||||
|
||||
def main():
|
||||
# 切换到脚本所在目录
|
||||
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
while True:
|
||||
print_header()
|
||||
print("请选择测试模式:")
|
||||
print("1. 交互式单次测试 (Interactive Single Test)")
|
||||
print(" - 手动输入指令,即时查看结果和可视化")
|
||||
print("2. 批量/场景测试 (Batch/Scenario Test)")
|
||||
print(" - 读取指令文件,支持多轮压力测试,生成统计报告")
|
||||
print("3. 上传任务到无人机 (Drone Uploader)")
|
||||
print("4. LLM 服务连通性测试 (LLM Tester)")
|
||||
print("0. 退出")
|
||||
print("-" * 60)
|
||||
|
||||
choice = input("请输入选项 [0-4]: ").strip()
|
||||
|
||||
if choice == '1':
|
||||
try:
|
||||
run_interactive_test()
|
||||
except Exception as e:
|
||||
print(f"❌ 运行出错: {e}")
|
||||
input("\n按回车键返回菜单...")
|
||||
|
||||
elif choice == '2':
|
||||
try:
|
||||
run_batch_test()
|
||||
except Exception as e:
|
||||
print(f"❌ 运行出错: {e}")
|
||||
input("\n按回车键返回菜单...")
|
||||
|
||||
elif choice == '3':
|
||||
menu_drone_upload()
|
||||
|
||||
elif choice == '4':
|
||||
run_legacy_module("llm_tester.py")
|
||||
|
||||
elif choice == '0':
|
||||
print("\n👋 再见!")
|
||||
break
|
||||
else:
|
||||
print("\n❌ 无效选项,请重试")
|
||||
time.sleep(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user