修正了行为树可视化的逻辑,优化了系统提示此

This commit is contained in:
2025-08-28 13:30:13 +08:00
parent 5b50fc912f
commit a09ef9aeba
2 changed files with 294 additions and 160 deletions

View File

@@ -4,20 +4,18 @@ import logging
import uuid
import re
from typing import Dict, Any, Optional, Set
import chromadb
import openai
from openai import OpenAIError
import jsonschema
import requests
import platform # 新增:用于选择合适的中文字体
# --- 自定义远程嵌入函数 (与ingest.py中定义一致) ---
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
def __init__(self, api_url: str):
self._api_url = api_url
def __call__(self, input: Embeddable) -> Embeddings:
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
return []
@@ -48,7 +46,6 @@ logging.basicConfig(
# ==============================================================================
# VALIDATION LOGIC (from utils/validation.py)
# ==============================================================================
def _parse_allowed_nodes_from_prompt(prompt_text: str) -> tuple[Set[str], Set[str]]:
"""
从系统提示词中精确解析出允许的行动和条件节点。
@@ -279,7 +276,6 @@ def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
# ==============================================================================
# VISUALIZATION LOGIC (from utils/visualization.py)
# ==============================================================================
def _visualize_pytree(node: Dict, file_path: str):
"""
使用Graphviz将Pytree字典可视化并保存到指定路径。
@@ -290,15 +286,36 @@ def _visualize_pytree(node: Dict, file_path: str):
logging.critical("错误未安装graphviz库。请运行: pip install graphviz")
return
dot = Digraph('Pytree', comment='Drone Mission Plan')
dot.attr('node', shape='box', style='rounded,filled', fontname='helvetica')
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20')
# 选择合适的中文字体,避免中文乱码
def _pick_zh_font():
sys = platform.system()
if sys == "Windows":
return "Microsoft YaHei"
elif sys == "Darwin":
return "PingFang SC"
else:
return "Noto Sans CJK SC"
fontname = _pick_zh_font()
dot = Digraph('Pytree', comment='Drone Mission Plan')
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20', fontname=fontname)
dot.attr('node', shape='box', style='rounded,filled', fontname=fontname)
dot.attr('edge', fontname=fontname)
_add_nodes_and_edges(node, dot)
try:
# 确保输出目录存在,并避免生成 .png.png
base_path, ext = os.path.splitext(file_path)
render_path = base_path if ext.lower() == '.png' else file_path
out_dir = os.path.dirname(render_path)
if out_dir and not os.path.exists(out_dir):
os.makedirs(out_dir, exist_ok=True)
# 保存为 .png 文件,并自动删除源码 .gv 文件
output_path = dot.render(file_path, format='png', cleanup=True, view=False)
output_path = dot.render(render_path, format='png', cleanup=True, view=False)
logging.info("--- 任务树可视化成功 ---")
logging.info(f"图形已保存到: {output_path}")
except Exception as e:
@@ -309,44 +326,96 @@ def _visualize_pytree(node: Dict, file_path: str):
def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
"""递归辅助函数,用于添加节点和边。"""
# 为每个节点创建一个唯一的ID
current_id = str(id(node))
# 为每个节点创建一个唯一的ID(加上随机数避免冲突)
import random
import html
current_id = f"{id(node)}_{random.randint(1000, 9999)}"
# 准备节点标签HTML-like正确换行与转义
name = html.escape(str(node.get('name', '')))
ntype = html.escape(str(node.get('type', '')))
label_parts = [f"<B>{name}</B> <FONT POINT-SIZE='10'><I>({ntype})</I></FONT>"]
# 格式化参数显示
params = node.get('params') or {}
if params:
params_lines = []
for key, value in params.items():
k = html.escape(str(key))
if isinstance(value, float):
value_str = f"{value:.2f}".rstrip('0').rstrip('.')
else:
value_str = str(value)
v = html.escape(value_str)
params_lines.append(f"{k}: {v}")
params_text = "<BR ALIGN='LEFT'/>".join(params_lines)
label_parts.append(f"<FONT POINT-SIZE='9' COLOR='#555555'>{params_text}</FONT>")
node_label = f"<{'<BR/>'.join(label_parts)}>"
# 根据类型设置节点样式和颜色(使用 fillcolor 控制填充色)
node_type = (node.get('type') or '').lower()
shape = 'ellipse'
style = 'filled'
fillcolor = '#e6e6e6' # 默认灰色填充
border_color = '#666666' # 默认描边色
# 准备节点标签
node_label = f"<{node['name']}<br/><i>({node['type']})</i>"
if node.get('params'):
params_str = json.dumps(node.get('params'))
node_label += f"<br/><font point-size='10'>params: {params_str}</font>"
node_label += ">"
# 根据类型设置节点样式
node_type = node.get('type', '').lower()
if node_type == 'action':
dot.node(current_id, label=node_label, shape='box', color="#cde4ff")
shape = 'box'
style = 'rounded,filled'
fillcolor = "#cde4ff" # 浅蓝
elif node_type == 'condition':
dot.node(current_id, label=node_label, shape='diamond', color="#fff2cc")
else: # Sequence, Selector, etc.
dot.node(current_id, label=node_label, shape='ellipse', color='#e6e6e6')
shape = 'diamond'
style = 'filled'
fillcolor = "#fff2cc" # 浅黄
elif node_type == 'sequence':
shape = 'ellipse'
style = 'filled'
fillcolor = '#d5e8d4' # 绿色
elif node_type == 'selector':
shape = 'ellipse'
style = 'filled'
fillcolor = '#ffe6cc' # 橙色
elif node_type == 'parallel':
shape = 'ellipse'
style = 'filled'
fillcolor = '#e1d5e7' # 紫色
dot.node(current_id, label=node_label, shape=shape, style=style, fillcolor=fillcolor, color=border_color)
# 连接父节点
if parent_id:
dot.edge(parent_id, current_id)
# 递归处理子节点
last_child_id = current_id
for child in node.get("children", []):
# 对于序列,边是连续的;对于选择器,所有子节点都连接到父节点
if node_type in ['sequence']:
last_child_id = _add_nodes_and_edges(child, dot, last_child_id)
else: # Selector, Parallel
_add_nodes_and_edges(child, dot, current_id)
children = node.get("children", [])
if not children:
return current_id
# 记录所有子节点的ID
child_ids = []
# 正确的递归连接:每个子节点都连接到当前节点
for child in children:
child_id = _add_nodes_and_edges(child, dot, current_id)
child_ids.append(child_id)
# 子节点同级排列(横向排布,更直观地表现同层)
if len(child_ids) > 1:
with dot.subgraph(name=f"rank_{current_id}") as s:
s.attr(rank='same')
for cid in child_ids:
s.node(cid)
# 行为树中,所有类型的节点都只是父连子,不需要子节点间的额外连接
# Sequence、Selector、Parallel 的执行逻辑由行为树引擎处理,不需要在可视化中体现
return current_id
# ==============================================================================
# CORE PYTREE GENERATOR CLASS
# ==============================================================================
class PyTreeGenerator:
def __init__(self):
self.base_dir = os.path.dirname(os.path.abspath(__file__))
@@ -355,7 +424,6 @@ class PyTreeGenerator:
# Updated output directory for visualizations
self.vis_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_visualizations'))
os.makedirs(self.vis_dir, exist_ok=True)
self.system_prompt = self._load_prompt("system_prompt.txt")
self.orin_ip = os.getenv("ORIN_IP", "localhost")
@@ -363,7 +431,7 @@ class PyTreeGenerator:
api_key=os.getenv("OPENAI_API_KEY", "sk-no-key-required"),
base_url=f"http://{self.orin_ip}:8081/v1"
)
# --- ChromaDB Client Setup ---
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'vector_store'))
self.chroma_client = chromadb.PersistentClient(path=vector_store_path)
@@ -371,12 +439,11 @@ class PyTreeGenerator:
# Explicitly use the remote embedding function for queries
embedding_api_url = f"http://{self.orin_ip}:8090/v1/embeddings"
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
self.collection = self.chroma_client.get_collection(
name="drone_docs",
embedding_function=embedding_func
)
allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.system_prompt)
self.schema = _generate_pytree_schema(allowed_actions, allowed_conditions)
@@ -423,7 +490,6 @@ class PyTreeGenerator:
final_user_prompt += augmentation
else:
logging.warning("未检索到上下文或检索失败,将使用原始用户提示词。")
for attempt in range(3):
logging.info(f"--- 第 {attempt + 1}/3 次尝试生成Pytree ---")
try:
@@ -438,7 +504,6 @@ class PyTreeGenerator:
)
pytree_str = response.choices[0].message.content
pytree_dict = json.loads(pytree_str)
if _validate_pytree_with_schema(pytree_dict, self.schema):
logging.info("成功生成并验证了Pytree。")
plan_id = str(uuid.uuid4())
@@ -449,15 +514,13 @@ class PyTreeGenerator:
vis_path = os.path.join(self.vis_dir, vis_filename)
_visualize_pytree(pytree_dict['root'], os.path.splitext(vis_path)[0])
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
return pytree_dict
else:
logging.warning("生成的Pytree验证失败正在重试...")
except (OpenAIError, json.JSONDecodeError) as e:
logging.error(f"生成Pytree时发生错误: {e}")
raise RuntimeError("在3次尝试后仍未能生成一个有效的Pytree。")
# Create a single instance for the application
py_tree_generator = PyTreeGenerator()
py_tree_generator = PyTreeGenerator()