843 lines
41 KiB
Python
843 lines
41 KiB
Python
import json
|
||
import os
|
||
import logging
|
||
import uuid
|
||
import re
|
||
from typing import Dict, Any, Optional, Set, List
|
||
import chromadb
|
||
import openai
|
||
from openai import OpenAIError
|
||
import jsonschema
|
||
import requests
|
||
import platform # 新增:用于选择合适的中文字体
|
||
|
||
# --- 自定义远程嵌入函数 (与ingest.py中定义一致) ---
|
||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
|
||
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||
def __init__(self, api_url: str):
|
||
self._api_url = api_url
|
||
def __call__(self, input: Embeddable) -> Embeddings:
|
||
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
|
||
return []
|
||
try:
|
||
response = requests.post(
|
||
self._api_url,
|
||
json={"input": input},
|
||
headers={"Content-Type": "application/json"}
|
||
)
|
||
response.raise_for_status()
|
||
data = response.json().get("data", [])
|
||
if not data:
|
||
return []
|
||
return [item['embedding'] for item in data]
|
||
except Exception as e:
|
||
logging.error(f"调用远程嵌入API时出错: {e}")
|
||
return []
|
||
|
||
# --- 日志记录设置 ---
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.StreamHandler()
|
||
]
|
||
)
|
||
|
||
# ==============================================================================
|
||
# VALIDATION LOGIC (from utils/validation.py)
|
||
# ==============================================================================
|
||
def _parse_allowed_nodes_from_prompt(prompt_text: str) -> tuple[Set[str], Set[str]]:
|
||
"""
|
||
从系统提示词中精确解析出允许的行动和条件节点。
|
||
"""
|
||
try:
|
||
# 使用更精确的正则表达式匹配节点定义部分
|
||
node_section_pattern = r"#### 1\. 可用节点定义.*?```json\s*({.*?})\s*```"
|
||
match = re.search(node_section_pattern, prompt_text, re.DOTALL | re.IGNORECASE)
|
||
|
||
if not match:
|
||
logging.error("在系统提示词中未找到'可用节点定义'部分的JSON代码块。")
|
||
# 备用方案:尝试查找所有JSON块并识别节点定义
|
||
return _fallback_parse_nodes(prompt_text)
|
||
|
||
json_str = match.group(1)
|
||
logging.info("成功找到节点定义JSON代码块")
|
||
|
||
# 解析JSON
|
||
allowed_nodes = json.loads(json_str)
|
||
|
||
# 从对象列表中提取节点名称
|
||
actions = set()
|
||
conditions = set()
|
||
|
||
# 提取动作节点
|
||
if "actions" in allowed_nodes and isinstance(allowed_nodes["actions"], list):
|
||
for action in allowed_nodes["actions"]:
|
||
if isinstance(action, dict) and "name" in action:
|
||
actions.add(action["name"])
|
||
|
||
# 提取条件节点
|
||
if "conditions" in allowed_nodes and isinstance(allowed_nodes["conditions"], list):
|
||
for condition in allowed_nodes["conditions"]:
|
||
if isinstance(condition, dict) and "name" in condition:
|
||
conditions.add(condition["name"])
|
||
|
||
if not actions:
|
||
logging.warning("关键错误:从提示词解析出的行动节点列表为空。")
|
||
|
||
logging.info(f"成功解析出动作节点: {sorted(actions)}")
|
||
logging.info(f"成功解析出条件节点: {sorted(conditions)}")
|
||
|
||
return actions, conditions
|
||
|
||
except json.JSONDecodeError as e:
|
||
logging.error(f"解析节点定义JSON时失败: {e}")
|
||
return set(), set()
|
||
except Exception as e:
|
||
logging.error(f"解析可用节点时发生未知错误: {e}")
|
||
return set(), set()
|
||
|
||
def _fallback_parse_nodes(prompt_text: str) -> tuple[Set[str], Set[str]]:
|
||
"""
|
||
备用解析方案:当精确匹配失败时使用。
|
||
"""
|
||
logging.warning("使用备用方案解析节点定义...")
|
||
|
||
# 查找所有JSON代码块
|
||
matches = re.findall(r"```json\s*({.*?})\s*```", prompt_text, re.DOTALL)
|
||
if not matches:
|
||
logging.error("在系统提示词中未找到任何JSON代码块。")
|
||
return set(), set()
|
||
|
||
# 尝试从每个JSON块中解析节点定义
|
||
for i, json_str in enumerate(matches):
|
||
try:
|
||
data = json.loads(json_str)
|
||
|
||
# 检查是否是节点定义的结构(包含actions、conditions、control_flow)
|
||
if ("actions" in data and isinstance(data["actions"], list) and
|
||
"conditions" in data and isinstance(data["conditions"], list) and
|
||
"control_flow" in data and isinstance(data["control_flow"], list)):
|
||
|
||
actions = set()
|
||
conditions = set()
|
||
|
||
# 提取动作节点
|
||
for action in data["actions"]:
|
||
if isinstance(action, dict) and "name" in action:
|
||
actions.add(action["name"])
|
||
|
||
# 提取条件节点
|
||
for condition in data["conditions"]:
|
||
if isinstance(condition, dict) and "name" in condition:
|
||
conditions.add(condition["name"])
|
||
|
||
if actions:
|
||
logging.info(f"从第{i+1}个JSON块中成功解析出节点定义")
|
||
logging.info(f"动作节点: {sorted(actions)}")
|
||
logging.info(f"条件节点: {sorted(conditions)}")
|
||
return actions, conditions
|
||
|
||
except json.JSONDecodeError:
|
||
continue # 尝试下一个JSON块
|
||
|
||
logging.error("在所有JSON代码块中都没有找到有效的节点定义结构。")
|
||
return set(), set()
|
||
|
||
def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> dict:
|
||
"""
|
||
根据允许的行动和条件节点,动态生成一个JSON Schema。
|
||
"""
|
||
# 所有可能的节点类型
|
||
node_types = ["action", "condition", "Sequence", "Selector", "Parallel", "decorator"]
|
||
|
||
# 目标检测相关的类别枚举
|
||
target_classes = [
|
||
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
|
||
"traffic_light", "fire_hydrant", "stop_sign", "parking_meter", "bench", "bird", "cat",
|
||
"dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack",
|
||
"umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports_ball",
|
||
"kite", "baseball_bat", "baseball_glove", "skateboard", "surfboard", "tennis_racket",
|
||
"bottle", "wine_glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
|
||
"sandwich", "orange", "broccoli", "carrot", "hot_dog", "pizza", "donut", "cake", "chair",
|
||
"couch", "potted_plant", "bed", "dining_table", "toilet", "tv", "laptop", "mouse", "remote",
|
||
"keyboard", "cell_phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
|
||
"clock", "vase", "scissors", "teddy_bear", "hair_drier", "toothbrush","balloon","trash","window"
|
||
]
|
||
|
||
# 递归节点定义
|
||
node_definition = {
|
||
"type": "object",
|
||
"properties": {
|
||
# 修改:手动构造不区分大小写的正则,避免使用不支持的 (?i) 标志
|
||
# 匹配: action, condition, sequence, selector, parallel, decorator (忽略大小写)
|
||
"type": {
|
||
"type": "string",
|
||
"pattern": "^([Aa][Cc][Tt][Ii][Oo][Nn]|[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]|[Ss][Ee][Qq][Uu][Ee][Nn][Cc][Ee]|[Ss][Ee][Ll][Ee][Cc][Tt][Oo][Rr]|[Pp][Aa][Rr][Aa][Ll][Ll][Ee][Ll]|[Dd][Ee][Cc][Oo][Rr][Aa][Tt][Oo][Rr])$"
|
||
},
|
||
"name": {"type": "string"},
|
||
"params": {"type": "object"},
|
||
"children": {
|
||
"type": "array",
|
||
"items": {"$ref": "#/definitions/node"}
|
||
},
|
||
"child": {"$ref": "#/definitions/node"}
|
||
},
|
||
"required": ["type", "name"],
|
||
"allOf": [
|
||
# 动作节点验证 (忽略大小写)
|
||
{
|
||
"if": {"properties": {"type": {"pattern": "^[Aa][Cc][Tt][Ii][Oo][Nn]$"}}},
|
||
"then": {"properties": {"name": {"enum": sorted(list(allowed_actions))}}}
|
||
},
|
||
# 条件节点验证 (忽略大小写)
|
||
{
|
||
"if": {"properties": {"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"}}},
|
||
"then": {"properties": {"name": {"enum": sorted(list(allowed_conditions))}}}
|
||
},
|
||
# 目标检测动作节点的参数验证 (忽略大小写)
|
||
{
|
||
"if": {
|
||
"properties": {
|
||
"type": {"pattern": "^[Aa][Cc][Tt][Ii][Oo][Nn]$"},
|
||
"name": {"const": "object_detect"}
|
||
}
|
||
},
|
||
"then": {
|
||
"properties": {
|
||
"params": {
|
||
"type": "object",
|
||
"properties": {
|
||
"target_class": {"type": "string", "enum": target_classes},
|
||
"description": {"type": "string"},
|
||
"count": {"type": "integer", "minimum": 1}
|
||
},
|
||
"required": ["target_class"],
|
||
"additionalProperties": False
|
||
}
|
||
}
|
||
}
|
||
},
|
||
# 目标检测条件节点的参数验证 (忽略大小写)
|
||
{
|
||
"if": {
|
||
"properties": {
|
||
"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"},
|
||
"name": {"const": "object_detected"}
|
||
}
|
||
},
|
||
"then": {
|
||
"properties": {
|
||
"params": {
|
||
"type": "object",
|
||
"properties": {
|
||
"target_class": {"type": "string", "enum": target_classes},
|
||
"description": {"type": "string"},
|
||
"count": {"type": "integer", "minimum": 1}
|
||
},
|
||
"required": ["target_class"],
|
||
"additionalProperties": False
|
||
}
|
||
}
|
||
}
|
||
},
|
||
# 电池监控节点的参数验证 (忽略大小写)
|
||
{
|
||
"if": {
|
||
"properties": {
|
||
"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"},
|
||
"name": {"const": "battery_above"}
|
||
}
|
||
},
|
||
"then": {
|
||
"properties": {
|
||
"params": {
|
||
"type": "object",
|
||
"properties": {
|
||
"threshold": {"type": "number", "minimum": 0.0, "maximum": 1.0}
|
||
},
|
||
"required": ["threshold"],
|
||
"additionalProperties": False
|
||
}
|
||
}
|
||
}
|
||
},
|
||
# GPS状态节点的参数验证 (忽略大小写)
|
||
{
|
||
"if": {
|
||
"properties": {
|
||
"type": {"pattern": "^[Cc][Oo][Nn][Dd][Ii][Tt][Ii][Oo][Nn]$"},
|
||
"name": {"const": "gps_status"}
|
||
}
|
||
},
|
||
"then": {
|
||
"properties": {
|
||
"params": {
|
||
"type": "object",
|
||
"properties": {
|
||
"min_satellites": {"type": "integer", "minimum": 6, "maximum": 15}
|
||
},
|
||
"required": ["min_satellites"],
|
||
"additionalProperties": False
|
||
}
|
||
}
|
||
}
|
||
}
|
||
]
|
||
}
|
||
|
||
# 完整的Schema结构
|
||
schema = {
|
||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||
"title": "Pytree",
|
||
"definitions": {
|
||
"node": node_definition
|
||
},
|
||
"type": "object",
|
||
"properties": {
|
||
"root": { "$ref": "#/definitions/node" }
|
||
},
|
||
"required": ["root"]
|
||
}
|
||
|
||
return schema
|
||
|
||
def _generate_simple_mode_schema(allowed_actions: set) -> dict:
|
||
"""
|
||
生成简单模式JSON Schema:{"root":{"type":"action","name":"...","params":{...}}}
|
||
简单模式与复杂模式使用相同的格式(root字段),但要求root必须是action类型且没有children。
|
||
严格按照提示词要求:root节点必须是action类型节点,不能是控制流节点(即不能有children)。
|
||
仅校验动作名称在允许集合内,以及基本结构完整性;参数按对象形状放宽,由上游提示词与运行时再约束。
|
||
"""
|
||
schema = {
|
||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||
"title": "SimpleMode",
|
||
"type": "object",
|
||
"properties": {
|
||
"root": {
|
||
"type": "object",
|
||
"properties": {
|
||
"type": {"type": "string", "const": "action"}, # 必须是action类型,不能是控制流节点(Sequence/Selector/Parallel)
|
||
"name": {"type": "string", "enum": sorted(list(allowed_actions))}, # 动作名称必须在允许列表中
|
||
"params": {"type": "object"} # params是对象,具体参数由提示词和运行时约束
|
||
},
|
||
"required": ["type", "name"], # type和name是必需的,params可选
|
||
"additionalProperties": True # 允许root节点有其他属性(如额外的元数据)
|
||
# 注意:children字段的检查在验证后手动进行,因为JSON Schema的not/allOf在检查不存在字段时可能有问题
|
||
}
|
||
},
|
||
"required": ["root"], # 顶层必须有root字段
|
||
"additionalProperties": False # 顶层只能有root字段,不能有其他字段(如mode等)
|
||
}
|
||
return schema
|
||
|
||
def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
|
||
"""
|
||
使用JSON Schema验证给定的Pytree实例。
|
||
"""
|
||
try:
|
||
jsonschema.validate(instance=pytree_instance, schema=schema)
|
||
logging.info("✅ JSON Schema验证成功")
|
||
|
||
return True
|
||
except jsonschema.ValidationError as e:
|
||
logging.warning("❌ Pytree验证失败")
|
||
logging.warning(f"错误信息: {e.message}")
|
||
error_path = list(e.path)
|
||
logging.warning(f"错误路径: {' -> '.join(map(str, error_path)) if error_path else '根节点'}")
|
||
|
||
# 提供更具体的错误信息
|
||
if "object_detect" in str(e.message) or "object_detected" in str(e.message):
|
||
logging.warning("💡 提示: 请确保目标类别是预定义列表中的有效值")
|
||
elif "battery_above" in str(e.message):
|
||
logging.warning("💡 提示: 电池阈值必须在0.0到1.0之间")
|
||
elif "gps_status" in str(e.message):
|
||
logging.warning("💡 提示: 最小卫星数量必须在6到15之间")
|
||
|
||
return False
|
||
except Exception as e:
|
||
logging.error(f"进行JSON Schema验证时发生未知错误: {e}")
|
||
return False
|
||
|
||
# ==============================================================================
|
||
# VISUALIZATION LOGIC (from utils/visualization.py)
|
||
# ==============================================================================
|
||
def _visualize_pytree(node: Dict, file_path: str):
|
||
"""
|
||
使用Graphviz将Pytree字典可视化,并保存到指定路径。
|
||
"""
|
||
try:
|
||
from graphviz import Digraph
|
||
except ImportError:
|
||
logging.critical("错误:未安装graphviz库。请运行: pip install graphviz")
|
||
return
|
||
|
||
# 选择合适的中文字体,避免中文乱码
|
||
def _pick_zh_font():
|
||
sys = platform.system()
|
||
if sys == "Windows":
|
||
return "Microsoft YaHei"
|
||
elif sys == "Darwin":
|
||
return "PingFang SC"
|
||
else:
|
||
return "Noto Sans CJK SC"
|
||
|
||
fontname = _pick_zh_font()
|
||
|
||
dot = Digraph('Pytree', comment='Drone Mission Plan')
|
||
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20', fontname=fontname)
|
||
dot.attr('node', shape='box', style='rounded,filled', fontname=fontname)
|
||
dot.attr('edge', fontname=fontname)
|
||
|
||
_add_nodes_and_edges(node, dot)
|
||
|
||
try:
|
||
# 确保输出目录存在,并避免生成 .png.png
|
||
base_path, ext = os.path.splitext(file_path)
|
||
render_path = base_path if ext.lower() == '.png' else file_path
|
||
|
||
out_dir = os.path.dirname(render_path)
|
||
if out_dir and not os.path.exists(out_dir):
|
||
os.makedirs(out_dir, exist_ok=True)
|
||
|
||
# 保存为 .png 文件,并自动删除源码 .gv 文件
|
||
output_path = dot.render(render_path, format='png', cleanup=True, view=False)
|
||
logging.info("✅ 任务树可视化成功")
|
||
logging.info(f"图形已保存到: {output_path}")
|
||
except Exception as e:
|
||
logging.error("❌ 生成可视化图形失败")
|
||
logging.error("请确保您的系统已经正确安装了Graphviz图形库。")
|
||
logging.error(f"错误详情: {e}")
|
||
|
||
def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
|
||
"""递归辅助函数,用于添加节点和边。"""
|
||
|
||
# 为每个节点创建一个唯一的ID(加上随机数避免冲突)
|
||
import random
|
||
import html
|
||
|
||
current_id = f"{id(node)}_{random.randint(1000, 9999)}"
|
||
|
||
# 准备节点标签(HTML-like,正确换行与转义)
|
||
name = html.escape(str(node.get('name', '')))
|
||
ntype = html.escape(str(node.get('type', '')))
|
||
label_parts = [f"<B>{name}</B> <FONT POINT-SIZE='10'><I>({ntype})</I></FONT>"]
|
||
|
||
# 格式化参数显示
|
||
params = node.get('params') or {}
|
||
if params:
|
||
params_lines = []
|
||
for key, value in params.items():
|
||
k = html.escape(str(key))
|
||
if isinstance(value, float):
|
||
value_str = f"{value:.2f}".rstrip('0').rstrip('.')
|
||
else:
|
||
value_str = str(value)
|
||
v = html.escape(value_str)
|
||
params_lines.append(f"{k}: {v}")
|
||
params_text = "<BR ALIGN='LEFT'/>".join(params_lines)
|
||
label_parts.append(f"<FONT POINT-SIZE='9' COLOR='#555555'>{params_text}</FONT>")
|
||
|
||
node_label = f"<{'<BR/>'.join(label_parts)}>"
|
||
|
||
# 根据类型设置节点样式和颜色(使用 fillcolor 控制填充色)
|
||
node_type = (node.get('type') or '').lower()
|
||
shape = 'ellipse'
|
||
style = 'filled'
|
||
fillcolor = '#e6e6e6' # 默认灰色填充
|
||
border_color = '#666666' # 默认描边色
|
||
|
||
if node_type == 'action':
|
||
shape = 'box'
|
||
style = 'rounded,filled'
|
||
fillcolor = "#cde4ff" # 浅蓝
|
||
elif node_type == 'condition':
|
||
shape = 'diamond'
|
||
style = 'filled'
|
||
fillcolor = "#fff2cc" # 浅黄
|
||
elif node_type == 'sequence':
|
||
shape = 'ellipse'
|
||
style = 'filled'
|
||
fillcolor = '#d5e8d4' # 绿色
|
||
elif node_type == 'selector':
|
||
shape = 'ellipse'
|
||
style = 'filled'
|
||
fillcolor = '#ffe6cc' # 橙色
|
||
elif node_type == 'parallel':
|
||
shape = 'ellipse'
|
||
style = 'filled'
|
||
fillcolor = '#e1d5e7' # 紫色
|
||
elif node_type == 'decorator':
|
||
shape = 'doubleoctagon'
|
||
style = 'filled'
|
||
fillcolor = '#f8cecc' # 浅红
|
||
|
||
# 特别标记安全相关节点
|
||
if node.get('name') in ['battery_above', 'gps_status', 'SafetyMonitor']:
|
||
border_color = '#ff0000' # 红色边框突出显示安全节点
|
||
style = 'filled,bold' # 加粗
|
||
|
||
dot.node(current_id, label=node_label, shape=shape, style=style, fillcolor=fillcolor, color=border_color)
|
||
|
||
# 连接父节点
|
||
if parent_id:
|
||
dot.edge(parent_id, current_id)
|
||
|
||
# 递归处理子节点 (Sequence, Selector, Parallel 等)
|
||
children = node.get("children", [])
|
||
if children:
|
||
# 记录所有子节点的ID
|
||
child_ids = []
|
||
|
||
# 正确的递归连接:每个子节点都连接到当前节点
|
||
for child in children:
|
||
child_id = _add_nodes_and_edges(child, dot, current_id)
|
||
child_ids.append(child_id)
|
||
|
||
# 子节点同级排列(横向排布,更直观地表现同层)
|
||
if len(child_ids) > 1:
|
||
with dot.subgraph(name=f"rank_{current_id}") as s:
|
||
s.attr(rank='same')
|
||
for cid in child_ids:
|
||
s.node(cid)
|
||
|
||
# 递归处理单子节点 (Decorator)
|
||
child = node.get("child")
|
||
if child:
|
||
_add_nodes_and_edges(child, dot, current_id)
|
||
|
||
return current_id
|
||
|
||
# ==============================================================================
|
||
# CORE PYTREE GENERATOR CLASS
|
||
# ==============================================================================
|
||
class PyTreeGenerator:
|
||
def __init__(self):
|
||
self.base_dir = os.path.dirname(os.path.abspath(__file__))
|
||
self.prompts_dir = os.path.join(self.base_dir, 'prompts')
|
||
|
||
# Updated output directory for visualizations
|
||
self.vis_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_visualizations'))
|
||
os.makedirs(self.vis_dir, exist_ok=True)
|
||
# Reasoning content output directory (Markdown files)
|
||
self.reasoning_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_reasoning_content'))
|
||
os.makedirs(self.reasoning_dir, exist_ok=True)
|
||
# 控制是否允许模型返回含 <think> 的原文(不强制JSON),以便提取推理链
|
||
self.enable_reasoning_capture = os.getenv("ENABLE_REASONING_CAPTURE", "true").lower() in ("1", "true", "yes")
|
||
# 终端预览的最大行数
|
||
try:
|
||
self.reasoning_preview_lines = int(os.getenv("REASONING_PREVIEW_LINES", "20"))
|
||
except Exception:
|
||
self.reasoning_preview_lines = 20
|
||
# 加载提示词:复杂模式复用现有 system_prompt.txt;简单模式与分类器独立提示词
|
||
self.complex_prompt = self._load_prompt("system_prompt.txt")
|
||
self.simple_prompt = self._load_prompt("simple_mode_prompt.txt")
|
||
self.classifier_prompt = self._load_prompt("classifier_prompt.txt")
|
||
# 兼容旧变量名
|
||
self.system_prompt = self.complex_prompt
|
||
|
||
self.orin_ip = os.getenv("ORIN_IP", "localhost")
|
||
# 三类模型的可配置项:基于不同模型与Base URL分流
|
||
self.classifier_model = os.getenv("CLASSIFIER_MODEL", os.getenv("OPENAI_MODEL", "local-model"))
|
||
self.simple_model = os.getenv("SIMPLE_MODEL", os.getenv("OPENAI_MODEL", "local-model"))
|
||
self.complex_model = os.getenv("COMPLEX_MODEL", os.getenv("OPENAI_MODEL", "local-model"))
|
||
self.classifier_base_url = os.getenv("CLASSIFIER_BASE_URL", f"http://{self.orin_ip}:8081/v1")
|
||
self.simple_base_url = os.getenv("SIMPLE_BASE_URL", f"http://{self.orin_ip}:8081/v1")
|
||
self.complex_base_url = os.getenv("COMPLEX_BASE_URL", f"http://{self.orin_ip}:8081/v1")
|
||
self.api_key = os.getenv("OPENAI_API_KEY", "sk-no-key-required")
|
||
# 直接在代码中指定最大输出token数(不通过环境变量)
|
||
self.classifier_max_tokens = 512
|
||
self.simple_max_tokens = 8192
|
||
self.complex_max_tokens = 8192
|
||
|
||
# 为不同用途分别创建客户端
|
||
self.classifier_client = openai.OpenAI(api_key=self.api_key, base_url=self.classifier_base_url)
|
||
self.simple_llm_client = openai.OpenAI(api_key=self.api_key, base_url=self.simple_base_url)
|
||
self.complex_llm_client = openai.OpenAI(api_key=self.api_key, base_url=self.complex_base_url)
|
||
|
||
# --- ChromaDB Client Setup ---
|
||
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'rag','vector_store'))
|
||
self.chroma_client = chromadb.PersistentClient(path=vector_store_path)
|
||
|
||
# Explicitly use the remote embedding function for queries
|
||
embedding_api_url = f"http://{self.orin_ip}:8090/v1/embeddings"
|
||
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
|
||
self.collection = self.chroma_client.get_collection(
|
||
name="drone_docs",
|
||
embedding_function=embedding_func
|
||
)
|
||
|
||
# 使用复杂模式提示词作为节点来源,确保Schema稳定
|
||
allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.complex_prompt)
|
||
self.schema = _generate_pytree_schema(allowed_actions, allowed_conditions)
|
||
self.simple_schema = _generate_simple_mode_schema(allowed_actions)
|
||
|
||
def _load_prompt(self, file_name: str) -> str:
|
||
try:
|
||
with open(os.path.join(self.prompts_dir, file_name), 'r', encoding='utf-8') as f:
|
||
return f.read()
|
||
except FileNotFoundError:
|
||
logging.error(f"提示词文件未找到 -> {file_name}")
|
||
return ""
|
||
|
||
def _retrieve_context(self, query: str) -> Optional[str]:
|
||
logging.info("--- 开始从向量数据库检索上下文 ---")
|
||
try:
|
||
results = self.collection.query(query_texts=[query], n_results=5)
|
||
retrieved_docs = results.get("documents", [[]])[0]
|
||
if not retrieved_docs:
|
||
logging.warning("在向量数据库中没有找到相关的上下文信息。")
|
||
return None
|
||
context_str = "\n\n".join(retrieved_docs)
|
||
logging.info("--- 成功检索到上下文信息 ---")
|
||
# 打印检索到的上下文内容
|
||
logging.info(f"📚 检索到的上下文内容:\n{context_str}")
|
||
return context_str
|
||
except Exception as e:
|
||
logging.error(f"从向量数据库检索时发生错误: {e}")
|
||
return None
|
||
|
||
async def generate(self, user_prompt: str) -> Dict[str, Any]:
|
||
"""
|
||
Generates a py_tree.json structure based on the user's prompt.
|
||
"""
|
||
logging.info(f"接收到用户请求: {user_prompt}")
|
||
|
||
# 第一步:分类(简单/复杂)
|
||
mode = "complex"
|
||
try:
|
||
classifier_resp = self.classifier_client.chat.completions.create(
|
||
model=self.classifier_model,
|
||
messages=[
|
||
{"role": "system", "content": self.classifier_prompt or "你是一个分类器,只输出JSON。"},
|
||
{"role": "user", "content": user_prompt}
|
||
],
|
||
temperature=0.0,
|
||
response_format={"type": "json_object"}, # 强制JSON输出,禁用思考功能
|
||
max_tokens=self.classifier_max_tokens,
|
||
# 禁用 Qwen3 模型的思考功能(通过 extra_body 传递)
|
||
# 注意:如果 API 服务器不支持此参数,会忽略
|
||
extra_body={"chat_template_kwargs": {"enable_thinking": False}}
|
||
)
|
||
class_str = classifier_resp.choices[0].message.content
|
||
class_obj = json.loads(class_str)
|
||
if isinstance(class_obj, dict) and class_obj.get("mode") in ("simple", "complex"):
|
||
mode = class_obj.get("mode")
|
||
logging.info(f"分类结果: {mode}")
|
||
except Exception as e:
|
||
logging.warning(f"分类失败,默认按复杂指令处理: {e}")
|
||
|
||
# 第二步:根据模式准备提示词与上下文(简单与复杂都执行检索增强)
|
||
# 基于模式选择提示词;复杂模式追加一条强制规则,避免模型误输出简单结构
|
||
use_prompt = self.simple_prompt if mode == "simple" else (
|
||
(self.complex_prompt or "") +
|
||
"\n\n【强制规则】仅生成包含root的复杂行为树JSON,不得输出简单模式(不得包含mode字段或仅有action节点)。"
|
||
)
|
||
final_user_prompt = user_prompt
|
||
retrieved_context = self._retrieve_context(user_prompt)
|
||
if retrieved_context:
|
||
augmentation = (
|
||
"\n\n---\n"
|
||
"参考知识:\n"
|
||
"以下是从知识库中检索到的、与当前任务最相关的信息,请优先参考这些信息来生成结果:\n"
|
||
f"{retrieved_context}"
|
||
"\n---"
|
||
)
|
||
final_user_prompt += augmentation
|
||
else:
|
||
logging.warning("未检索到上下文或检索失败,将使用原始用户提示词。")
|
||
|
||
# 构建完整的 final_prompt(准确反映实际发送给大模型的内容结构)
|
||
# 注意:RAG检索结果被添加到 user prompt 中,而不是 system prompt
|
||
# System Prompt: use_prompt(不包含RAG结果)
|
||
# User Prompt: final_user_prompt(包含原始user_prompt + RAG检索结果)
|
||
final_prompt = f"=== System Prompt ===\n{use_prompt}\n\n=== User Prompt ===\n{final_user_prompt}"
|
||
for attempt in range(3):
|
||
logging.info(f"--- 第 {attempt + 1}/3 次尝试生成Pytree ---")
|
||
try:
|
||
# 简单/复杂分流到不同模型与提示词
|
||
client = self.simple_llm_client if mode == "simple" else self.complex_llm_client
|
||
model_name = self.simple_model if mode == "simple" else self.complex_model
|
||
# 始终强制JSON响应并禁用思考功能
|
||
response_kwargs = {
|
||
"model": model_name,
|
||
"messages": [
|
||
{"role": "system", "content": use_prompt},
|
||
{"role": "user", "content": final_user_prompt}
|
||
],
|
||
"temperature": 0.1 if mode == "complex" else 0.0,
|
||
"response_format": {"type": "json_object"}, # 始终强制JSON输出,禁用思考功能
|
||
# 禁用 Qwen3 模型的思考功能(通过 extra_body 传递)
|
||
# 注意:如果 API 服务器不支持此参数,会忽略
|
||
"extra_body": {"chat_template_kwargs": {"enable_thinking": False}}
|
||
}
|
||
# 基于模式设定最大输出token数(直接在代码中配置)
|
||
response_kwargs["max_tokens"] = self.simple_max_tokens if mode == "simple" else self.complex_max_tokens
|
||
response = client.chat.completions.create(**response_kwargs)
|
||
# 兼容可能存在的 reasoning_content 字段
|
||
try:
|
||
msg = response.choices[0].message
|
||
msg_content = getattr(msg, "content", None)
|
||
msg_reasoning = getattr(msg, "reasoning_content", None)
|
||
except Exception:
|
||
msg = response.choices[0]["message"] if isinstance(response.choices[0], dict) else None
|
||
msg_content = (msg or {}).get("content") if isinstance(msg, dict) else None
|
||
msg_reasoning = (msg or {}).get("reasoning_content") if isinstance(msg, dict) else None
|
||
|
||
combined_text = ""
|
||
if isinstance(msg_reasoning, str) and msg_reasoning.strip():
|
||
# 将 reasoning_content 包装为 <think>,便于统一解析
|
||
combined_text += f"<think>\n{msg_reasoning}\n</think>\n"
|
||
if isinstance(msg_content, str) and msg_content.strip():
|
||
combined_text += msg_content
|
||
pytree_str = combined_text if combined_text else (msg_content or "")
|
||
raw_full_text_for_logging = pytree_str # 保存完整原文(含 <think>)以便失败时完整打印
|
||
|
||
# 提取 <think> 推理链内容(若有)
|
||
reasoning_text = None
|
||
try:
|
||
think_match = re.search(r"<think>([\s\S]*?)</think>", pytree_str)
|
||
if think_match:
|
||
reasoning_text = think_match.group(1).strip()
|
||
# 去除推理文本后再尝试解析JSON
|
||
pytree_str = re.sub(r"<think>[\s\S]*?</think>", "", pytree_str).strip()
|
||
except Exception:
|
||
reasoning_text = None
|
||
# 单独捕获JSON解析错误并打印原始响应
|
||
try:
|
||
pytree_dict = json.loads(pytree_str)
|
||
except json.JSONDecodeError as e:
|
||
logging.error(f"❌ JSON解析失败(第 {attempt + 1}/3 次)。\n—— 完整原始文本(含<think>) ——\n{raw_full_text_for_logging}")
|
||
# 尝试打印响应对象的完整结构
|
||
try:
|
||
raw_response_dump = None
|
||
if hasattr(response, 'model_dump_json'):
|
||
raw_response_dump = response.model_dump_json(indent=2, exclude_none=False)
|
||
elif hasattr(response, 'dict'):
|
||
raw_response_dump = json.dumps(response.dict(), ensure_ascii=False, indent=2, default=str)
|
||
else:
|
||
# 兜底:尝试将choices与关键字段展开
|
||
safe_obj = {
|
||
"id": getattr(response, 'id', None),
|
||
"model": getattr(response, 'model', None),
|
||
"object": getattr(response, 'object', None),
|
||
"usage": getattr(response, 'usage', None),
|
||
"choices": [
|
||
{
|
||
"index": getattr(c, 'index', None),
|
||
"finish_reason": getattr(c, 'finish_reason', None),
|
||
"message": {
|
||
"role": getattr(getattr(c, 'message', None), 'role', None),
|
||
"content": getattr(getattr(c, 'message', None), 'content', None),
|
||
"reasoning_content": getattr(getattr(c, 'message', None), 'reasoning_content', None)
|
||
} if getattr(c, 'message', None) is not None else None
|
||
}
|
||
for c in getattr(response, 'choices', [])
|
||
] if hasattr(response, 'choices') else None
|
||
}
|
||
raw_response_dump = json.dumps(safe_obj, ensure_ascii=False, indent=2, default=str)
|
||
logging.error(f"—— 完整响应对象 ——\n{raw_response_dump}")
|
||
except Exception as dump_e:
|
||
try:
|
||
logging.error(f"响应对象转储失败,repr如下:\n{repr(response)}")
|
||
except Exception:
|
||
pass
|
||
continue
|
||
|
||
# 简单/复杂分别验证与返回
|
||
if mode == "simple":
|
||
try:
|
||
jsonschema.validate(instance=pytree_dict, schema=self.simple_schema)
|
||
# 手动检查:简单模式的root节点不能有children(或children必须是空数组)
|
||
root_node = pytree_dict.get('root', {})
|
||
if 'children' in root_node:
|
||
children = root_node.get('children', [])
|
||
if isinstance(children, list) and len(children) > 0:
|
||
logging.warning(f"❌ 简单模式验证失败: root节点不能有children,但发现 {len(children)} 个子节点")
|
||
continue
|
||
logging.info("✅ 简单模式JSON Schema验证成功")
|
||
except jsonschema.ValidationError as e:
|
||
logging.warning(f"❌ 简单模式验证失败: {e.message}")
|
||
continue
|
||
# 附加元信息并生成简单可视化(单动作)
|
||
plan_id = str(uuid.uuid4())
|
||
pytree_dict['plan_id'] = plan_id
|
||
# 简单模式可视化:使用root节点(已经是action类型)
|
||
try:
|
||
vis_filename = "py_tree.png"
|
||
vis_path = os.path.join(self.vis_dir, vis_filename)
|
||
# 简单模式的root节点就是action节点,直接使用
|
||
root_node = pytree_dict.get('root', {})
|
||
_visualize_pytree(root_node, os.path.splitext(vis_path)[0])
|
||
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
|
||
except Exception as e:
|
||
logging.warning(f"简单模式可视化失败: {e}")
|
||
|
||
# 保存推理链(若有)
|
||
try:
|
||
if reasoning_text:
|
||
reasoning_path = os.path.join(self.reasoning_dir, "reasoning_content.md")
|
||
with open(reasoning_path, 'w', encoding='utf-8') as rf:
|
||
rf.write(reasoning_text)
|
||
logging.info(f"📝 推理链已保存: {reasoning_path}")
|
||
# 终端预览(最多N行)
|
||
try:
|
||
lines = reasoning_text.splitlines()
|
||
preview = "\n".join(lines[: self.reasoning_preview_lines])
|
||
logging.info("🧠 推理链预览(前%d行):\n%s", self.reasoning_preview_lines, preview)
|
||
except Exception:
|
||
pass
|
||
else:
|
||
logging.info("未在模型输出中发现 <think> 推理链片段。若需捕获,请设置 ENABLE_REASONING_CAPTURE=true 以放宽JSON强制格式。")
|
||
except Exception as e:
|
||
logging.warning(f"保存推理链Markdown失败: {e}")
|
||
# 添加 final_prompt 到返回结果
|
||
pytree_dict['final_prompt'] = final_prompt
|
||
return pytree_dict
|
||
|
||
# 验证生成的复杂行为树
|
||
if _validate_pytree_with_schema(pytree_dict, self.schema):
|
||
logging.info("✅ 成功生成并验证了Pytree")
|
||
plan_id = str(uuid.uuid4())
|
||
pytree_dict['plan_id'] = plan_id
|
||
|
||
# Generate visualization to a static path
|
||
vis_filename = "py_tree.png"
|
||
vis_path = os.path.join(self.vis_dir, vis_filename)
|
||
_visualize_pytree(pytree_dict['root'], os.path.splitext(vis_path)[0])
|
||
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
|
||
|
||
# 保存推理链(若有)
|
||
try:
|
||
if reasoning_text:
|
||
reasoning_path = os.path.join(self.reasoning_dir, "reasoning_content.md")
|
||
with open(reasoning_path, 'w', encoding='utf-8') as rf:
|
||
rf.write(reasoning_text)
|
||
logging.info(f"📝 推理链已保存: {reasoning_path}")
|
||
# 终端预览(最多N行)
|
||
try:
|
||
lines = reasoning_text.splitlines()
|
||
preview = "\n".join(lines[: self.reasoning_preview_lines])
|
||
logging.info("🧠 推理链预览(前%d行):\n%s", self.reasoning_preview_lines, preview)
|
||
except Exception:
|
||
pass
|
||
else:
|
||
logging.info("未在模型输出中发现 <think> 推理链片段。若需捕获,请设置 ENABLE_REASONING_CAPTURE=true 以放宽JSON强制格式。")
|
||
except Exception as e:
|
||
logging.warning(f"保存推理链Markdown失败: {e}")
|
||
# 添加 final_prompt 到返回结果
|
||
pytree_dict['final_prompt'] = final_prompt
|
||
return pytree_dict
|
||
else:
|
||
# 打印未通过验证的Pytree以便排查
|
||
preview = json.dumps(pytree_dict, ensure_ascii=False, indent=2)
|
||
logging.warning(f"❌ 未通过验证的Pytree(第 {attempt + 1}/3 次尝试):\n{preview}")
|
||
logging.warning("生成的Pytree验证失败,正在重试...")
|
||
except OpenAIError as e:
|
||
logging.error(f"生成Pytree时发生错误: {e}")
|
||
|
||
raise RuntimeError("在3次尝试后,仍未能生成一个有效的Pytree。")
|
||
|
||
# Create a single instance for the application
|
||
py_tree_generator = PyTreeGenerator()
|