Compare commits
10 Commits
main
...
d32520d83f
| Author | SHA1 | Date | |
|---|---|---|---|
| d32520d83f | |||
| afd170c451 | |||
| fd89745950 | |||
| 8e333ac03f | |||
| 7b9d05b306 | |||
| 781b490cdc | |||
| ce963ed7d6 | |||
| 9703f7cc10 | |||
| 7bf8210b80 | |||
| 3adf3985cb |
91
README.md
91
README.md
@@ -13,8 +13,13 @@
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── main.py # 应用主入口,提供Web API
|
||||
│ │ ├── py_tree_generator.py # RAG与LLM集成,生成py_tree
|
||||
│ │ ├── prompts/ # LLM 提示词
|
||||
│ │ │ ├── system_prompt.txt # 复杂模式提示词(行为树与安全监控)
|
||||
│ │ │ ├── simple_mode_prompt.txt # 简单模式提示词(单一原子动作JSON)
|
||||
│ │ │ └── classifier_prompt.txt # 指令简单/复杂分类提示词
|
||||
│ │ ├── ...
|
||||
│ ├── generated_visualizations/ # 存放最新生成的py_tree可视化图像
|
||||
│ ├── generated_reasoning_content/ # 存放最新推理链Markdown(<plan_id>.md)
|
||||
│ └── requirements.txt # 后端服务的Python依赖
|
||||
│
|
||||
├── tools/
|
||||
@@ -22,7 +27,8 @@
|
||||
│ ├── knowledge_base/ # 【处理后】存放build_knowledge_base.py生成的.ndjson文件
|
||||
│ ├── vector_store/ # 【数据库】存放最终的ChromaDB向量数据库
|
||||
│ ├── build_knowledge_base.py # 【步骤1】用于将原始数据转换为自然语言知识
|
||||
│ └── ingest.py # 【步骤2】用于将自然语言知识摄入向量数据库
|
||||
│ ├── ingest.py # 【步骤2】用于将自然语言知识摄入向量数据库
|
||||
│ └── test_llama_server.py # 直接调用本地8081端口llama-server,支持 --system / --system-file
|
||||
│
|
||||
├── / # ROS2接口定义 (保持不变)
|
||||
└── docs/
|
||||
@@ -67,6 +73,78 @@
|
||||
|
||||
---
|
||||
|
||||
## 指令分类与分流
|
||||
|
||||
后端在生成任务前会先对用户指令进行“简单/复杂”分类,并分流到不同提示词与模型:
|
||||
|
||||
- 分类提示词:`backend_service/src/prompts/classifier_prompt.txt`
|
||||
- 简单模式提示词:`backend_service/src/prompts/simple_mode_prompt.txt`
|
||||
- 复杂模式提示词:`backend_service/src/prompts/system_prompt.txt`
|
||||
|
||||
分类仅输出如下JSON之一:`{"mode":"simple"}` 或 `{"mode":"complex"}`。两种模式都会执行检索增强(RAG),将参考知识拼接到用户指令后再进行推理。
|
||||
|
||||
当为简单模式时,LLM仅输出:
|
||||
`{"mode":"simple","action":{"name":"<action>","params":{...}}}`。
|
||||
后端不会再自动封装为复杂行为树;将直接返回简单JSON,并附加 `plan_id` 与 `visualization_url`(单动作可视化)。
|
||||
|
||||
### 环境变量(可选)
|
||||
|
||||
支持为“分类/简单/复杂”三类调用分别配置模型与Base URL(未设置时回退到默认本地配置):
|
||||
|
||||
- `CLASSIFIER_MODEL`, `CLASSIFIER_BASE_URL`
|
||||
- `SIMPLE_MODEL`, `SIMPLE_BASE_URL`
|
||||
- `COMPLEX_MODEL`, `COMPLEX_BASE_URL`
|
||||
|
||||
通用API Key:`OPENAI_API_KEY`
|
||||
|
||||
推理链捕获相关:
|
||||
- `ENABLE_REASONING_CAPTURE`:是否允许模型返回含有 <think> 的原文以便捕获推理链;默认 true。
|
||||
- `REASONING_PREVIEW_LINES`:在后端日志中打印推理链预览的行数;默认 20。
|
||||
|
||||
示例:
|
||||
```bash
|
||||
export CLASSIFIER_MODEL="qwen2.5-1.8b-instruct"
|
||||
export SIMPLE_MODEL="qwen2.5-1.8b-instruct"
|
||||
export COMPLEX_MODEL="qwen2.5-7b-instruct"
|
||||
export CLASSIFIER_BASE_URL="http://$ORIN_IP:8081/v1"
|
||||
export SIMPLE_BASE_URL="http://$ORIN_IP:8081/v1"
|
||||
export COMPLEX_BASE_URL="http://$ORIN_IP:8081/v1"
|
||||
export OPENAI_API_KEY="sk-no-key-required"
|
||||
|
||||
# 推理链捕获(可选)
|
||||
export ENABLE_REASONING_CAPTURE=true # 默认已为true;如需关闭,设置为 false
|
||||
export REASONING_PREVIEW_LINES=30 # 调整日志预览行数
|
||||
```
|
||||
|
||||
### 测试简单模式
|
||||
|
||||
启动服务后,运行内置测试脚本:
|
||||
|
||||
```bash
|
||||
cd tools
|
||||
python test_api.py
|
||||
```
|
||||
|
||||
示例输入:“简单模式,起飞” 或 “起飞到10米”。返回结果为简单JSON(无 `root`):包含 `mode`、`action`、`plan_id`、`visualization_url`。
|
||||
|
||||
### 直接调用 llama-server(绕过后端)
|
||||
|
||||
当仅需测试本地 8081 端口的推理服务(OpenAI 兼容接口)时,可使用内置脚本:
|
||||
|
||||
```bash
|
||||
python tools/test_llama_server.py \
|
||||
--system-file backend_service/src/prompts/system_prompt.txt \
|
||||
--user "起飞到10米然后降落" \
|
||||
--base-url "http://127.0.0.1:8081/v1" \
|
||||
--verbose
|
||||
```
|
||||
|
||||
说明:
|
||||
- 支持 `--system` 或 `--system-file` 自定义提示词文件;`--system-file` 优先。
|
||||
- 默认解析 OpenAI 风格返回,若包含 `<think>` 推理内容会显示在输出中(具体取决于模型和服务配置)。
|
||||
|
||||
---
|
||||
|
||||
## 工作流程
|
||||
|
||||
整个系统的工作流程分为两个主要阶段:
|
||||
@@ -235,7 +313,7 @@ python test_api.py
|
||||
"user_prompt": "无人机起飞到10米,然后前往机库,最后降落。"
|
||||
}
|
||||
```
|
||||
- **Success Response**:
|
||||
- **Success Response(复杂模式)**:
|
||||
```json
|
||||
{
|
||||
"root": { ... },
|
||||
@@ -243,6 +321,15 @@ python test_api.py
|
||||
"visualization_url": "/static/py_tree.png"
|
||||
}
|
||||
```
|
||||
- **Success Response(简单模式)**:
|
||||
```json
|
||||
{
|
||||
"mode": "simple",
|
||||
"action": { "name": "takeoff", "params": { "altitude": 10.0 } },
|
||||
"plan_id": "some-unique-id",
|
||||
"visualization_url": "/static/py_tree.png"
|
||||
}
|
||||
```
|
||||
|
||||
##### **B. 查看任务可视化**
|
||||
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
好的,我现在需要处理用户的任务指令,生成一个符合要求的行为树JSON。首先,我得仔细阅读用户的指令和参考知识,确保理解所有要求。
|
||||
|
||||
用户的目标是搜索并锁定危险性最高的气球(红色>蓝色>绿色),然后跟踪30秒后进行打击。参考知识中提到了三个地点,但用户可能希望搜索的区域是这些地点附近,或者需要根据这些坐标来规划路径。
|
||||
|
||||
首先,我需要确定无人机的起降点。参考知识中的地点坐标可能作为搜索区域的中心。比如,用户提到的“学生宿舍”坐标是(5,3,2),但可能需要选择一个合适的中心点进行搜索。不过用户没有明确指定搜索区域,可能需要使用search_pattern来覆盖这些区域,或者直接在某个中心点周围搜索。
|
||||
|
||||
接下来,任务流程应该是:起飞→飞往搜索区域→搜索目标→检测到目标后跟踪→打击。同时必须包含安全监控。
|
||||
|
||||
根据用户提供的参考知识,可能需要将搜索区域设置为某个中心点。比如,用户提到的“跷跷板”在(x:15, y:-8.5, z:1.2),但z坐标需要调整到至少1米,这里已经是1.2,没问题。或者可能选择其他地点作为搜索中心。但用户没有明确说明,可能需要假设搜索区域是这些地点附近,或者使用一个综合的中心点。
|
||||
|
||||
不过用户可能希望无人机先飞往某个特定的搜索区。比如,参考知识中的“学生宿舍”可能是一个可能的区域,但需要确认。或者用户可能希望无人机在某个中心点周围进行搜索。例如,使用search_pattern的中心点可能选在某个已知地点,比如“跷跷板”附近,或者综合多个点。
|
||||
|
||||
但用户没有明确指定搜索区域,所以可能需要使用search_pattern的中心点为某个已知地点,比如“跷跷板”的坐标,或者选择一个合理的中心点。例如,假设搜索区域是“跷跷板”所在的位置,那么中心坐标为(15, -8.5, 1.2)。或者可能需要将搜索区域设置为多个地点的组合,但用户没有说明,所以可能需要选择一个中心点。
|
||||
|
||||
接下来,考虑使用search_pattern来搜索,因为目标位置未知。参数中需要指定pattern_type,比如spiral或grid。假设选择spiral模式,覆盖更大的区域。中心点可能选在某个已知地点,比如“跷跷板”的坐标,或者综合多个点。但用户没有明确,所以可能需要选择一个合理的中心点,比如“跷跷板”的坐标。
|
||||
|
||||
然后,检测到目标后,需要跟踪30秒,再打击。所以流程是:起飞→飞往搜索区→搜索→检测→跟踪→打击→降落。
|
||||
|
||||
安全监控必须包含在Parallel节点中,监控电池和GPS,并在紧急情况下返航。
|
||||
|
||||
现在,构建行为树的结构:
|
||||
|
||||
根节点是Parallel,包含主任务和安全监控。主任务是Sequence,包含起飞、飞往搜索区、搜索、检测、跟踪、打击、降落。安全监控是Selector,包含电池和GPS条件,以及紧急处理Sequence。
|
||||
|
||||
在搜索部分,使用search_pattern,参数包括pattern_type为spiral,中心点选在某个已知地点,比如“跷跷板”的坐标(15, -8.5, 1.2),半径可能设为50米,覆盖较大区域。目标类别是balloon,描述为“红色气球危险性高于蓝色气球高于绿色气球”,但参数中description需要描述目标属性,比如“红>蓝>绿”。
|
||||
|
||||
检测到目标后,使用track_object,参数target_class为balloon,description为“红>蓝>绿”,track_time为30秒。然后strike_target,最后评估战损。
|
||||
|
||||
需要确保所有参数符合约束,比如z坐标在1-5000米之间,半径在5-1000米之间,速度等参数在合理范围。
|
||||
|
||||
现在,将各个步骤的参数填入JSON中。例如,起飞到2米,飞往搜索区的中心点可能选在“跷跷板”的坐标(15, -8.5, 1.2),但z需要调整到至少1米,这里已经是1.2。所以fly_to_waypoint的参数x=15,y=-8.5,z=1.2,接受半径2米。
|
||||
|
||||
搜索使用search_pattern,pattern_type为spiral,中心点为(15, -8.5, 1.2),半径设为50米,覆盖区域。目标类别是balloon,description为“红>蓝>绿”。
|
||||
|
||||
检测到目标后,使用track_object,参数target_class为balloon,description为“红>蓝>绿”,track_time为30秒,min_confidence设为0.7,safe_distance设为15米。
|
||||
|
||||
然后strike_target,参数target_class为balloon,description为“红>蓝>绿”,count为1。
|
||||
|
||||
最后,降落。
|
||||
|
||||
安全监控部分,电池阈值设为0.35,GPS的min_satellites设为8,紧急处理Sequence包括emergency_return和land。
|
||||
|
||||
现在,将所有步骤的参数填入JSON中,确保符合所有规则。
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 200 KiB After Width: | Height: | Size: 233 KiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
24
backend_service/src/prompts/classifier_prompt.txt
Normal file
24
backend_service/src/prompts/classifier_prompt.txt
Normal file
@@ -0,0 +1,24 @@
|
||||
你是一个严格的任务分类器。只输出一个JSON对象,不要输出解释或多余文本。
|
||||
根据用户指令与下述可用节点定义,判断其为“简单”或“复杂”。
|
||||
|
||||
- 简单:单一原子动作即可完成(例如“起飞”“飞机自检”“移动到某地(已给定坐标)”“对着某点环绕XY圈(如‘对着学生宿舍环绕三十两圈’)”等),且无需行为树与安全并行监控。
|
||||
- 复杂:需要多步流程、搜索/检测/跟踪/评估、战损确认、或需要模板化任务结构与安全并行监控。
|
||||
|
||||
输出格式(严格遵守):
|
||||
{"mode":"simple"} 或 {"mode":"complex"}
|
||||
|
||||
—— 可用节点定义——
|
||||
```json
|
||||
{
|
||||
"actions": [
|
||||
{"name": "takeoff"}, {"name": "land"}, {"name": "fly_to_waypoint"}, {"name": "move_direction"}, {"name": "orbit_around_point"}, {"name": "orbit_around_target"}, {"name": "loiter"},
|
||||
{"name": "object_detect"}, {"name": "strike_target"}, {"name": "battle_damage_assessment"},
|
||||
{"name": "search_pattern"}, {"name": "track_object"}, {"name": "deliver_payload"},
|
||||
{"name": "preflight_checks"}, {"name": "emergency_return"}
|
||||
],
|
||||
"conditions": [
|
||||
{"name": "battery_above"}, {"name": "at_waypoint"}, {"name": "object_detected"},
|
||||
{"name": "target_destroyed"}, {"name": "time_elapsed"}, {"name": "gps_status"}
|
||||
]
|
||||
}
|
||||
```
|
||||
61
backend_service/src/prompts/simple_mode_prompt.txt
Normal file
61
backend_service/src/prompts/simple_mode_prompt.txt
Normal file
@@ -0,0 +1,61 @@
|
||||
你是一个无人机简单指令执行规划器。你的任务:输出一个严格的JSON对象。
|
||||
|
||||
输出要求(必须遵守):
|
||||
- 只输出一个JSON对象,不要任何解释或多余文本。
|
||||
- JSON结构:
|
||||
{"mode":"simple","action":{"name":"<action_name>","params":{...}}}
|
||||
- 不包含任何行为树结构与安全监控并行,仅输出单一原子动作。
|
||||
|
||||
示例:
|
||||
- “起飞到10米” → {"mode":"simple","action":{"name":"takeoff","params":{"altitude":10.0}}}
|
||||
- “移动到(120,80,20)” → {"mode":"simple","action":{"name":"fly_to_waypoint","params":{"x":120.0,"y":80.0,"z":20.0,"acceptance_radius":2.0}}}
|
||||
- “飞机自检” → {"mode":"simple","action":{"name":"preflight_checks","params":{"check_level":"comprehensive"}}}
|
||||
|
||||
—— 可用节点定义——
|
||||
```json
|
||||
{
|
||||
"actions": [
|
||||
{"name": "takeoff", "description": "无人机从当前位置垂直起飞到指定的海拔高度。", "params": {"altitude": "float, 目标海拔高度(米),范围[1, 100],默认为2"}},
|
||||
{"name": "land", "description": "降落无人机。可选择当前位置或返航点降落。", "params": {"mode": "string, 可选值: 'current'(当前位置), 'home'(返航点)"}},
|
||||
{"name": "fly_to_waypoint", "description": "导航至一个指定坐标点。使用相对坐标系(x,y,z),单位为米。", "params": {"x": "float", "y": "float", "z": "float", "acceptance_radius": "float, 可选,默认2.0"}},
|
||||
{"name": "move_direction", "description": "按指定方向直线移动。方向可为绝对方位或相对机体朝向。", "params": {"direction": "string: north|south|east|west|forward|backward|left|right", "distance": "float[1,10000], 可选, 不指定则持续移动"}},
|
||||
{"name": "orbit_around_point", "description": "以给定中心点为中心,等速圆周飞行指定圈数。", "params": {"center_x": "float", "center_y": "float", "center_z": "float", "radius": "float[5,1000]", "laps": "int[1,20]", "clockwise": "boolean, 可选, 默认true", "speed_mps": "float[0.5,15], 可选", "gimbal_lock": "boolean, 可选, 默认true"}},
|
||||
{"name": "orbit_around_target", "description": "以目标为中心,等速圆周飞行指定圈数(需已有目标)。", "params": {"target_class": "string, 取值同object_detect列表", "description": "string, 可选", "radius": "float[5,1000]", "laps": "int[1,20]", "clockwise": "boolean, 可选, 默认true", "speed_mps": "float[0.5,15], 可选", "gimbal_lock": "boolean, 可选, 默认true"}},
|
||||
{"name": "loiter", "description": "在当前位置上空悬停一段时间或直到条件触发。", "params": {"duration": "float, 可选[1,600]", "until_condition": "string, 可选"}},
|
||||
{"name": "object_detect", "description": "识别特定目标对象。一般是用户提到的需要检测的目标;如果用户给出了需要探索的目标的优先级,比如蓝色球危险性大于红色球大于绿色球,需要检测最危险的球,此处应给出检测优先级,描述应当为 '蓝>红>绿'", "params": {"target_class": "string, 要识别的目标类别,必须为以下值之一: balloon,person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic_light, fire_hydrant, stop_sign, parking_meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports_ball, kite, baseball_bat, baseball_glove, skateboard, surfboard, tennis_racket, bottle, wine_glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot_dog, pizza, donut, cake, chair, couch, potted_plant, bed, dining_table, toilet, tv, laptop, mouse, remote, keyboard, cell_phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy_bear, hair_drier, toothbrush", "description": "string, 可选", "count": "int, 可选, 默认1"}},
|
||||
{"name": "strike_target", "description": "对已识别目标进行打击。", "params": {"target_class": "string", "description": "string, 可选", "count": "int, 可选, 默认1"}},
|
||||
{"name": "battle_damage_assessment", "description": "战损评估。", "params": {"target_class": "string", "assessment_time": "float[5-60], 默认15.0"}},
|
||||
{"name": "search_pattern", "description": "按模式搜索。", "params": {"pattern_type": "string: spiral|grid", "center_x": "float", "center_y": "float", "center_z": "float", "radius": "float[5,1000]", "target_class": "string", "description": "string, 可选", "count": "int, 可选, 默认1"}},
|
||||
{"name": "track_object", "description": "持续跟踪目标。", "params": {"target_class": "string, 取值同object_detect列表", "description": "string, 可选", "track_time": "float[1,600], 默认30.0", "min_confidence": "float[0.5-1.0], 默认0.7", "safe_distance": "float[2-50], 默认10.0"}},
|
||||
{"name": "deliver_payload", "description": "投放物资。", "params": {"payload_type": "string", "release_altitude": "float[2,100], 默认5.0"}},
|
||||
{"name": "preflight_checks", "description": "飞行前系统自检。", "params": {"check_level": "string: basic|comprehensive"}},
|
||||
{"name": "emergency_return", "description": "执行紧急返航程序。", "params": {"reason": "string"}}
|
||||
],
|
||||
"conditions": [
|
||||
{"name": "battery_above", "description": "电池电量高于阈值。", "params": {"threshold": "float[0.0,1.0]"}},
|
||||
{"name": "at_waypoint", "description": "在指定坐标容差范围内。", "params": {"x": "float", "y": "float", "z": "float", "tolerance": "float, 可选, 默认3.0"}},
|
||||
{"name": "object_detected", "description": "检测到特定目标。", "params": {"target_class": "string", "description": "string, 可选", "count": "int, 可选, 默认1"}},
|
||||
{"name": "target_destroyed", "description": "目标已被摧毁。", "params": {"target_class": "string", "description": "string, 可选", "confidence": "float[0.5-1.0], 默认0.8"}},
|
||||
{"name": "time_elapsed", "description": "时间经过。", "params": {"duration": "float[1,2700]"}},
|
||||
{"name": "gps_status", "description": "GPS状态良好。", "params": {"min_satellites": "int[6,15], 默认10"}}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
—— 参数约束——
|
||||
- takeoff.altitude: [1, 100]
|
||||
- fly_to_waypoint.z: [1, 5000]
|
||||
- fly_to_waypoint.x,y: [-10000, 10000]
|
||||
- search_pattern.radius: [5, 1000]
|
||||
- move_direction.distance: [1, 10000]
|
||||
- orbit_around_point.radius: [5, 1000]
|
||||
- orbit_around_target.radius: [5, 1000]
|
||||
- orbit_around_point/target.laps: [1, 20]
|
||||
- orbit_around_point/target.speed_mps: [0.5, 15]
|
||||
- 若参考知识提供坐标,必须使用并裁剪到约束范围内
|
||||
|
||||
—— 口令转化规则(环绕类)——
|
||||
- “环绕X米Y圈” → 若有目标上下文则使用 `orbit_around_target`,否则根据是否给出中心坐标选择 `orbit_around_point`;`radius=X`,`laps=Y`,默认 `clockwise=true`,`gimbal_lock=true`
|
||||
- “顺时针/逆时针” → `clockwise=true/false`
|
||||
- “等速” → 若未给速度则 `speed_mps` 采用默认值(例如3.0);若口令指明速度,裁剪到[0.5,15]
|
||||
- “以(x,y,z)为中心”/“当前位置为中心” → 选择 `orbit_around_point` 并填充 `center_x/center_y/center_z`
|
||||
@@ -38,6 +38,41 @@
|
||||
"acceptance_radius": "float, 可选,到达容差半径(米),默认2.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "move_direction",
|
||||
"description": "按指定方向直线移动。方向可为绝对方位或相对机体朝向。",
|
||||
"params": {
|
||||
"direction": "string, 取值: 'north'(北), 'south'(南), 'east'(东), 'west'(西), 'forward'(前), 'backward'(后), 'left'(左), 'right'(右)",
|
||||
"distance": "float, 可选,移动距离(米)[1,10000];缺省表示持续移动,直至外部条件停止"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "orbit_around_point",
|
||||
"description": "以给定中心点为中心,等速圆周飞行指定圈数。",
|
||||
"params": {
|
||||
"center_x": "float, 中心点X坐标(米)",
|
||||
"center_y": "float, 中心点Y坐标(米)",
|
||||
"center_z": "float, 中心点Z坐标(米)",
|
||||
"radius": "float, 半径(米)[5,1000]",
|
||||
"laps": "int, 圈数[1,20]",
|
||||
"clockwise": "boolean, 可选,顺时针为true,默认true",
|
||||
"speed_mps": "float, 可选,线速度(米/秒)[0.5,15]",
|
||||
"gimbal_lock": "boolean, 可选,云台持续指向中心,默认true"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "orbit_around_target",
|
||||
"description": "以目标为中心,等速圆周飞行指定圈数。需要已确认目标。",
|
||||
"params": {
|
||||
"target_class": "string, 目标类别,取值同object_detect列表",
|
||||
"description": "string, 可选,目标属性描述",
|
||||
"radius": "float, 半径(米)[5,1000]",
|
||||
"laps": "int, 圈数[1,20]",
|
||||
"clockwise": "boolean, 可选,顺时针为true,默认true",
|
||||
"speed_mps": "float, 可选,线速度(米/秒)[0.5,15]",
|
||||
"gimbal_lock": "boolean, 可选,云台持续指向目标,默认true"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "loiter",
|
||||
"description": "在当前位置上空悬停一段时间或直到条件触发。",
|
||||
@@ -50,8 +85,8 @@
|
||||
"name": "object_detect",
|
||||
"description": "在当前视野范围内识别特定目标对象。适用于定点检测,无人机应在目标大致位置悬停或保持稳定姿态。",
|
||||
"params": {
|
||||
"target_class": "string, 要识别的目标类别,必须为以下值之一: person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic_light, fire_hydrant, stop_sign, parking_meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports_ball, kite, baseball_bat, baseball_glove, skateboard, surfboard, tennis_racket, bottle, wine_glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot_dog, pizza, donut, cake, chair, couch, potted_plant, bed, dining_table, toilet, tv, laptop, mouse, remote, keyboard, cell_phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy_bear, hair_drier, toothbrush",
|
||||
"description": "string, 可选,目标属性描述(如颜色、状态等)",
|
||||
"target_class": "string, 要识别的目标类别,必须为以下值之一: balloon,person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic_light, fire_hydrant, stop_sign, parking_meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports_ball, kite, baseball_bat, baseball_glove, skateboard, surfboard, tennis_racket, bottle, wine_glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot_dog, pizza, donut, cake, chair, couch, potted_plant, bed, dining_table, toilet, tv, laptop, mouse, remote, keyboard, cell_phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy_bear, hair_drier, toothbrush",
|
||||
"description": "string, 可选,目标属性描述(如颜色、状态等),一般是用户提到的需要检测的目标;如果用户给出了需要探索的目标的优先级,比如蓝色球危险性大于红色球大于绿色球,需要检测最危险的球,此处应给出检测优先级,描述应当为 '蓝>红>绿'",
|
||||
"count": "int, 可选,需要检测的目标个数,默认1"
|
||||
}
|
||||
},
|
||||
@@ -323,6 +358,10 @@
|
||||
- `z` (fly_to_waypoint): [1, 5000]
|
||||
- `x`, `y` (fly_to_waypoint): [-10000, 10000]
|
||||
- `radius` (search_pattern): [5, 1000]
|
||||
- `distance` (move_direction): [1, 10000]
|
||||
- `radius` (orbit_around_point/orbit_around_target): [5, 1000]
|
||||
- `laps` (orbit_around_point/orbit_around_target): [1, 20]
|
||||
- `speed_mps` (orbit_around_point/orbit_around_target): [0.5, 15]
|
||||
- 电池阈值: [0.0, 1.0]
|
||||
- 等等其他参数范围。
|
||||
|
||||
@@ -531,7 +570,7 @@
|
||||
},
|
||||
{
|
||||
"type": "Selector",
|
||||
"name": "SafetyMonitorSelector",
|
||||
"name": "SafetyMonitor",
|
||||
"params": {
|
||||
"memory": true
|
||||
},
|
||||
@@ -552,7 +591,7 @@
|
||||
},
|
||||
{
|
||||
"type": "Sequence",
|
||||
"name": "EmergencyProcedureSequence",
|
||||
"name": "EmergencyHandler",
|
||||
"children": [
|
||||
{
|
||||
"type": "action",
|
||||
@@ -711,5 +750,12 @@
|
||||
- 如果参考知识说"目标坐标: (x: 120.5, y: 80.2, z: 60.0)",则在使用`fly_to_waypoint`时设置这些值。
|
||||
- 确保坐标符合约束(如z≥1)。
|
||||
|
||||
环绕口令到参数的映射规则(当口令涉及“环绕/绕圈”等):
|
||||
- “环绕XY圈” → `radius=X`, `laps=Y`,默认 `clockwise=true`, `gimbal_lock=true`,比如环绕三十两圈,意思就是以目标点为圆心,30米为半径绕2圈
|
||||
- 明确“顺时针/逆时针”时 → 设置 `clockwise=true/false`
|
||||
- 出现“等速”时 → 若未给速度则 `speed_mps` 使用默认值(如3.0);若口令给出速度,裁剪到[0.5,15]
|
||||
- “以(中心坐标)为中心/当前位置为中心” → 使用 `orbit_around_point` 并填写 `center_x/center_y/center_z`
|
||||
- “以目标为中心/围绕目标” → 使用 `orbit_around_target`;若任务未提供目标来源,则需要在主任务中先行确认目标(通过检测/跟踪或参考知识)
|
||||
|
||||
#### 11. 输出要求
|
||||
你的输出必须是严格的、单一的JSON对象,符合上述所有规则。不包含任何自然语言描述。
|
||||
@@ -201,7 +201,7 @@ def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> di
|
||||
"sandwich", "orange", "broccoli", "carrot", "hot_dog", "pizza", "donut", "cake", "chair",
|
||||
"couch", "potted_plant", "bed", "dining_table", "toilet", "tv", "laptop", "mouse", "remote",
|
||||
"keyboard", "cell_phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
|
||||
"clock", "vase", "scissors", "teddy_bear", "hair_drier", "toothbrush"
|
||||
"clock", "vase", "scissors", "teddy_bear", "hair_drier", "toothbrush","balloon"
|
||||
]
|
||||
|
||||
# 递归节点定义
|
||||
@@ -335,6 +335,32 @@ def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> di
|
||||
|
||||
return schema
|
||||
|
||||
def _generate_simple_mode_schema(allowed_actions: set) -> dict:
|
||||
"""
|
||||
生成简单模式JSON Schema:{"mode":"simple","action":{...}}
|
||||
仅校验动作名称在允许集合内,以及基本结构完整性;参数按对象形状放宽,由上游提示词与运行时再约束。
|
||||
"""
|
||||
schema = {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "SimpleMode",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"mode": {"type": "string", "const": "simple"},
|
||||
"action": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "enum": sorted(list(allowed_actions))},
|
||||
"params": {"type": "object"}
|
||||
},
|
||||
"required": ["name"],
|
||||
"additionalProperties": True
|
||||
}
|
||||
},
|
||||
"required": ["mode", "action"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
return schema
|
||||
|
||||
def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
|
||||
"""
|
||||
使用JSON Schema验证给定的Pytree实例。
|
||||
@@ -522,13 +548,41 @@ class PyTreeGenerator:
|
||||
# Updated output directory for visualizations
|
||||
self.vis_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_visualizations'))
|
||||
os.makedirs(self.vis_dir, exist_ok=True)
|
||||
self.system_prompt = self._load_prompt("system_prompt.txt")
|
||||
# Reasoning content output directory (Markdown files)
|
||||
self.reasoning_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_reasoning_content'))
|
||||
os.makedirs(self.reasoning_dir, exist_ok=True)
|
||||
# 控制是否允许模型返回含 <think> 的原文(不强制JSON),以便提取推理链
|
||||
self.enable_reasoning_capture = os.getenv("ENABLE_REASONING_CAPTURE", "true").lower() in ("1", "true", "yes")
|
||||
# 终端预览的最大行数
|
||||
try:
|
||||
self.reasoning_preview_lines = int(os.getenv("REASONING_PREVIEW_LINES", "20"))
|
||||
except Exception:
|
||||
self.reasoning_preview_lines = 20
|
||||
# 加载提示词:复杂模式复用现有 system_prompt.txt;简单模式与分类器独立提示词
|
||||
self.complex_prompt = self._load_prompt("system_prompt.txt")
|
||||
self.simple_prompt = self._load_prompt("simple_mode_prompt.txt")
|
||||
self.classifier_prompt = self._load_prompt("classifier_prompt.txt")
|
||||
# 兼容旧变量名
|
||||
self.system_prompt = self.complex_prompt
|
||||
|
||||
self.orin_ip = os.getenv("ORIN_IP", "localhost")
|
||||
self.llm_client = openai.OpenAI(
|
||||
api_key=os.getenv("OPENAI_API_KEY", "sk-no-key-required"),
|
||||
base_url=f"http://{self.orin_ip}:8081/v1"
|
||||
)
|
||||
# 三类模型的可配置项:基于不同模型与Base URL分流
|
||||
self.classifier_model = os.getenv("CLASSIFIER_MODEL", os.getenv("OPENAI_MODEL", "local-model"))
|
||||
self.simple_model = os.getenv("SIMPLE_MODEL", os.getenv("OPENAI_MODEL", "local-model"))
|
||||
self.complex_model = os.getenv("COMPLEX_MODEL", os.getenv("OPENAI_MODEL", "local-model"))
|
||||
self.classifier_base_url = os.getenv("CLASSIFIER_BASE_URL", f"http://{self.orin_ip}:8081/v1")
|
||||
self.simple_base_url = os.getenv("SIMPLE_BASE_URL", f"http://{self.orin_ip}:8081/v1")
|
||||
self.complex_base_url = os.getenv("COMPLEX_BASE_URL", f"http://{self.orin_ip}:8081/v1")
|
||||
self.api_key = os.getenv("OPENAI_API_KEY", "sk-no-key-required")
|
||||
# 直接在代码中指定最大输出token数(不通过环境变量)
|
||||
self.classifier_max_tokens = 512
|
||||
self.simple_max_tokens = 8192
|
||||
self.complex_max_tokens = 8192
|
||||
|
||||
# 为不同用途分别创建客户端
|
||||
self.classifier_client = openai.OpenAI(api_key=self.api_key, base_url=self.classifier_base_url)
|
||||
self.simple_llm_client = openai.OpenAI(api_key=self.api_key, base_url=self.simple_base_url)
|
||||
self.complex_llm_client = openai.OpenAI(api_key=self.api_key, base_url=self.complex_base_url)
|
||||
|
||||
# --- ChromaDB Client Setup ---
|
||||
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'vector_store'))
|
||||
@@ -542,8 +596,10 @@ class PyTreeGenerator:
|
||||
embedding_function=embedding_func
|
||||
)
|
||||
|
||||
allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.system_prompt)
|
||||
# 使用复杂模式提示词作为节点来源,确保Schema稳定
|
||||
allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.complex_prompt)
|
||||
self.schema = _generate_pytree_schema(allowed_actions, allowed_conditions)
|
||||
self.simple_schema = _generate_simple_mode_schema(allowed_actions)
|
||||
|
||||
def _load_prompt(self, file_name: str) -> str:
|
||||
try:
|
||||
@@ -574,14 +630,40 @@ class PyTreeGenerator:
|
||||
"""
|
||||
logging.info(f"接收到用户请求: {user_prompt}")
|
||||
|
||||
retrieved_context = self._retrieve_context(user_prompt)
|
||||
|
||||
# 第一步:分类(简单/复杂)
|
||||
mode = "complex"
|
||||
try:
|
||||
classifier_resp = self.classifier_client.chat.completions.create(
|
||||
model=self.classifier_model,
|
||||
messages=[
|
||||
{"role": "system", "content": self.classifier_prompt or "你是一个分类器,只输出JSON。"},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
temperature=0.0,
|
||||
response_format={"type": "json_object"},
|
||||
max_tokens=self.classifier_max_tokens
|
||||
)
|
||||
class_str = classifier_resp.choices[0].message.content
|
||||
class_obj = json.loads(class_str)
|
||||
if isinstance(class_obj, dict) and class_obj.get("mode") in ("simple", "complex"):
|
||||
mode = class_obj.get("mode")
|
||||
logging.info(f"分类结果: {mode}")
|
||||
except Exception as e:
|
||||
logging.warning(f"分类失败,默认按复杂指令处理: {e}")
|
||||
|
||||
# 第二步:根据模式准备提示词与上下文(简单与复杂都执行检索增强)
|
||||
# 基于模式选择提示词;复杂模式追加一条强制规则,避免模型误输出简单结构
|
||||
use_prompt = self.simple_prompt if mode == "simple" else (
|
||||
(self.complex_prompt or "") +
|
||||
"\n\n【强制规则】仅生成包含root的复杂行为树JSON,不得输出简单模式(不得包含mode字段或仅有action节点)。"
|
||||
)
|
||||
final_user_prompt = user_prompt
|
||||
retrieved_context = self._retrieve_context(user_prompt)
|
||||
if retrieved_context:
|
||||
augmentation = (
|
||||
"\n\n---\n"
|
||||
"参考知识:\n"
|
||||
"以下是从知识库中检索到的、与当前任务最相关的信息,请优先参考这些信息来生成行为树:\n"
|
||||
"以下是从知识库中检索到的、与当前任务最相关的信息,请优先参考这些信息来生成结果:\n"
|
||||
f"{retrieved_context}"
|
||||
"\n---"
|
||||
)
|
||||
@@ -591,17 +673,178 @@ class PyTreeGenerator:
|
||||
for attempt in range(3):
|
||||
logging.info(f"--- 第 {attempt + 1}/3 次尝试生成Pytree ---")
|
||||
try:
|
||||
response = self.llm_client.chat.completions.create(
|
||||
model="local-model",
|
||||
messages=[
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
# 简单/复杂分流到不同模型与提示词
|
||||
client = self.simple_llm_client if mode == "simple" else self.complex_llm_client
|
||||
model_name = self.simple_model if mode == "simple" else self.complex_model
|
||||
# 根据是否捕获推理链来决定是否强制JSON响应
|
||||
response_kwargs = {
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{"role": "system", "content": use_prompt},
|
||||
{"role": "user", "content": final_user_prompt}
|
||||
],
|
||||
temperature=0.1,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
pytree_str = response.choices[0].message.content
|
||||
pytree_dict = json.loads(pytree_str)
|
||||
"temperature": 0.1 if mode == "complex" else 0.0,
|
||||
}
|
||||
if not self.enable_reasoning_capture:
|
||||
response_kwargs["response_format"] = {"type": "json_object"}
|
||||
# 基于模式设定最大输出token数(直接在代码中配置)
|
||||
response_kwargs["max_tokens"] = self.simple_max_tokens if mode == "simple" else self.complex_max_tokens
|
||||
response = client.chat.completions.create(**response_kwargs)
|
||||
# 兼容可能存在的 reasoning_content 字段
|
||||
try:
|
||||
msg = response.choices[0].message
|
||||
msg_content = getattr(msg, "content", None)
|
||||
msg_reasoning = getattr(msg, "reasoning_content", None)
|
||||
except Exception:
|
||||
msg = response.choices[0]["message"] if isinstance(response.choices[0], dict) else None
|
||||
msg_content = (msg or {}).get("content") if isinstance(msg, dict) else None
|
||||
msg_reasoning = (msg or {}).get("reasoning_content") if isinstance(msg, dict) else None
|
||||
|
||||
combined_text = ""
|
||||
if isinstance(msg_reasoning, str) and msg_reasoning.strip():
|
||||
# 将 reasoning_content 包装为 <think>,便于统一解析
|
||||
combined_text += f"<think>\n{msg_reasoning}\n</think>\n"
|
||||
if isinstance(msg_content, str) and msg_content.strip():
|
||||
combined_text += msg_content
|
||||
pytree_str = combined_text if combined_text else (msg_content or "")
|
||||
raw_full_text_for_logging = pytree_str # 保存完整原文(含 <think>)以便失败时完整打印
|
||||
|
||||
# 提取 <think> 推理链内容(若存在)
|
||||
reasoning_text = None
|
||||
try:
|
||||
think_match = re.search(r"<think>([\s\S]*?)</think>", pytree_str)
|
||||
if think_match:
|
||||
reasoning_text = think_match.group(1).strip()
|
||||
# 去除推理文本后再尝试解析JSON
|
||||
pytree_str = re.sub(r"<think>[\s\S]*?</think>", "", pytree_str).strip()
|
||||
except Exception:
|
||||
reasoning_text = None
|
||||
# 单独捕获JSON解析错误并打印原始响应
|
||||
try:
|
||||
pytree_dict = json.loads(pytree_str)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"❌ JSON解析失败(第 {attempt + 1}/3 次)。\n—— 完整原始文本(含<think>) ——\n{raw_full_text_for_logging}")
|
||||
# 尝试打印响应对象的完整结构
|
||||
try:
|
||||
raw_response_dump = None
|
||||
if hasattr(response, 'model_dump_json'):
|
||||
raw_response_dump = response.model_dump_json(indent=2, exclude_none=False)
|
||||
elif hasattr(response, 'dict'):
|
||||
raw_response_dump = json.dumps(response.dict(), ensure_ascii=False, indent=2, default=str)
|
||||
else:
|
||||
# 兜底:尝试将choices与关键字段展开
|
||||
safe_obj = {
|
||||
"id": getattr(response, 'id', None),
|
||||
"model": getattr(response, 'model', None),
|
||||
"object": getattr(response, 'object', None),
|
||||
"usage": getattr(response, 'usage', None),
|
||||
"choices": [
|
||||
{
|
||||
"index": getattr(c, 'index', None),
|
||||
"finish_reason": getattr(c, 'finish_reason', None),
|
||||
"message": {
|
||||
"role": getattr(getattr(c, 'message', None), 'role', None),
|
||||
"content": getattr(getattr(c, 'message', None), 'content', None),
|
||||
"reasoning_content": getattr(getattr(c, 'message', None), 'reasoning_content', None)
|
||||
} if getattr(c, 'message', None) is not None else None
|
||||
}
|
||||
for c in getattr(response, 'choices', [])
|
||||
] if hasattr(response, 'choices') else None
|
||||
}
|
||||
raw_response_dump = json.dumps(safe_obj, ensure_ascii=False, indent=2, default=str)
|
||||
logging.error(f"—— 完整响应对象 ——\n{raw_response_dump}")
|
||||
except Exception as dump_e:
|
||||
try:
|
||||
logging.error(f"响应对象转储失败,repr如下:\n{repr(response)}")
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
# 简单/复杂分别验证与返回
|
||||
if mode == "simple":
|
||||
try:
|
||||
jsonschema.validate(instance=pytree_dict, schema=self.simple_schema)
|
||||
logging.info("✅ 简单模式JSON Schema验证成功")
|
||||
except jsonschema.ValidationError as e:
|
||||
logging.warning(f"❌ 简单模式验证失败: {e.message}")
|
||||
continue
|
||||
# 附加元信息并生成简单可视化(单动作)
|
||||
plan_id = str(uuid.uuid4())
|
||||
pytree_dict['plan_id'] = plan_id
|
||||
# 简单模式可视化:构造一个简化节点图
|
||||
try:
|
||||
vis_filename = "py_tree.png"
|
||||
vis_path = os.path.join(self.vis_dir, vis_filename)
|
||||
simple_node = {
|
||||
"type": "action",
|
||||
"name": pytree_dict.get('action', {}).get('name', 'action'),
|
||||
"params": pytree_dict.get('action', {}).get('params', {})
|
||||
}
|
||||
_visualize_pytree(simple_node, os.path.splitext(vis_path)[0])
|
||||
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
|
||||
except Exception as e:
|
||||
logging.warning(f"简单模式可视化失败: {e}")
|
||||
|
||||
# 保存推理链(若有)
|
||||
try:
|
||||
if reasoning_text:
|
||||
reasoning_path = os.path.join(self.reasoning_dir, "reasoning_content.md")
|
||||
with open(reasoning_path, 'w', encoding='utf-8') as rf:
|
||||
rf.write(reasoning_text)
|
||||
logging.info(f"📝 推理链已保存: {reasoning_path}")
|
||||
# 终端预览(最多N行)
|
||||
try:
|
||||
lines = reasoning_text.splitlines()
|
||||
preview = "\n".join(lines[: self.reasoning_preview_lines])
|
||||
logging.info("🧠 推理链预览(前%d行):\n%s", self.reasoning_preview_lines, preview)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
logging.info("未在模型输出中发现 <think> 推理链片段。若需捕获,请设置 ENABLE_REASONING_CAPTURE=true 以放宽JSON强制格式。")
|
||||
except Exception as e:
|
||||
logging.warning(f"保存推理链Markdown失败: {e}")
|
||||
return pytree_dict
|
||||
|
||||
# 复杂模式回退:若模型误返回简单结构,则自动包装为含安全监控的行为树
|
||||
if mode == "complex" and isinstance(pytree_dict, dict) and 'root' not in pytree_dict:
|
||||
try:
|
||||
jsonschema.validate(instance=pytree_dict, schema=self.simple_schema)
|
||||
logging.warning("⚠️ 复杂模式生成了简单结构,触发自动包装为完整行为树的回退逻辑。")
|
||||
simple_action_obj = pytree_dict.get('action') or {}
|
||||
action_name = simple_action_obj.get('name')
|
||||
action_params = simple_action_obj.get('params') if isinstance(simple_action_obj.get('params'), dict) else {}
|
||||
|
||||
safety_selector = {
|
||||
"type": "Selector",
|
||||
"name": "SafetyMonitor",
|
||||
"params": {"memory": True},
|
||||
"children": [
|
||||
{"type": "condition", "name": "battery_above", "params": {"threshold": 0.3}},
|
||||
{"type": "condition", "name": "gps_status", "params": {"min_satellites": 8}},
|
||||
{"type": "Sequence", "name": "EmergencyHandler", "children": [
|
||||
{"type": "action", "name": "emergency_return", "params": {"reason": "safety_breach"}},
|
||||
{"type": "action", "name": "land", "params": {"mode": "home"}}
|
||||
]}
|
||||
]
|
||||
}
|
||||
|
||||
main_children = [{"type": "action", "name": action_name, "params": action_params}]
|
||||
if action_name != "land":
|
||||
main_children.append({"type": "action", "name": "land", "params": {"mode": "home"}})
|
||||
|
||||
root_parallel = {
|
||||
"type": "Parallel",
|
||||
"name": "MissionWithSafety",
|
||||
"params": {"policy": "all_success"},
|
||||
"children": [
|
||||
{"type": "Sequence", "name": "MainTask", "children": main_children},
|
||||
safety_selector
|
||||
]
|
||||
}
|
||||
pytree_dict = {"root": root_parallel}
|
||||
except jsonschema.ValidationError:
|
||||
# 不符合简单结构,按正常复杂验证继续
|
||||
pass
|
||||
if _validate_pytree_with_schema(pytree_dict, self.schema):
|
||||
logging.info("✅ 成功生成并验证了Pytree")
|
||||
plan_id = str(uuid.uuid4())
|
||||
@@ -612,10 +855,32 @@ class PyTreeGenerator:
|
||||
vis_path = os.path.join(self.vis_dir, vis_filename)
|
||||
_visualize_pytree(pytree_dict['root'], os.path.splitext(vis_path)[0])
|
||||
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
|
||||
|
||||
# 保存推理链(若有)
|
||||
try:
|
||||
if reasoning_text:
|
||||
reasoning_path = os.path.join(self.reasoning_dir, "reasoning_content.md")
|
||||
with open(reasoning_path, 'w', encoding='utf-8') as rf:
|
||||
rf.write(reasoning_text)
|
||||
logging.info(f"📝 推理链已保存: {reasoning_path}")
|
||||
# 终端预览(最多N行)
|
||||
try:
|
||||
lines = reasoning_text.splitlines()
|
||||
preview = "\n".join(lines[: self.reasoning_preview_lines])
|
||||
logging.info("🧠 推理链预览(前%d行):\n%s", self.reasoning_preview_lines, preview)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
logging.info("未在模型输出中发现 <think> 推理链片段。若需捕获,请设置 ENABLE_REASONING_CAPTURE=true 以放宽JSON强制格式。")
|
||||
except Exception as e:
|
||||
logging.warning(f"保存推理链Markdown失败: {e}")
|
||||
return pytree_dict
|
||||
else:
|
||||
# 打印未通过验证的Pytree以便排查
|
||||
preview = json.dumps(pytree_dict, ensure_ascii=False, indent=2)
|
||||
logging.warning(f"❌ 未通过验证的Pytree(第 {attempt + 1}/3 次尝试):\n{preview}")
|
||||
logging.warning("生成的Pytree验证失败,正在重试...")
|
||||
except (OpenAIError, json.JSONDecodeError) as e:
|
||||
except OpenAIError as e:
|
||||
logging.error(f"生成Pytree时发生错误: {e}")
|
||||
|
||||
raise RuntimeError("在3次尝试后,仍未能生成一个有效的Pytree。")
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -12,7 +12,7 @@ BASE_URL = "http://127.0.0.1:8000"
|
||||
ENDPOINT = "/generate_plan"
|
||||
|
||||
# The user prompt we will send for the test
|
||||
TEST_PROMPT = "起飞后移动到学生宿舍上方搜索蓝色车辆,并进行打击"
|
||||
TEST_PROMPT = "已知目标检测红色气球危险性高于蓝色气球高于绿色气球,飞往搜索区搜索并锁定危险性最高的气球,对其跟踪30秒后进行打击操作"
|
||||
|
||||
def test_generate_plan():
|
||||
"""
|
||||
|
||||
174
tools/test_llama_server.py
Normal file
174
tools/test_llama_server.py
Normal file
@@ -0,0 +1,174 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def build_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="调用本地 llama-server (OpenAI兼容) 进行推理,支持自定义系统/用户提示词"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
default=os.getenv("SIMPLE_BASE_URL", "http://127.0.0.1:8081/v1"),
|
||||
help="llama-server 的基础URL(默认: http://127.0.0.1:8081/v1,或环境变量 SIMPLE_BASE_URL)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default=os.getenv("SIMPLE_MODEL", "local-model"),
|
||||
help="模型名称(默认: local-model,或环境变量 SIMPLE_MODEL)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system",
|
||||
default="You are a helpful assistant.",
|
||||
help="系统提示词(system role)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system-file",
|
||||
default=None,
|
||||
help="系统提示词文件路径(txt);若提供,则覆盖 --system 的字符串",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--user",
|
||||
default=None,
|
||||
help="用户提示词(user role);若不传则从交互式输入读取",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="采样温度(默认: 0.2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="最大生成Token数(默认: 4096)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=120.0,
|
||||
help="HTTP超时时间秒(默认: 120)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="打印完整返回JSON",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def call_llama_server(
|
||||
base_url: str,
|
||||
model: str,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
timeout: float,
|
||||
) -> Dict[str, Any]:
|
||||
endpoint = base_url.rstrip("/") + "/chat/completions"
|
||||
headers: Dict[str, str] = {"Content-Type": "application/json"}
|
||||
|
||||
# 兼容需要API Key的代理/服务(llama-server通常不强制)
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
resp = requests.post(endpoint, headers=headers, data=json.dumps(payload), timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = build_args()
|
||||
|
||||
user_prompt = args.user
|
||||
if not user_prompt:
|
||||
try:
|
||||
user_prompt = input("请输入用户提示词: ")
|
||||
except KeyboardInterrupt:
|
||||
print("\n已取消。")
|
||||
sys.exit(1)
|
||||
|
||||
# 解析系统提示词:优先使用 --system-file
|
||||
system_prompt = args.system
|
||||
if args.system_file:
|
||||
try:
|
||||
with open(args.system_file, "r", encoding="utf-8") as f:
|
||||
system_prompt = f.read()
|
||||
except Exception as e:
|
||||
print("\n❌ 读取系统提示词文件失败:")
|
||||
print(str(e))
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
print("--- llama-server 推理 ---")
|
||||
print(f"Base URL: {args.base_url}")
|
||||
print(f"Model: {args.model}")
|
||||
if args.system_file:
|
||||
print(f"System(from file): {args.system_file}")
|
||||
else:
|
||||
print(f"System: {system_prompt}")
|
||||
print(f"User: {user_prompt}")
|
||||
|
||||
data = call_llama_server(
|
||||
base_url=args.base_url,
|
||||
model=args.model,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
temperature=args.temperature,
|
||||
max_tokens=args.max_tokens,
|
||||
timeout=args.timeout,
|
||||
)
|
||||
|
||||
if args.verbose:
|
||||
print("\n完整返回JSON:")
|
||||
print(json.dumps(data, ensure_ascii=False, indent=2))
|
||||
|
||||
# 尝试按OpenAI兼容格式提取assistant内容
|
||||
content = None
|
||||
try:
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if content is not None:
|
||||
print("\n模型输出:")
|
||||
print(content)
|
||||
else:
|
||||
# 兜底打印
|
||||
print("\n无法按OpenAI兼容字段解析内容,原始返回如下:")
|
||||
print(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print("\n❌ 请求失败:请确认 llama-server 已在 8081 端口启动并可访问。")
|
||||
print(f"详情: {e}")
|
||||
sys.exit(2)
|
||||
except Exception as e:
|
||||
print("\n❌ 发生未预期的错误:")
|
||||
print(str(e))
|
||||
sys.exit(3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user