修正了行为树可视化的逻辑,优化了系统提示此
This commit is contained in:
@@ -4,20 +4,18 @@ import logging
|
||||
import uuid
|
||||
import re
|
||||
from typing import Dict, Any, Optional, Set
|
||||
|
||||
import chromadb
|
||||
import openai
|
||||
from openai import OpenAIError
|
||||
import jsonschema
|
||||
import requests
|
||||
import platform # 新增:用于选择合适的中文字体
|
||||
|
||||
# --- 自定义远程嵌入函数 (与ingest.py中定义一致) ---
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
|
||||
|
||||
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
def __init__(self, api_url: str):
|
||||
self._api_url = api_url
|
||||
|
||||
def __call__(self, input: Embeddable) -> Embeddings:
|
||||
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
|
||||
return []
|
||||
@@ -48,7 +46,6 @@ logging.basicConfig(
|
||||
# ==============================================================================
|
||||
# VALIDATION LOGIC (from utils/validation.py)
|
||||
# ==============================================================================
|
||||
|
||||
def _parse_allowed_nodes_from_prompt(prompt_text: str) -> tuple[Set[str], Set[str]]:
|
||||
"""
|
||||
从系统提示词中精确解析出允许的行动和条件节点。
|
||||
@@ -279,7 +276,6 @@ def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
|
||||
# ==============================================================================
|
||||
# VISUALIZATION LOGIC (from utils/visualization.py)
|
||||
# ==============================================================================
|
||||
|
||||
def _visualize_pytree(node: Dict, file_path: str):
|
||||
"""
|
||||
使用Graphviz将Pytree字典可视化,并保存到指定路径。
|
||||
@@ -290,15 +286,36 @@ def _visualize_pytree(node: Dict, file_path: str):
|
||||
logging.critical("错误:未安装graphviz库。请运行: pip install graphviz")
|
||||
return
|
||||
|
||||
dot = Digraph('Pytree', comment='Drone Mission Plan')
|
||||
dot.attr('node', shape='box', style='rounded,filled', fontname='helvetica')
|
||||
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20')
|
||||
# 选择合适的中文字体,避免中文乱码
|
||||
def _pick_zh_font():
|
||||
sys = platform.system()
|
||||
if sys == "Windows":
|
||||
return "Microsoft YaHei"
|
||||
elif sys == "Darwin":
|
||||
return "PingFang SC"
|
||||
else:
|
||||
return "Noto Sans CJK SC"
|
||||
|
||||
fontname = _pick_zh_font()
|
||||
|
||||
dot = Digraph('Pytree', comment='Drone Mission Plan')
|
||||
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20', fontname=fontname)
|
||||
dot.attr('node', shape='box', style='rounded,filled', fontname=fontname)
|
||||
dot.attr('edge', fontname=fontname)
|
||||
|
||||
_add_nodes_and_edges(node, dot)
|
||||
|
||||
try:
|
||||
# 确保输出目录存在,并避免生成 .png.png
|
||||
base_path, ext = os.path.splitext(file_path)
|
||||
render_path = base_path if ext.lower() == '.png' else file_path
|
||||
|
||||
out_dir = os.path.dirname(render_path)
|
||||
if out_dir and not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# 保存为 .png 文件,并自动删除源码 .gv 文件
|
||||
output_path = dot.render(file_path, format='png', cleanup=True, view=False)
|
||||
output_path = dot.render(render_path, format='png', cleanup=True, view=False)
|
||||
logging.info("--- 任务树可视化成功 ---")
|
||||
logging.info(f"图形已保存到: {output_path}")
|
||||
except Exception as e:
|
||||
@@ -309,44 +326,96 @@ def _visualize_pytree(node: Dict, file_path: str):
|
||||
def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
|
||||
"""递归辅助函数,用于添加节点和边。"""
|
||||
|
||||
# 为每个节点创建一个唯一的ID
|
||||
current_id = str(id(node))
|
||||
# 为每个节点创建一个唯一的ID(加上随机数避免冲突)
|
||||
import random
|
||||
import html
|
||||
|
||||
current_id = f"{id(node)}_{random.randint(1000, 9999)}"
|
||||
|
||||
# 准备节点标签(HTML-like,正确换行与转义)
|
||||
name = html.escape(str(node.get('name', '')))
|
||||
ntype = html.escape(str(node.get('type', '')))
|
||||
label_parts = [f"<B>{name}</B> <FONT POINT-SIZE='10'><I>({ntype})</I></FONT>"]
|
||||
|
||||
# 格式化参数显示
|
||||
params = node.get('params') or {}
|
||||
if params:
|
||||
params_lines = []
|
||||
for key, value in params.items():
|
||||
k = html.escape(str(key))
|
||||
if isinstance(value, float):
|
||||
value_str = f"{value:.2f}".rstrip('0').rstrip('.')
|
||||
else:
|
||||
value_str = str(value)
|
||||
v = html.escape(value_str)
|
||||
params_lines.append(f"{k}: {v}")
|
||||
params_text = "<BR ALIGN='LEFT'/>".join(params_lines)
|
||||
label_parts.append(f"<FONT POINT-SIZE='9' COLOR='#555555'>{params_text}</FONT>")
|
||||
|
||||
node_label = f"<{'<BR/>'.join(label_parts)}>"
|
||||
|
||||
# 根据类型设置节点样式和颜色(使用 fillcolor 控制填充色)
|
||||
node_type = (node.get('type') or '').lower()
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#e6e6e6' # 默认灰色填充
|
||||
border_color = '#666666' # 默认描边色
|
||||
|
||||
# 准备节点标签
|
||||
node_label = f"<{node['name']}<br/><i>({node['type']})</i>"
|
||||
if node.get('params'):
|
||||
params_str = json.dumps(node.get('params'))
|
||||
node_label += f"<br/><font point-size='10'>params: {params_str}</font>"
|
||||
node_label += ">"
|
||||
|
||||
# 根据类型设置节点样式
|
||||
node_type = node.get('type', '').lower()
|
||||
if node_type == 'action':
|
||||
dot.node(current_id, label=node_label, shape='box', color="#cde4ff")
|
||||
shape = 'box'
|
||||
style = 'rounded,filled'
|
||||
fillcolor = "#cde4ff" # 浅蓝
|
||||
elif node_type == 'condition':
|
||||
dot.node(current_id, label=node_label, shape='diamond', color="#fff2cc")
|
||||
else: # Sequence, Selector, etc.
|
||||
dot.node(current_id, label=node_label, shape='ellipse', color='#e6e6e6')
|
||||
|
||||
shape = 'diamond'
|
||||
style = 'filled'
|
||||
fillcolor = "#fff2cc" # 浅黄
|
||||
elif node_type == 'sequence':
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#d5e8d4' # 绿色
|
||||
elif node_type == 'selector':
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#ffe6cc' # 橙色
|
||||
elif node_type == 'parallel':
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#e1d5e7' # 紫色
|
||||
|
||||
dot.node(current_id, label=node_label, shape=shape, style=style, fillcolor=fillcolor, color=border_color)
|
||||
|
||||
# 连接父节点
|
||||
if parent_id:
|
||||
dot.edge(parent_id, current_id)
|
||||
|
||||
# 递归处理子节点
|
||||
last_child_id = current_id
|
||||
for child in node.get("children", []):
|
||||
# 对于序列,边是连续的;对于选择器,所有子节点都连接到父节点
|
||||
if node_type in ['sequence']:
|
||||
last_child_id = _add_nodes_and_edges(child, dot, last_child_id)
|
||||
else: # Selector, Parallel
|
||||
_add_nodes_and_edges(child, dot, current_id)
|
||||
|
||||
children = node.get("children", [])
|
||||
if not children:
|
||||
return current_id
|
||||
|
||||
# 记录所有子节点的ID
|
||||
child_ids = []
|
||||
|
||||
# 正确的递归连接:每个子节点都连接到当前节点
|
||||
for child in children:
|
||||
child_id = _add_nodes_and_edges(child, dot, current_id)
|
||||
child_ids.append(child_id)
|
||||
|
||||
# 子节点同级排列(横向排布,更直观地表现同层)
|
||||
if len(child_ids) > 1:
|
||||
with dot.subgraph(name=f"rank_{current_id}") as s:
|
||||
s.attr(rank='same')
|
||||
for cid in child_ids:
|
||||
s.node(cid)
|
||||
|
||||
# 行为树中,所有类型的节点都只是父连子,不需要子节点间的额外连接
|
||||
# Sequence、Selector、Parallel 的执行逻辑由行为树引擎处理,不需要在可视化中体现
|
||||
|
||||
return current_id
|
||||
|
||||
# ==============================================================================
|
||||
# CORE PYTREE GENERATOR CLASS
|
||||
# ==============================================================================
|
||||
|
||||
class PyTreeGenerator:
|
||||
def __init__(self):
|
||||
self.base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -355,7 +424,6 @@ class PyTreeGenerator:
|
||||
# Updated output directory for visualizations
|
||||
self.vis_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_visualizations'))
|
||||
os.makedirs(self.vis_dir, exist_ok=True)
|
||||
|
||||
self.system_prompt = self._load_prompt("system_prompt.txt")
|
||||
|
||||
self.orin_ip = os.getenv("ORIN_IP", "localhost")
|
||||
@@ -363,7 +431,7 @@ class PyTreeGenerator:
|
||||
api_key=os.getenv("OPENAI_API_KEY", "sk-no-key-required"),
|
||||
base_url=f"http://{self.orin_ip}:8081/v1"
|
||||
)
|
||||
|
||||
|
||||
# --- ChromaDB Client Setup ---
|
||||
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'vector_store'))
|
||||
self.chroma_client = chromadb.PersistentClient(path=vector_store_path)
|
||||
@@ -371,12 +439,11 @@ class PyTreeGenerator:
|
||||
# Explicitly use the remote embedding function for queries
|
||||
embedding_api_url = f"http://{self.orin_ip}:8090/v1/embeddings"
|
||||
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
|
||||
|
||||
self.collection = self.chroma_client.get_collection(
|
||||
name="drone_docs",
|
||||
embedding_function=embedding_func
|
||||
)
|
||||
|
||||
|
||||
allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.system_prompt)
|
||||
self.schema = _generate_pytree_schema(allowed_actions, allowed_conditions)
|
||||
|
||||
@@ -423,7 +490,6 @@ class PyTreeGenerator:
|
||||
final_user_prompt += augmentation
|
||||
else:
|
||||
logging.warning("未检索到上下文或检索失败,将使用原始用户提示词。")
|
||||
|
||||
for attempt in range(3):
|
||||
logging.info(f"--- 第 {attempt + 1}/3 次尝试生成Pytree ---")
|
||||
try:
|
||||
@@ -438,7 +504,6 @@ class PyTreeGenerator:
|
||||
)
|
||||
pytree_str = response.choices[0].message.content
|
||||
pytree_dict = json.loads(pytree_str)
|
||||
|
||||
if _validate_pytree_with_schema(pytree_dict, self.schema):
|
||||
logging.info("成功生成并验证了Pytree。")
|
||||
plan_id = str(uuid.uuid4())
|
||||
@@ -449,15 +514,13 @@ class PyTreeGenerator:
|
||||
vis_path = os.path.join(self.vis_dir, vis_filename)
|
||||
_visualize_pytree(pytree_dict['root'], os.path.splitext(vis_path)[0])
|
||||
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
|
||||
|
||||
return pytree_dict
|
||||
else:
|
||||
logging.warning("生成的Pytree验证失败,正在重试...")
|
||||
|
||||
except (OpenAIError, json.JSONDecodeError) as e:
|
||||
logging.error(f"生成Pytree时发生错误: {e}")
|
||||
|
||||
raise RuntimeError("在3次尝试后,仍未能生成一个有效的Pytree。")
|
||||
|
||||
# Create a single instance for the application
|
||||
py_tree_generator = PyTreeGenerator()
|
||||
py_tree_generator = PyTreeGenerator()
|
||||
Reference in New Issue
Block a user