优化简单模式支持
This commit is contained in:
52
README.md
52
README.md
@@ -13,6 +13,10 @@
|
|||||||
│ │ ├── __init__.py
|
│ │ ├── __init__.py
|
||||||
│ │ ├── main.py # 应用主入口,提供Web API
|
│ │ ├── main.py # 应用主入口,提供Web API
|
||||||
│ │ ├── py_tree_generator.py # RAG与LLM集成,生成py_tree
|
│ │ ├── py_tree_generator.py # RAG与LLM集成,生成py_tree
|
||||||
|
│ │ ├── prompts/ # LLM 提示词
|
||||||
|
│ │ │ ├── system_prompt.txt # 复杂模式提示词(行为树与安全监控)
|
||||||
|
│ │ │ ├── simple_mode_prompt.txt # 简单模式提示词(单一原子动作JSON)
|
||||||
|
│ │ │ └── classifier_prompt.txt # 指令简单/复杂分类提示词
|
||||||
│ │ ├── ...
|
│ │ ├── ...
|
||||||
│ ├── generated_visualizations/ # 存放最新生成的py_tree可视化图像
|
│ ├── generated_visualizations/ # 存放最新生成的py_tree可视化图像
|
||||||
│ └── requirements.txt # 后端服务的Python依赖
|
│ └── requirements.txt # 后端服务的Python依赖
|
||||||
@@ -67,6 +71,54 @@
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## 指令分类与分流
|
||||||
|
|
||||||
|
后端在生成任务前会先对用户指令进行“简单/复杂”分类,并分流到不同提示词与模型:
|
||||||
|
|
||||||
|
- 分类提示词:`backend_service/src/prompts/classifier_prompt.txt`
|
||||||
|
- 简单模式提示词:`backend_service/src/prompts/simple_mode_prompt.txt`
|
||||||
|
- 复杂模式提示词:`backend_service/src/prompts/system_prompt.txt`
|
||||||
|
|
||||||
|
分类仅输出如下JSON之一:`{"mode":"simple"}` 或 `{"mode":"complex"}`。
|
||||||
|
|
||||||
|
当为简单模式时,LLM仅输出:
|
||||||
|
`{"mode":"simple","action":{"name":"<action>","params":{...}}}`。
|
||||||
|
生成端会自动将该动作封装为带安全监控的最小行为树(根 `Parallel` 并行安全监控),以保持与现有Schema和可视化兼容。
|
||||||
|
|
||||||
|
### 环境变量(可选)
|
||||||
|
|
||||||
|
支持为“分类/简单/复杂”三类调用分别配置模型与Base URL(未设置时回退到默认本地配置):
|
||||||
|
|
||||||
|
- `CLASSIFIER_MODEL`, `CLASSIFIER_BASE_URL`
|
||||||
|
- `SIMPLE_MODEL`, `SIMPLE_BASE_URL`
|
||||||
|
- `COMPLEX_MODEL`, `COMPLEX_BASE_URL`
|
||||||
|
|
||||||
|
通用API Key:`OPENAI_API_KEY`
|
||||||
|
|
||||||
|
示例:
|
||||||
|
```bash
|
||||||
|
export CLASSIFIER_MODEL="qwen2.5-1.8b-instruct"
|
||||||
|
export SIMPLE_MODEL="qwen2.5-1.8b-instruct"
|
||||||
|
export COMPLEX_MODEL="qwen2.5-7b-instruct"
|
||||||
|
export CLASSIFIER_BASE_URL="http://$ORIN_IP:8081/v1"
|
||||||
|
export SIMPLE_BASE_URL="http://$ORIN_IP:8081/v1"
|
||||||
|
export COMPLEX_BASE_URL="http://$ORIN_IP:8081/v1"
|
||||||
|
export OPENAI_API_KEY="sk-no-key-required"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试简单模式
|
||||||
|
|
||||||
|
启动服务后,运行内置测试脚本:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd tools
|
||||||
|
python test_api.py
|
||||||
|
```
|
||||||
|
|
||||||
|
示例输入:“简单模式,起飞” 或 “起飞到10米”。返回结果将是完整的带安全并行监控的行为树(`root` + `plan_id` + `visualization_url`)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 工作流程
|
## 工作流程
|
||||||
|
|
||||||
整个系统的工作流程分为两个主要阶段:
|
整个系统的工作流程分为两个主要阶段:
|
||||||
|
|||||||
24
backend_service/src/prompts/classifier_prompt.txt
Normal file
24
backend_service/src/prompts/classifier_prompt.txt
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
你是一个严格的任务分类器。只输出一个JSON对象,不要输出解释或多余文本。
|
||||||
|
根据用户指令与下述可用节点定义,判断其为“简单”或“复杂”。
|
||||||
|
|
||||||
|
- 简单:单一原子动作即可完成(例如“起飞”“飞机自检”“移动到某地(已给定坐标)”等),且无需行为树与安全并行监控。
|
||||||
|
- 复杂:需要多步流程、搜索/检测/跟踪/评估、战损确认、或需要模板化任务结构与安全并行监控。
|
||||||
|
|
||||||
|
输出格式(严格遵守):
|
||||||
|
{"mode":"simple"} 或 {"mode":"complex"}
|
||||||
|
|
||||||
|
—— 可用节点定义(与复杂模式保持一致,供分类参考)——
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"actions": [
|
||||||
|
{"name": "takeoff"}, {"name": "land"}, {"name": "fly_to_waypoint"}, {"name": "loiter"},
|
||||||
|
{"name": "object_detect"}, {"name": "strike_target"}, {"name": "battle_damage_assessment"},
|
||||||
|
{"name": "search_pattern"}, {"name": "track_object"}, {"name": "deliver_payload"},
|
||||||
|
{"name": "preflight_checks"}, {"name": "emergency_return"}
|
||||||
|
],
|
||||||
|
"conditions": [
|
||||||
|
{"name": "battery_above"}, {"name": "at_waypoint"}, {"name": "object_detected"},
|
||||||
|
{"name": "target_destroyed"}, {"name": "time_elapsed"}, {"name": "gps_status"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
49
backend_service/src/prompts/simple_mode_prompt.txt
Normal file
49
backend_service/src/prompts/simple_mode_prompt.txt
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
你是一个无人机简单指令执行规划器。你的任务:当用户给出“简单指令”(单一原子动作即可完成)时,输出一个严格的JSON对象。
|
||||||
|
|
||||||
|
输出要求(必须遵守):
|
||||||
|
- 只输出一个JSON对象,不要任何解释或多余文本。
|
||||||
|
- JSON结构:
|
||||||
|
{"mode":"simple","action":{"name":"<action_name>","params":{...}}}
|
||||||
|
- <action_name> 与参数定义、取值范围,必须与“复杂模式”提示词(system_prompt.txt)中的定义完全一致。
|
||||||
|
- 简单模式下不包含任何行为树结构与安全监控并行,仅输出单一原子动作。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
- “起飞到10米” → {"mode":"simple","action":{"name":"takeoff","params":{"altitude":10.0}}}
|
||||||
|
- “移动到(120,80,20)” → {"mode":"simple","action":{"name":"fly_to_waypoint","params":{"x":120.0,"y":80.0,"z":20.0,"acceptance_radius":2.0}}}
|
||||||
|
- “飞机自检” → {"mode":"simple","action":{"name":"preflight_checks","params":{"check_level":"comprehensive"}}}
|
||||||
|
|
||||||
|
—— 可用节点定义(与复杂模式保持一致,逐字遵守)——
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"actions": [
|
||||||
|
{"name": "takeoff", "description": "无人机从当前位置垂直起飞到指定的海拔高度。", "params": {"altitude": "float, 目标海拔高度(米),范围[1, 100],默认为2"}},
|
||||||
|
{"name": "land", "description": "降落无人机。可选择当前位置或返航点降落。", "params": {"mode": "string, 可选值: 'current'(当前位置), 'home'(返航点)"}},
|
||||||
|
{"name": "fly_to_waypoint", "description": "导航至一个指定坐标点。使用相对坐标系(x,y,z),单位为米。", "params": {"x": "float", "y": "float", "z": "float", "acceptance_radius": "float, 可选,默认2.0"}},
|
||||||
|
{"name": "loiter", "description": "在当前位置上空悬停一段时间或直到条件触发。", "params": {"duration": "float, 可选[1,600]", "until_condition": "string, 可选"}},
|
||||||
|
{"name": "object_detect", "description": "识别特定目标对象。", "params": {"target_class": "string, 见复杂模式定义列表", "description": "string, 可选", "count": "int, 可选, 默认1"}},
|
||||||
|
{"name": "strike_target", "description": "对已识别目标进行打击。", "params": {"target_class": "string", "description": "string, 可选", "count": "int, 可选, 默认1"}},
|
||||||
|
{"name": "battle_damage_assessment", "description": "战损评估。", "params": {"target_class": "string", "assessment_time": "float[5-60], 默认15.0"}},
|
||||||
|
{"name": "search_pattern", "description": "按模式搜索。", "params": {"pattern_type": "string: spiral|grid", "center_x": "float", "center_y": "float", "center_z": "float", "radius": "float[5,1000]", "target_class": "string", "description": "string, 可选", "count": "int, 可选, 默认1"}},
|
||||||
|
{"name": "track_object", "description": "持续跟踪目标。", "params": {"target_class": "string, 见复杂模式定义列表", "description": "string, 可选", "track_time": "float[1,600], 默认30.0", "min_confidence": "float[0.5-1.0], 默认0.7", "safe_distance": "float[2-50], 默认10.0"}},
|
||||||
|
{"name": "deliver_payload", "description": "投放物资。", "params": {"payload_type": "string", "release_altitude": "float[2,100], 默认5.0"}},
|
||||||
|
{"name": "preflight_checks", "description": "飞行前系统自检。", "params": {"check_level": "string: basic|comprehensive"}},
|
||||||
|
{"name": "emergency_return", "description": "执行紧急返航程序。", "params": {"reason": "string"}}
|
||||||
|
],
|
||||||
|
"conditions": [
|
||||||
|
{"name": "battery_above", "description": "电池电量高于阈值。", "params": {"threshold": "float[0.0,1.0]"}},
|
||||||
|
{"name": "at_waypoint", "description": "在指定坐标容差范围内。", "params": {"x": "float", "y": "float", "z": "float", "tolerance": "float, 可选, 默认3.0"}},
|
||||||
|
{"name": "object_detected", "description": "检测到特定目标。", "params": {"target_class": "string", "description": "string, 可选", "count": "int, 可选, 默认1"}},
|
||||||
|
{"name": "target_destroyed", "description": "目标已被摧毁。", "params": {"target_class": "string", "description": "string, 可选", "confidence": "float[0.5-1.0], 默认0.8"}},
|
||||||
|
{"name": "time_elapsed", "description": "时间经过。", "params": {"duration": "float[1,2700]"}},
|
||||||
|
{"name": "gps_status", "description": "GPS状态良好。", "params": {"min_satellites": "int[6,15], 默认10"}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
—— 参数约束(与复杂模式保持一致,必须遵守)——
|
||||||
|
- takeoff.altitude: [1, 100]
|
||||||
|
- fly_to_waypoint.z: [1, 5000]
|
||||||
|
- fly_to_waypoint.x,y: [-10000, 10000]
|
||||||
|
- search_pattern.radius: [5, 1000]
|
||||||
|
- 电池阈值等同复杂模式(如需涉及)
|
||||||
|
- 若参考知识提供坐标,必须使用并裁剪到约束范围内
|
||||||
@@ -531,7 +531,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "Selector",
|
"type": "Selector",
|
||||||
"name": "SafetyMonitorSelector",
|
"name": "SafetyMonitor",
|
||||||
"params": {
|
"params": {
|
||||||
"memory": true
|
"memory": true
|
||||||
},
|
},
|
||||||
@@ -552,7 +552,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "Sequence",
|
"type": "Sequence",
|
||||||
"name": "EmergencyProcedureSequence",
|
"name": "EmergencyHandler",
|
||||||
"children": [
|
"children": [
|
||||||
{
|
{
|
||||||
"type": "action",
|
"type": "action",
|
||||||
|
|||||||
@@ -335,6 +335,32 @@ def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> di
|
|||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
def _generate_simple_mode_schema(allowed_actions: set) -> dict:
|
||||||
|
"""
|
||||||
|
生成简单模式JSON Schema:{"mode":"simple","action":{...}}
|
||||||
|
仅校验动作名称在允许集合内,以及基本结构完整性;参数按对象形状放宽,由上游提示词与运行时再约束。
|
||||||
|
"""
|
||||||
|
schema = {
|
||||||
|
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||||
|
"title": "SimpleMode",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"mode": {"type": "string", "const": "simple"},
|
||||||
|
"action": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string", "enum": sorted(list(allowed_actions))},
|
||||||
|
"params": {"type": "object"}
|
||||||
|
},
|
||||||
|
"required": ["name"],
|
||||||
|
"additionalProperties": True
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["mode", "action"],
|
||||||
|
"additionalProperties": False
|
||||||
|
}
|
||||||
|
return schema
|
||||||
|
|
||||||
def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
|
def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
|
||||||
"""
|
"""
|
||||||
使用JSON Schema验证给定的Pytree实例。
|
使用JSON Schema验证给定的Pytree实例。
|
||||||
@@ -522,13 +548,27 @@ class PyTreeGenerator:
|
|||||||
# Updated output directory for visualizations
|
# Updated output directory for visualizations
|
||||||
self.vis_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_visualizations'))
|
self.vis_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_visualizations'))
|
||||||
os.makedirs(self.vis_dir, exist_ok=True)
|
os.makedirs(self.vis_dir, exist_ok=True)
|
||||||
self.system_prompt = self._load_prompt("system_prompt.txt")
|
# 加载提示词:复杂模式复用现有 system_prompt.txt;简单模式与分类器独立提示词
|
||||||
|
self.complex_prompt = self._load_prompt("system_prompt.txt")
|
||||||
|
self.simple_prompt = self._load_prompt("simple_mode_prompt.txt")
|
||||||
|
self.classifier_prompt = self._load_prompt("classifier_prompt.txt")
|
||||||
|
# 兼容旧变量名
|
||||||
|
self.system_prompt = self.complex_prompt
|
||||||
|
|
||||||
self.orin_ip = os.getenv("ORIN_IP", "localhost")
|
self.orin_ip = os.getenv("ORIN_IP", "localhost")
|
||||||
self.llm_client = openai.OpenAI(
|
# 三类模型的可配置项:基于不同模型与Base URL分流
|
||||||
api_key=os.getenv("OPENAI_API_KEY", "sk-no-key-required"),
|
self.classifier_model = os.getenv("CLASSIFIER_MODEL", os.getenv("OPENAI_MODEL", "local-model"))
|
||||||
base_url=f"http://{self.orin_ip}:8081/v1"
|
self.simple_model = os.getenv("SIMPLE_MODEL", os.getenv("OPENAI_MODEL", "local-model"))
|
||||||
)
|
self.complex_model = os.getenv("COMPLEX_MODEL", os.getenv("OPENAI_MODEL", "local-model"))
|
||||||
|
self.classifier_base_url = os.getenv("CLASSIFIER_BASE_URL", f"http://{self.orin_ip}:8081/v1")
|
||||||
|
self.simple_base_url = os.getenv("SIMPLE_BASE_URL", f"http://{self.orin_ip}:8081/v1")
|
||||||
|
self.complex_base_url = os.getenv("COMPLEX_BASE_URL", f"http://{self.orin_ip}:8081/v1")
|
||||||
|
self.api_key = os.getenv("OPENAI_API_KEY", "sk-no-key-required")
|
||||||
|
|
||||||
|
# 为不同用途分别创建客户端
|
||||||
|
self.classifier_client = openai.OpenAI(api_key=self.api_key, base_url=self.classifier_base_url)
|
||||||
|
self.simple_llm_client = openai.OpenAI(api_key=self.api_key, base_url=self.simple_base_url)
|
||||||
|
self.complex_llm_client = openai.OpenAI(api_key=self.api_key, base_url=self.complex_base_url)
|
||||||
|
|
||||||
# --- ChromaDB Client Setup ---
|
# --- ChromaDB Client Setup ---
|
||||||
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'vector_store'))
|
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'vector_store'))
|
||||||
@@ -542,8 +582,10 @@ class PyTreeGenerator:
|
|||||||
embedding_function=embedding_func
|
embedding_function=embedding_func
|
||||||
)
|
)
|
||||||
|
|
||||||
allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.system_prompt)
|
# 使用复杂模式提示词作为节点来源,确保Schema稳定
|
||||||
|
allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.complex_prompt)
|
||||||
self.schema = _generate_pytree_schema(allowed_actions, allowed_conditions)
|
self.schema = _generate_pytree_schema(allowed_actions, allowed_conditions)
|
||||||
|
self.simple_schema = _generate_simple_mode_schema(allowed_actions)
|
||||||
|
|
||||||
def _load_prompt(self, file_name: str) -> str:
|
def _load_prompt(self, file_name: str) -> str:
|
||||||
try:
|
try:
|
||||||
@@ -574,9 +616,36 @@ class PyTreeGenerator:
|
|||||||
"""
|
"""
|
||||||
logging.info(f"接收到用户请求: {user_prompt}")
|
logging.info(f"接收到用户请求: {user_prompt}")
|
||||||
|
|
||||||
retrieved_context = self._retrieve_context(user_prompt)
|
# 第一步:分类(简单/复杂)
|
||||||
|
mode = "complex"
|
||||||
|
try:
|
||||||
|
classifier_resp = self.classifier_client.chat.completions.create(
|
||||||
|
model=self.classifier_model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": self.classifier_prompt or "你是一个分类器,只输出JSON。"},
|
||||||
|
{"role": "user", "content": user_prompt}
|
||||||
|
],
|
||||||
|
temperature=0.0,
|
||||||
|
response_format={"type": "json_object"}
|
||||||
|
)
|
||||||
|
class_str = classifier_resp.choices[0].message.content
|
||||||
|
class_obj = json.loads(class_str)
|
||||||
|
if isinstance(class_obj, dict) and class_obj.get("mode") in ("simple", "complex"):
|
||||||
|
mode = class_obj.get("mode")
|
||||||
|
logging.info(f"分类结果: {mode}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"分类失败,默认按复杂指令处理: {e}")
|
||||||
|
|
||||||
|
# 第二步:根据模式准备提示词与上下文
|
||||||
|
# 基于模式选择提示词;复杂模式追加一条强制规则,避免模型误输出简单结构
|
||||||
|
use_prompt = self.simple_prompt if mode == "simple" else (
|
||||||
|
(self.complex_prompt or "") +
|
||||||
|
"\n\n【强制规则】仅生成包含root的复杂行为树JSON,不得输出简单模式(不得包含mode字段或仅有action节点)。"
|
||||||
|
)
|
||||||
final_user_prompt = user_prompt
|
final_user_prompt = user_prompt
|
||||||
|
retrieved_context = None
|
||||||
|
if mode == "complex":
|
||||||
|
retrieved_context = self._retrieve_context(user_prompt)
|
||||||
if retrieved_context:
|
if retrieved_context:
|
||||||
augmentation = (
|
augmentation = (
|
||||||
"\n\n---\n"
|
"\n\n---\n"
|
||||||
@@ -591,17 +660,92 @@ class PyTreeGenerator:
|
|||||||
for attempt in range(3):
|
for attempt in range(3):
|
||||||
logging.info(f"--- 第 {attempt + 1}/3 次尝试生成Pytree ---")
|
logging.info(f"--- 第 {attempt + 1}/3 次尝试生成Pytree ---")
|
||||||
try:
|
try:
|
||||||
response = self.llm_client.chat.completions.create(
|
# 简单/复杂分流到不同模型与提示词
|
||||||
model="local-model",
|
client = self.simple_llm_client if mode == "simple" else self.complex_llm_client
|
||||||
|
model_name = self.simple_model if mode == "simple" else self.complex_model
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": self.system_prompt},
|
{"role": "system", "content": use_prompt},
|
||||||
{"role": "user", "content": final_user_prompt}
|
{"role": "user", "content": final_user_prompt}
|
||||||
],
|
],
|
||||||
temperature=0.1,
|
temperature=0.1 if mode == "complex" else 0.0,
|
||||||
response_format={"type": "json_object"}
|
response_format={"type": "json_object"}
|
||||||
)
|
)
|
||||||
pytree_str = response.choices[0].message.content
|
pytree_str = response.choices[0].message.content
|
||||||
|
# 单独捕获JSON解析错误并打印原始响应
|
||||||
|
try:
|
||||||
pytree_dict = json.loads(pytree_str)
|
pytree_dict = json.loads(pytree_str)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logging.error(f"❌ JSON解析失败(第 {attempt + 1}/3 次)。原始响应如下:\n{pytree_str}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 简单/复杂分别验证与返回
|
||||||
|
if mode == "simple":
|
||||||
|
try:
|
||||||
|
jsonschema.validate(instance=pytree_dict, schema=self.simple_schema)
|
||||||
|
logging.info("✅ 简单模式JSON Schema验证成功")
|
||||||
|
except jsonschema.ValidationError as e:
|
||||||
|
logging.warning(f"❌ 简单模式验证失败: {e.message}")
|
||||||
|
continue
|
||||||
|
# 附加元信息并生成简单可视化(单动作)
|
||||||
|
plan_id = str(uuid.uuid4())
|
||||||
|
pytree_dict['plan_id'] = plan_id
|
||||||
|
# 简单模式可视化:构造一个简化节点图
|
||||||
|
try:
|
||||||
|
vis_filename = "py_tree.png"
|
||||||
|
vis_path = os.path.join(self.vis_dir, vis_filename)
|
||||||
|
simple_node = {
|
||||||
|
"type": "action",
|
||||||
|
"name": pytree_dict.get('action', {}).get('name', 'action'),
|
||||||
|
"params": pytree_dict.get('action', {}).get('params', {})
|
||||||
|
}
|
||||||
|
_visualize_pytree(simple_node, os.path.splitext(vis_path)[0])
|
||||||
|
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"简单模式可视化失败: {e}")
|
||||||
|
return pytree_dict
|
||||||
|
|
||||||
|
# 复杂模式回退:若模型误返回简单结构,则自动包装为含安全监控的行为树
|
||||||
|
if mode == "complex" and isinstance(pytree_dict, dict) and 'root' not in pytree_dict:
|
||||||
|
try:
|
||||||
|
jsonschema.validate(instance=pytree_dict, schema=self.simple_schema)
|
||||||
|
logging.warning("⚠️ 复杂模式生成了简单结构,触发自动包装为完整行为树的回退逻辑。")
|
||||||
|
simple_action_obj = pytree_dict.get('action') or {}
|
||||||
|
action_name = simple_action_obj.get('name')
|
||||||
|
action_params = simple_action_obj.get('params') if isinstance(simple_action_obj.get('params'), dict) else {}
|
||||||
|
|
||||||
|
safety_selector = {
|
||||||
|
"type": "Selector",
|
||||||
|
"name": "SafetyMonitor",
|
||||||
|
"params": {"memory": True},
|
||||||
|
"children": [
|
||||||
|
{"type": "condition", "name": "battery_above", "params": {"threshold": 0.3}},
|
||||||
|
{"type": "condition", "name": "gps_status", "params": {"min_satellites": 8}},
|
||||||
|
{"type": "Sequence", "name": "EmergencyHandler", "children": [
|
||||||
|
{"type": "action", "name": "emergency_return", "params": {"reason": "safety_breach"}},
|
||||||
|
{"type": "action", "name": "land", "params": {"mode": "home"}}
|
||||||
|
]}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
main_children = [{"type": "action", "name": action_name, "params": action_params}]
|
||||||
|
if action_name != "land":
|
||||||
|
main_children.append({"type": "action", "name": "land", "params": {"mode": "home"}})
|
||||||
|
|
||||||
|
root_parallel = {
|
||||||
|
"type": "Parallel",
|
||||||
|
"name": "MissionWithSafety",
|
||||||
|
"params": {"policy": "all_success"},
|
||||||
|
"children": [
|
||||||
|
{"type": "Sequence", "name": "MainTask", "children": main_children},
|
||||||
|
safety_selector
|
||||||
|
]
|
||||||
|
}
|
||||||
|
pytree_dict = {"root": root_parallel}
|
||||||
|
except jsonschema.ValidationError:
|
||||||
|
# 不符合简单结构,按正常复杂验证继续
|
||||||
|
pass
|
||||||
if _validate_pytree_with_schema(pytree_dict, self.schema):
|
if _validate_pytree_with_schema(pytree_dict, self.schema):
|
||||||
logging.info("✅ 成功生成并验证了Pytree")
|
logging.info("✅ 成功生成并验证了Pytree")
|
||||||
plan_id = str(uuid.uuid4())
|
plan_id = str(uuid.uuid4())
|
||||||
@@ -614,8 +758,11 @@ class PyTreeGenerator:
|
|||||||
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
|
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
|
||||||
return pytree_dict
|
return pytree_dict
|
||||||
else:
|
else:
|
||||||
|
# 打印未通过验证的Pytree以便排查
|
||||||
|
preview = json.dumps(pytree_dict, ensure_ascii=False, indent=2)
|
||||||
|
logging.warning(f"❌ 未通过验证的Pytree(第 {attempt + 1}/3 次尝试):\n{preview}")
|
||||||
logging.warning("生成的Pytree验证失败,正在重试...")
|
logging.warning("生成的Pytree验证失败,正在重试...")
|
||||||
except (OpenAIError, json.JSONDecodeError) as e:
|
except OpenAIError as e:
|
||||||
logging.error(f"生成Pytree时发生错误: {e}")
|
logging.error(f"生成Pytree时发生错误: {e}")
|
||||||
|
|
||||||
raise RuntimeError("在3次尝试后,仍未能生成一个有效的Pytree。")
|
raise RuntimeError("在3次尝试后,仍未能生成一个有效的Pytree。")
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ BASE_URL = "http://127.0.0.1:8000"
|
|||||||
ENDPOINT = "/generate_plan"
|
ENDPOINT = "/generate_plan"
|
||||||
|
|
||||||
# The user prompt we will send for the test
|
# The user prompt we will send for the test
|
||||||
TEST_PROMPT = "起飞"
|
TEST_PROMPT = "飞到学生宿舍"
|
||||||
|
|
||||||
def test_generate_plan():
|
def test_generate_plan():
|
||||||
"""
|
"""
|
||||||
|
|||||||
Binary file not shown.
Reference in New Issue
Block a user