代码内容转移
This commit is contained in:
330
backend_service/src/py_tree_generator.py
Normal file
330
backend_service/src/py_tree_generator.py
Normal file
@@ -0,0 +1,330 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import uuid
|
||||
import re
|
||||
from typing import Dict, Any, Optional, Set
|
||||
|
||||
import chromadb
|
||||
import openai
|
||||
from openai import OpenAIError
|
||||
import jsonschema
|
||||
import requests
|
||||
|
||||
# --- 自定义远程嵌入函数 (与ingest.py中定义一致) ---
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
|
||||
|
||||
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
def __init__(self, api_url: str):
|
||||
self._api_url = api_url
|
||||
|
||||
def __call__(self, input: Embeddable) -> Embeddings:
|
||||
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
|
||||
return []
|
||||
try:
|
||||
response = requests.post(
|
||||
self._api_url,
|
||||
json={"input": input},
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json().get("data", [])
|
||||
if not data:
|
||||
return []
|
||||
return [item['embedding'] for item in data]
|
||||
except Exception as e:
|
||||
logging.error(f"调用远程嵌入API时出错: {e}")
|
||||
return []
|
||||
|
||||
# --- 日志记录设置 ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
|
||||
# ==============================================================================
|
||||
# VALIDATION LOGIC (from utils/validation.py)
|
||||
# ==============================================================================
|
||||
|
||||
def _parse_allowed_nodes_from_prompt(prompt_text: str) -> tuple[Set[str], Set[str]]:
|
||||
"""
|
||||
从系统提示词中解析出允许的行动和条件节点。
|
||||
"""
|
||||
try:
|
||||
# 使用正则表达式查找JSON代码块
|
||||
match = re.search(r"```json\s*({.*?})\s*```", prompt_text, re.DOTALL)
|
||||
if not match:
|
||||
logging.error("在系统提示词中未找到可用节点的JSON定义块。")
|
||||
return set(), set()
|
||||
|
||||
json_str = match.group(1)
|
||||
allowed_nodes = json.loads(json_str)
|
||||
|
||||
# 从对象列表中提取节点名称
|
||||
actions = {action['name'] for action in allowed_nodes.get("actions", []) if 'name' in action}
|
||||
conditions = {condition['name'] for condition in allowed_nodes.get("conditions", []) if 'name' in condition}
|
||||
|
||||
if not actions:
|
||||
logging.warning("关键错误:从提示词解析出的行动节点列表为空,无法生成任何有效任务。")
|
||||
|
||||
return actions, conditions
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logging.error("解析系统提示词中的JSON时失败。请检查格式。")
|
||||
return set(), set()
|
||||
except Exception as e:
|
||||
logging.error(f"解析可用节点时发生未知错误: {e}")
|
||||
return set(), set()
|
||||
|
||||
def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> dict:
|
||||
"""
|
||||
根据允许的行动和条件节点,动态生成一个JSON Schema。
|
||||
"""
|
||||
# 递归节点定义
|
||||
node_definition = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"type": "string", "enum": ["action", "condition", "Sequence", "Selector"]},
|
||||
"name": {"type": "string"},
|
||||
"params": {"type": "object"},
|
||||
"children": {
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/definitions/node"}
|
||||
}
|
||||
},
|
||||
"required": ["type", "name"],
|
||||
# 使用 allOf 和 if/then 来实现基于'type'字段的条件性'name'验证
|
||||
"allOf": [
|
||||
{
|
||||
"if": {"properties": {"type": {"const": "action"}}},
|
||||
"then": {"properties": {"name": {"enum": sorted(list(allowed_actions))}}}
|
||||
},
|
||||
{
|
||||
"if": {"properties": {"type": {"const": "condition"}}},
|
||||
"then": {"properties": {"name": {"enum": sorted(list(allowed_conditions))}}}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 完整的Schema结构
|
||||
schema = {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "Pytree",
|
||||
"definitions": {
|
||||
"node": node_definition
|
||||
},
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"root": { "$ref": "#/definitions/node" }
|
||||
},
|
||||
"required": ["root"]
|
||||
}
|
||||
return schema
|
||||
|
||||
def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
|
||||
"""
|
||||
使用JSON Schema验证给定的Pytree实例。
|
||||
"""
|
||||
try:
|
||||
jsonschema.validate(instance=pytree_instance, schema=schema)
|
||||
logging.info("验证成功!Pytree格式和内容均符合规范。")
|
||||
return True
|
||||
except jsonschema.ValidationError as e:
|
||||
logging.warning("--- Pytree验证失败 ---")
|
||||
logging.warning(f"错误信息: {e.message}")
|
||||
error_path = list(e.path)
|
||||
logging.warning(f"错误路径: {' -> '.join(map(str, error_path)) if error_path else '根节点'}")
|
||||
logging.warning(f"出错的实例部分: {e.instance}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.error(f"进行JSON Schema验证时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
# ==============================================================================
|
||||
# VISUALIZATION LOGIC (from utils/visualization.py)
|
||||
# ==============================================================================
|
||||
|
||||
def _visualize_pytree(node: Dict, file_path: str):
|
||||
"""
|
||||
使用Graphviz将Pytree字典可视化,并保存到指定路径。
|
||||
"""
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
logging.critical("错误:未安装graphviz库。请运行: pip install graphviz")
|
||||
return
|
||||
|
||||
dot = Digraph('Pytree', comment='Drone Mission Plan')
|
||||
dot.attr('node', shape='box', style='rounded,filled', fontname='helvetica')
|
||||
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20')
|
||||
|
||||
_add_nodes_and_edges(node, dot)
|
||||
|
||||
try:
|
||||
# 保存为 .png 文件,并自动删除源码 .gv 文件
|
||||
output_path = dot.render(file_path, format='png', cleanup=True, view=False)
|
||||
logging.info("--- 任务树可视化成功 ---")
|
||||
logging.info(f"图形已保存到: {output_path}")
|
||||
except Exception as e:
|
||||
logging.error("--- 错误:生成可视化图形失败 ---")
|
||||
logging.error("请确保您的系统已经正确安装了Graphviz图形库。")
|
||||
logging.error(f"错误详情: {e}")
|
||||
|
||||
def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
|
||||
"""递归辅助函数,用于添加节点和边。"""
|
||||
|
||||
# 为每个节点创建一个唯一的ID
|
||||
current_id = str(id(node))
|
||||
|
||||
# 准备节点标签
|
||||
node_label = f"<{node['name']}<br/><i>({node['type']})</i>"
|
||||
if node.get('params'):
|
||||
params_str = json.dumps(node.get('params'))
|
||||
node_label += f"<br/><font point-size='10'>params: {params_str}</font>"
|
||||
node_label += ">"
|
||||
|
||||
# 根据类型设置节点样式
|
||||
node_type = node.get('type', '').lower()
|
||||
if node_type == 'action':
|
||||
dot.node(current_id, label=node_label, shape='box', color="#cde4ff")
|
||||
elif node_type == 'condition':
|
||||
dot.node(current_id, label=node_label, shape='diamond', color="#fff2cc")
|
||||
else: # Sequence, Selector, etc.
|
||||
dot.node(current_id, label=node_label, shape='ellipse', color='#e6e6e6')
|
||||
|
||||
# 连接父节点
|
||||
if parent_id:
|
||||
dot.edge(parent_id, current_id)
|
||||
|
||||
# 递归处理子节点
|
||||
last_child_id = current_id
|
||||
for child in node.get("children", []):
|
||||
# 对于序列,边是连续的;对于选择器,所有子节点都连接到父节点
|
||||
if node_type in ['sequence']:
|
||||
last_child_id = _add_nodes_and_edges(child, dot, last_child_id)
|
||||
else: # Selector, Parallel
|
||||
_add_nodes_and_edges(child, dot, current_id)
|
||||
|
||||
return current_id
|
||||
|
||||
# ==============================================================================
|
||||
# CORE PYTREE GENERATOR CLASS
|
||||
# ==============================================================================
|
||||
|
||||
class PyTreeGenerator:
|
||||
def __init__(self):
|
||||
self.base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.prompts_dir = os.path.join(self.base_dir, 'prompts')
|
||||
|
||||
# Updated output directory for visualizations
|
||||
self.vis_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_visualizations'))
|
||||
os.makedirs(self.vis_dir, exist_ok=True)
|
||||
|
||||
self.system_prompt = self._load_prompt("system_prompt.txt")
|
||||
|
||||
self.orin_ip = os.getenv("ORIN_IP", "172.101.1.117")
|
||||
self.llm_client = openai.OpenAI(
|
||||
api_key=os.getenv("OPENAI_API_KEY", "sk-no-key-required"),
|
||||
base_url=f"http://{self.orin_ip}:8081/v1"
|
||||
)
|
||||
|
||||
# --- ChromaDB Client Setup ---
|
||||
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'vector_store'))
|
||||
self.chroma_client = chromadb.PersistentClient(path=vector_store_path)
|
||||
|
||||
# Explicitly use the remote embedding function for queries
|
||||
embedding_api_url = f"http://{self.orin_ip}:8090/v1/embeddings"
|
||||
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
|
||||
|
||||
self.collection = self.chroma_client.get_collection(
|
||||
name="drone_docs",
|
||||
embedding_function=embedding_func
|
||||
)
|
||||
|
||||
allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.system_prompt)
|
||||
self.schema = _generate_pytree_schema(allowed_actions, allowed_conditions)
|
||||
|
||||
def _load_prompt(self, file_name: str) -> str:
|
||||
try:
|
||||
with open(os.path.join(self.prompts_dir, file_name), 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
logging.error(f"提示词文件未找到 -> {file_name}")
|
||||
return ""
|
||||
|
||||
def _retrieve_context(self, query: str) -> Optional[str]:
|
||||
logging.info("--- 开始从向量数据库检索上下文 ---")
|
||||
try:
|
||||
results = self.collection.query(query_texts=[query], n_results=5)
|
||||
retrieved_docs = results.get("documents", [[]])[0]
|
||||
if not retrieved_docs:
|
||||
logging.warning("在向量数据库中没有找到相关的上下文信息。")
|
||||
return None
|
||||
context_str = "\n\n".join(retrieved_docs)
|
||||
logging.info("--- 成功检索到上下文信息 ---")
|
||||
return context_str
|
||||
except Exception as e:
|
||||
logging.error(f"从向量数据库检索时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def generate(self, user_prompt: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Generates a py_tree.json structure based on the user's prompt.
|
||||
"""
|
||||
logging.info(f"接收到用户请求: {user_prompt}")
|
||||
|
||||
retrieved_context = self._retrieve_context(user_prompt)
|
||||
|
||||
final_user_prompt = user_prompt
|
||||
if retrieved_context:
|
||||
augmentation = (
|
||||
"\n\n---\n"
|
||||
"参考知识:\n"
|
||||
"以下是从知识库中检索到的、与当前任务最相关的信息,请优先参考这些信息来生成行为树:\n"
|
||||
f"{retrieved_context}"
|
||||
"\n---"
|
||||
)
|
||||
final_user_prompt += augmentation
|
||||
else:
|
||||
logging.warning("未检索到上下文或检索失败,将使用原始用户提示词。")
|
||||
|
||||
for attempt in range(3):
|
||||
logging.info(f"--- 第 {attempt + 1}/3 次尝试生成Pytree ---")
|
||||
try:
|
||||
response = self.llm_client.chat.completions.create(
|
||||
model="local-model",
|
||||
messages=[
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": final_user_prompt}
|
||||
],
|
||||
temperature=0.1,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
pytree_str = response.choices[0].message.content
|
||||
pytree_dict = json.loads(pytree_str)
|
||||
|
||||
if _validate_pytree_with_schema(pytree_dict, self.schema):
|
||||
logging.info("成功生成并验证了Pytree。")
|
||||
plan_id = str(uuid.uuid4())
|
||||
pytree_dict['plan_id'] = plan_id
|
||||
|
||||
# Generate visualization to a static path
|
||||
vis_filename = "py_tree.png"
|
||||
vis_path = os.path.join(self.vis_dir, vis_filename)
|
||||
_visualize_pytree(pytree_dict['root'], os.path.splitext(vis_path)[0])
|
||||
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
|
||||
|
||||
return pytree_dict
|
||||
else:
|
||||
logging.warning("生成的Pytree验证失败,正在重试...")
|
||||
|
||||
except (OpenAIError, json.JSONDecodeError) as e:
|
||||
logging.error(f"生成Pytree时发生错误: {e}")
|
||||
|
||||
raise RuntimeError("在3次尝试后,仍未能生成一个有效的Pytree。")
|
||||
|
||||
# Create a single instance for the application
|
||||
py_tree_generator = PyTreeGenerator()
|
||||
Reference in New Issue
Block a user