代码内容转移

This commit is contained in:
2025-08-17 22:41:54 +08:00
commit 0b50022af1
38 changed files with 72624 additions and 0 deletions

View File

@@ -0,0 +1,330 @@
import json
import os
import logging
import uuid
import re
from typing import Dict, Any, Optional, Set
import chromadb
import openai
from openai import OpenAIError
import jsonschema
import requests
# --- 自定义远程嵌入函数 (与ingest.py中定义一致) ---
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
def __init__(self, api_url: str):
self._api_url = api_url
def __call__(self, input: Embeddable) -> Embeddings:
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
return []
try:
response = requests.post(
self._api_url,
json={"input": input},
headers={"Content-Type": "application/json"}
)
response.raise_for_status()
data = response.json().get("data", [])
if not data:
return []
return [item['embedding'] for item in data]
except Exception as e:
logging.error(f"调用远程嵌入API时出错: {e}")
return []
# --- 日志记录设置 ---
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
# ==============================================================================
# VALIDATION LOGIC (from utils/validation.py)
# ==============================================================================
def _parse_allowed_nodes_from_prompt(prompt_text: str) -> tuple[Set[str], Set[str]]:
"""
从系统提示词中解析出允许的行动和条件节点。
"""
try:
# 使用正则表达式查找JSON代码块
match = re.search(r"```json\s*({.*?})\s*```", prompt_text, re.DOTALL)
if not match:
logging.error("在系统提示词中未找到可用节点的JSON定义块。")
return set(), set()
json_str = match.group(1)
allowed_nodes = json.loads(json_str)
# 从对象列表中提取节点名称
actions = {action['name'] for action in allowed_nodes.get("actions", []) if 'name' in action}
conditions = {condition['name'] for condition in allowed_nodes.get("conditions", []) if 'name' in condition}
if not actions:
logging.warning("关键错误:从提示词解析出的行动节点列表为空,无法生成任何有效任务。")
return actions, conditions
except json.JSONDecodeError:
logging.error("解析系统提示词中的JSON时失败。请检查格式。")
return set(), set()
except Exception as e:
logging.error(f"解析可用节点时发生未知错误: {e}")
return set(), set()
def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> dict:
"""
根据允许的行动和条件节点动态生成一个JSON Schema。
"""
# 递归节点定义
node_definition = {
"type": "object",
"properties": {
"type": {"type": "string", "enum": ["action", "condition", "Sequence", "Selector"]},
"name": {"type": "string"},
"params": {"type": "object"},
"children": {
"type": "array",
"items": {"$ref": "#/definitions/node"}
}
},
"required": ["type", "name"],
# 使用 allOf 和 if/then 来实现基于'type'字段的条件性'name'验证
"allOf": [
{
"if": {"properties": {"type": {"const": "action"}}},
"then": {"properties": {"name": {"enum": sorted(list(allowed_actions))}}}
},
{
"if": {"properties": {"type": {"const": "condition"}}},
"then": {"properties": {"name": {"enum": sorted(list(allowed_conditions))}}}
}
]
}
# 完整的Schema结构
schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "Pytree",
"definitions": {
"node": node_definition
},
"type": "object",
"properties": {
"root": { "$ref": "#/definitions/node" }
},
"required": ["root"]
}
return schema
def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
"""
使用JSON Schema验证给定的Pytree实例。
"""
try:
jsonschema.validate(instance=pytree_instance, schema=schema)
logging.info("验证成功Pytree格式和内容均符合规范。")
return True
except jsonschema.ValidationError as e:
logging.warning("--- Pytree验证失败 ---")
logging.warning(f"错误信息: {e.message}")
error_path = list(e.path)
logging.warning(f"错误路径: {' -> '.join(map(str, error_path)) if error_path else '根节点'}")
logging.warning(f"出错的实例部分: {e.instance}")
return False
except Exception as e:
logging.error(f"进行JSON Schema验证时发生未知错误: {e}")
return False
# ==============================================================================
# VISUALIZATION LOGIC (from utils/visualization.py)
# ==============================================================================
def _visualize_pytree(node: Dict, file_path: str):
"""
使用Graphviz将Pytree字典可视化并保存到指定路径。
"""
try:
from graphviz import Digraph
except ImportError:
logging.critical("错误未安装graphviz库。请运行: pip install graphviz")
return
dot = Digraph('Pytree', comment='Drone Mission Plan')
dot.attr('node', shape='box', style='rounded,filled', fontname='helvetica')
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20')
_add_nodes_and_edges(node, dot)
try:
# 保存为 .png 文件,并自动删除源码 .gv 文件
output_path = dot.render(file_path, format='png', cleanup=True, view=False)
logging.info("--- 任务树可视化成功 ---")
logging.info(f"图形已保存到: {output_path}")
except Exception as e:
logging.error("--- 错误:生成可视化图形失败 ---")
logging.error("请确保您的系统已经正确安装了Graphviz图形库。")
logging.error(f"错误详情: {e}")
def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
"""递归辅助函数,用于添加节点和边。"""
# 为每个节点创建一个唯一的ID
current_id = str(id(node))
# 准备节点标签
node_label = f"<{node['name']}<br/><i>({node['type']})</i>"
if node.get('params'):
params_str = json.dumps(node.get('params'))
node_label += f"<br/><font point-size='10'>params: {params_str}</font>"
node_label += ">"
# 根据类型设置节点样式
node_type = node.get('type', '').lower()
if node_type == 'action':
dot.node(current_id, label=node_label, shape='box', color="#cde4ff")
elif node_type == 'condition':
dot.node(current_id, label=node_label, shape='diamond', color="#fff2cc")
else: # Sequence, Selector, etc.
dot.node(current_id, label=node_label, shape='ellipse', color='#e6e6e6')
# 连接父节点
if parent_id:
dot.edge(parent_id, current_id)
# 递归处理子节点
last_child_id = current_id
for child in node.get("children", []):
# 对于序列,边是连续的;对于选择器,所有子节点都连接到父节点
if node_type in ['sequence']:
last_child_id = _add_nodes_and_edges(child, dot, last_child_id)
else: # Selector, Parallel
_add_nodes_and_edges(child, dot, current_id)
return current_id
# ==============================================================================
# CORE PYTREE GENERATOR CLASS
# ==============================================================================
class PyTreeGenerator:
def __init__(self):
self.base_dir = os.path.dirname(os.path.abspath(__file__))
self.prompts_dir = os.path.join(self.base_dir, 'prompts')
# Updated output directory for visualizations
self.vis_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_visualizations'))
os.makedirs(self.vis_dir, exist_ok=True)
self.system_prompt = self._load_prompt("system_prompt.txt")
self.orin_ip = os.getenv("ORIN_IP", "172.101.1.117")
self.llm_client = openai.OpenAI(
api_key=os.getenv("OPENAI_API_KEY", "sk-no-key-required"),
base_url=f"http://{self.orin_ip}:8081/v1"
)
# --- ChromaDB Client Setup ---
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'vector_store'))
self.chroma_client = chromadb.PersistentClient(path=vector_store_path)
# Explicitly use the remote embedding function for queries
embedding_api_url = f"http://{self.orin_ip}:8090/v1/embeddings"
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
self.collection = self.chroma_client.get_collection(
name="drone_docs",
embedding_function=embedding_func
)
allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.system_prompt)
self.schema = _generate_pytree_schema(allowed_actions, allowed_conditions)
def _load_prompt(self, file_name: str) -> str:
try:
with open(os.path.join(self.prompts_dir, file_name), 'r', encoding='utf-8') as f:
return f.read()
except FileNotFoundError:
logging.error(f"提示词文件未找到 -> {file_name}")
return ""
def _retrieve_context(self, query: str) -> Optional[str]:
logging.info("--- 开始从向量数据库检索上下文 ---")
try:
results = self.collection.query(query_texts=[query], n_results=5)
retrieved_docs = results.get("documents", [[]])[0]
if not retrieved_docs:
logging.warning("在向量数据库中没有找到相关的上下文信息。")
return None
context_str = "\n\n".join(retrieved_docs)
logging.info("--- 成功检索到上下文信息 ---")
return context_str
except Exception as e:
logging.error(f"从向量数据库检索时发生错误: {e}")
return None
async def generate(self, user_prompt: str) -> Dict[str, Any]:
"""
Generates a py_tree.json structure based on the user's prompt.
"""
logging.info(f"接收到用户请求: {user_prompt}")
retrieved_context = self._retrieve_context(user_prompt)
final_user_prompt = user_prompt
if retrieved_context:
augmentation = (
"\n\n---\n"
"参考知识:\n"
"以下是从知识库中检索到的、与当前任务最相关的信息,请优先参考这些信息来生成行为树:\n"
f"{retrieved_context}"
"\n---"
)
final_user_prompt += augmentation
else:
logging.warning("未检索到上下文或检索失败,将使用原始用户提示词。")
for attempt in range(3):
logging.info(f"--- 第 {attempt + 1}/3 次尝试生成Pytree ---")
try:
response = self.llm_client.chat.completions.create(
model="local-model",
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": final_user_prompt}
],
temperature=0.1,
response_format={"type": "json_object"}
)
pytree_str = response.choices[0].message.content
pytree_dict = json.loads(pytree_str)
if _validate_pytree_with_schema(pytree_dict, self.schema):
logging.info("成功生成并验证了Pytree。")
plan_id = str(uuid.uuid4())
pytree_dict['plan_id'] = plan_id
# Generate visualization to a static path
vis_filename = "py_tree.png"
vis_path = os.path.join(self.vis_dir, vis_filename)
_visualize_pytree(pytree_dict['root'], os.path.splitext(vis_path)[0])
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
return pytree_dict
else:
logging.warning("生成的Pytree验证失败正在重试...")
except (OpenAIError, json.JSONDecodeError) as e:
logging.error(f"生成Pytree时发生错误: {e}")
raise RuntimeError("在3次尝试后仍未能生成一个有效的Pytree。")
# Create a single instance for the application
py_tree_generator = PyTreeGenerator()