代码内容转移

This commit is contained in:
2025-08-17 22:41:54 +08:00
commit 0b50022af1
38 changed files with 72624 additions and 0 deletions

View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,93 @@
import asyncio
import os
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
import logging
import threading
import rclpy
from .models import GeneratePlanRequest, ExecuteMissionRequest
from .websocket_manager import websocket_manager
from .py_tree_generator import py_tree_generator
from .ros2_client import MissionActionClient
# --- Application Setup ---
app = FastAPI(
title="Drone Backend Service",
description="Handles mission planning, generation, and execution for the drone.",
version="1.0.0",
)
# --- Mount Static Files for Visualizations ---
static_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'generated_visualizations'))
app.mount("/static", StaticFiles(directory=static_dir), name="static")
# --- ROS2 Node and Client Initialization ---
rclpy.init()
ros2_client = MissionActionClient()
def run_ros2_node():
"""Spins the ROS2 node in a dedicated thread."""
logging.info("Starting to spin ROS2 node...")
rclpy.spin(ros2_client)
logging.info("ROS2 node has stopped spinning.")
# --- API Endpoints ---
@app.post("/generate_plan", response_model=dict)
async def generate_plan_endpoint(request: GeneratePlanRequest):
"""
Receives a user prompt and returns a generated `py_tree.json` with a visualization URL.
"""
try:
pytree_dict = await py_tree_generator.generate(request.user_prompt)
return pytree_dict
except RuntimeError as e:
return {"error": str(e)}
@app.post("/execute_mission", response_model=dict)
async def execute_mission_endpoint(request: ExecuteMissionRequest):
"""
Receives a `py_tree.json` and sends it to the drone for execution.
"""
ros2_client.send_goal(request.py_tree)
return {"status": "execution_started"}
@app.websocket("/ws/status")
async def websocket_endpoint(websocket: WebSocket):
"""
Handles the WebSocket connection for real-time status updates.
"""
await websocket_manager.connect(websocket)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
websocket_manager.disconnect(websocket)
logging.info("Client disconnected from WebSocket.")
# --- Server Lifecycle ---
@app.on_event("startup")
async def startup_event():
"""
On startup, get the current asyncio event loop and pass it to the websocket manager.
Also, start the ROS2 node in a background thread.
"""
# Configure WebSocket Manager
loop = asyncio.get_running_loop()
websocket_manager.set_loop(loop)
logging.info("WebSocket event loop configured.")
# Start ROS2 node in a background thread
ros2_thread = threading.Thread(target=run_ros2_node, daemon=True)
ros2_thread.start()
logging.info("ROS2 node thread started.")
@app.on_event("shutdown")
async def shutdown_event():
logging.info("Backend service shutting down.")
ros2_client.destroy_node()
rclpy.shutdown()
logging.info("ROS2 node shut down successfully.")

View File

@@ -0,0 +1,12 @@
from pydantic import BaseModel
from typing import Dict, Any
class GeneratePlanRequest(BaseModel):
user_prompt: str
class ExecuteMissionRequest(BaseModel):
py_tree: Dict[str, Any]
class StatusUpdate(BaseModel):
node_id: str
status: int

View File

@@ -0,0 +1,143 @@
你是一个无人机任务规划专家。你的**唯一**任务是根据用户提供的任务指令和参考知识生成一个结构化、可执行的行为树Pytree
你的输出**必须**是一个严格的、单一的JSON对象不包含任何形式的解释、总结或自然语言描述。
---
#### 1. 可用节点定义 (必须遵守)
你**必须**严格从以下JSON定义的列表中选择节点来构建行为树。不允许幻想或使用任何未定义的节点。
```json
{
"actions": [
{
"name": "takeoff",
"description": "无人机从当前位置垂直起飞到指定的海拔高度。",
"params": {
"altitude": "float, 目标海拔高度(米)"
}
},
{
"name": "fly_to_waypoint",
"description": "导航至一个WGS84坐标航点。无人机到达航点后该动作才算完成。",
"params": {
"latitude": "float, 目标纬度",
"longitude": "float, 目标经度",
"altitude": "float, 目标海拔高度(米)"
}
},
{
"name": "loiter_for_duration",
"description": "在当前位置上空悬停或盘旋一段时间。",
"params": {
"duration": "float, 悬停时间(秒)"
}
},
{
"name": "land_at_position",
"description": "在当前位置降落。",
"params": {}
},
{
"name": "return_to_launch",
"description": "自动返航并降落到起飞点。",
"params": {}
}
],
"conditions": [],
"control_flow": [
{
"name": "Sequence",
"description": "序列节点,按顺序执行其子节点。只有当所有子节点都成功时,它才成功。",
"params": {},
"children": "array, 包含按顺序执行的子节点"
}
]
}
```
---
#### 2. JSON结构规范 (必须遵守)
生成的JSON对象**必须**有一个名为`root`的键,其值是一个有效的行为树节点对象。每个节点都必须包含 "type" 和 "name" 字段。
```json
{
"root": {
"type": "string, 'action' 或 'Sequence'",
"name": "string, 来自上方可用节点列表",
"params": "object, 包含所需的参数",
"children": "array, (可选) 包含子节点对象"
}
}
```
---
#### 3. 如何使用参考知识 (必须遵守)
当系统提供“参考知识”时,你**必须**使用其中的坐标来填充`params`字段。例如,如果任务是“飞到教学楼”,并且参考知识中提供了“教学楼”的坐标是 `(纬度: 31.2304, 经度: 121.4737)`,那么你的 `fly_to_waypoint` 节点应该写成:
`"params": {"latitude": 31.2304, "longitude": 121.4737, "altitude": 50.0}` (高度为默认或指定值)
---
#### 4. 任务规划示例 (Few-shot Learning)
以下是一些完整的任务规划示例,请学习并模仿它们的思考过程和输出格式。
##### 示例 1
**用户任务:**
"无人机起飞飞到教学楼进行30秒的勘察然后返航降落。"
**参考知识:**
"地理元素 '教学楼' (ID: 123) 是一个 建筑。 其中心位置坐标大约为 (纬度: 31.2304, 经度: 121.4737)。"
**生成的Pytree:**
```json
{
"root": {
"type": "Sequence",
"name": "Mission",
"children": [
{ "type": "action", "name": "takeoff", "params": { "altitude": 20.0 } },
{ "type": "action", "name": "fly_to_waypoint", "params": { "latitude": 31.2304, "longitude": 121.4737, "altitude": 20.0 } },
{ "type": "action", "name": "loiter_for_duration", "params": { "duration": 30.0 } },
{ "type": "action", "name": "return_to_launch", "params": {} }
]
}
}
```
##### 示例 2
**用户任务:**
"飞到图书馆,然后去体育馆,最后在体育馆的位置降落。"
**参考知识:**
"地理元素 '图书馆' (ID: 456) 是一个 建筑。 其中心位置坐标大约为 (纬度: 31.2315, 经度: 121.4758)。
地理元素 '体育馆' (ID: 789) 是一个 建筑。 其中心位置坐标大约为 (纬度: 31.2330, 经度: 121.4780)。"
**生成的Pytree:**
```json
{
"root": {
"type": "Sequence",
"name": "Mission",
"children": [
{ "type": "action", "name": "takeoff", "params": { "altitude": 50.0 } },
{ "type": "action", "name": "fly_to_waypoint", "params": { "latitude": 31.2315, "longitude": 121.4758, "altitude": 50.0 } },
{ "type": "action", "name": "fly_to_waypoint", "params": { "latitude": 31.2330, "longitude": 121.4780, "altitude": 50.0 } },
{ "type": "action", "name": "land_at_position", "params": {} }
]
}
}
```
---
#### 5. 最终指令
现在请严格按照以上规则和示例根据用户提供的最新任务和参考知识生成行为树JSON。直接输出JSON对象不要有任何其他内容。

View File

@@ -0,0 +1,330 @@
import json
import os
import logging
import uuid
import re
from typing import Dict, Any, Optional, Set
import chromadb
import openai
from openai import OpenAIError
import jsonschema
import requests
# --- 自定义远程嵌入函数 (与ingest.py中定义一致) ---
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
def __init__(self, api_url: str):
self._api_url = api_url
def __call__(self, input: Embeddable) -> Embeddings:
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
return []
try:
response = requests.post(
self._api_url,
json={"input": input},
headers={"Content-Type": "application/json"}
)
response.raise_for_status()
data = response.json().get("data", [])
if not data:
return []
return [item['embedding'] for item in data]
except Exception as e:
logging.error(f"调用远程嵌入API时出错: {e}")
return []
# --- 日志记录设置 ---
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
# ==============================================================================
# VALIDATION LOGIC (from utils/validation.py)
# ==============================================================================
def _parse_allowed_nodes_from_prompt(prompt_text: str) -> tuple[Set[str], Set[str]]:
"""
从系统提示词中解析出允许的行动和条件节点。
"""
try:
# 使用正则表达式查找JSON代码块
match = re.search(r"```json\s*({.*?})\s*```", prompt_text, re.DOTALL)
if not match:
logging.error("在系统提示词中未找到可用节点的JSON定义块。")
return set(), set()
json_str = match.group(1)
allowed_nodes = json.loads(json_str)
# 从对象列表中提取节点名称
actions = {action['name'] for action in allowed_nodes.get("actions", []) if 'name' in action}
conditions = {condition['name'] for condition in allowed_nodes.get("conditions", []) if 'name' in condition}
if not actions:
logging.warning("关键错误:从提示词解析出的行动节点列表为空,无法生成任何有效任务。")
return actions, conditions
except json.JSONDecodeError:
logging.error("解析系统提示词中的JSON时失败。请检查格式。")
return set(), set()
except Exception as e:
logging.error(f"解析可用节点时发生未知错误: {e}")
return set(), set()
def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> dict:
"""
根据允许的行动和条件节点动态生成一个JSON Schema。
"""
# 递归节点定义
node_definition = {
"type": "object",
"properties": {
"type": {"type": "string", "enum": ["action", "condition", "Sequence", "Selector"]},
"name": {"type": "string"},
"params": {"type": "object"},
"children": {
"type": "array",
"items": {"$ref": "#/definitions/node"}
}
},
"required": ["type", "name"],
# 使用 allOf 和 if/then 来实现基于'type'字段的条件性'name'验证
"allOf": [
{
"if": {"properties": {"type": {"const": "action"}}},
"then": {"properties": {"name": {"enum": sorted(list(allowed_actions))}}}
},
{
"if": {"properties": {"type": {"const": "condition"}}},
"then": {"properties": {"name": {"enum": sorted(list(allowed_conditions))}}}
}
]
}
# 完整的Schema结构
schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "Pytree",
"definitions": {
"node": node_definition
},
"type": "object",
"properties": {
"root": { "$ref": "#/definitions/node" }
},
"required": ["root"]
}
return schema
def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
"""
使用JSON Schema验证给定的Pytree实例。
"""
try:
jsonschema.validate(instance=pytree_instance, schema=schema)
logging.info("验证成功Pytree格式和内容均符合规范。")
return True
except jsonschema.ValidationError as e:
logging.warning("--- Pytree验证失败 ---")
logging.warning(f"错误信息: {e.message}")
error_path = list(e.path)
logging.warning(f"错误路径: {' -> '.join(map(str, error_path)) if error_path else '根节点'}")
logging.warning(f"出错的实例部分: {e.instance}")
return False
except Exception as e:
logging.error(f"进行JSON Schema验证时发生未知错误: {e}")
return False
# ==============================================================================
# VISUALIZATION LOGIC (from utils/visualization.py)
# ==============================================================================
def _visualize_pytree(node: Dict, file_path: str):
"""
使用Graphviz将Pytree字典可视化并保存到指定路径。
"""
try:
from graphviz import Digraph
except ImportError:
logging.critical("错误未安装graphviz库。请运行: pip install graphviz")
return
dot = Digraph('Pytree', comment='Drone Mission Plan')
dot.attr('node', shape='box', style='rounded,filled', fontname='helvetica')
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20')
_add_nodes_and_edges(node, dot)
try:
# 保存为 .png 文件,并自动删除源码 .gv 文件
output_path = dot.render(file_path, format='png', cleanup=True, view=False)
logging.info("--- 任务树可视化成功 ---")
logging.info(f"图形已保存到: {output_path}")
except Exception as e:
logging.error("--- 错误:生成可视化图形失败 ---")
logging.error("请确保您的系统已经正确安装了Graphviz图形库。")
logging.error(f"错误详情: {e}")
def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
"""递归辅助函数,用于添加节点和边。"""
# 为每个节点创建一个唯一的ID
current_id = str(id(node))
# 准备节点标签
node_label = f"<{node['name']}<br/><i>({node['type']})</i>"
if node.get('params'):
params_str = json.dumps(node.get('params'))
node_label += f"<br/><font point-size='10'>params: {params_str}</font>"
node_label += ">"
# 根据类型设置节点样式
node_type = node.get('type', '').lower()
if node_type == 'action':
dot.node(current_id, label=node_label, shape='box', color="#cde4ff")
elif node_type == 'condition':
dot.node(current_id, label=node_label, shape='diamond', color="#fff2cc")
else: # Sequence, Selector, etc.
dot.node(current_id, label=node_label, shape='ellipse', color='#e6e6e6')
# 连接父节点
if parent_id:
dot.edge(parent_id, current_id)
# 递归处理子节点
last_child_id = current_id
for child in node.get("children", []):
# 对于序列,边是连续的;对于选择器,所有子节点都连接到父节点
if node_type in ['sequence']:
last_child_id = _add_nodes_and_edges(child, dot, last_child_id)
else: # Selector, Parallel
_add_nodes_and_edges(child, dot, current_id)
return current_id
# ==============================================================================
# CORE PYTREE GENERATOR CLASS
# ==============================================================================
class PyTreeGenerator:
def __init__(self):
self.base_dir = os.path.dirname(os.path.abspath(__file__))
self.prompts_dir = os.path.join(self.base_dir, 'prompts')
# Updated output directory for visualizations
self.vis_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_visualizations'))
os.makedirs(self.vis_dir, exist_ok=True)
self.system_prompt = self._load_prompt("system_prompt.txt")
self.orin_ip = os.getenv("ORIN_IP", "172.101.1.117")
self.llm_client = openai.OpenAI(
api_key=os.getenv("OPENAI_API_KEY", "sk-no-key-required"),
base_url=f"http://{self.orin_ip}:8081/v1"
)
# --- ChromaDB Client Setup ---
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'vector_store'))
self.chroma_client = chromadb.PersistentClient(path=vector_store_path)
# Explicitly use the remote embedding function for queries
embedding_api_url = f"http://{self.orin_ip}:8090/v1/embeddings"
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
self.collection = self.chroma_client.get_collection(
name="drone_docs",
embedding_function=embedding_func
)
allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.system_prompt)
self.schema = _generate_pytree_schema(allowed_actions, allowed_conditions)
def _load_prompt(self, file_name: str) -> str:
try:
with open(os.path.join(self.prompts_dir, file_name), 'r', encoding='utf-8') as f:
return f.read()
except FileNotFoundError:
logging.error(f"提示词文件未找到 -> {file_name}")
return ""
def _retrieve_context(self, query: str) -> Optional[str]:
logging.info("--- 开始从向量数据库检索上下文 ---")
try:
results = self.collection.query(query_texts=[query], n_results=5)
retrieved_docs = results.get("documents", [[]])[0]
if not retrieved_docs:
logging.warning("在向量数据库中没有找到相关的上下文信息。")
return None
context_str = "\n\n".join(retrieved_docs)
logging.info("--- 成功检索到上下文信息 ---")
return context_str
except Exception as e:
logging.error(f"从向量数据库检索时发生错误: {e}")
return None
async def generate(self, user_prompt: str) -> Dict[str, Any]:
"""
Generates a py_tree.json structure based on the user's prompt.
"""
logging.info(f"接收到用户请求: {user_prompt}")
retrieved_context = self._retrieve_context(user_prompt)
final_user_prompt = user_prompt
if retrieved_context:
augmentation = (
"\n\n---\n"
"参考知识:\n"
"以下是从知识库中检索到的、与当前任务最相关的信息,请优先参考这些信息来生成行为树:\n"
f"{retrieved_context}"
"\n---"
)
final_user_prompt += augmentation
else:
logging.warning("未检索到上下文或检索失败,将使用原始用户提示词。")
for attempt in range(3):
logging.info(f"--- 第 {attempt + 1}/3 次尝试生成Pytree ---")
try:
response = self.llm_client.chat.completions.create(
model="local-model",
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": final_user_prompt}
],
temperature=0.1,
response_format={"type": "json_object"}
)
pytree_str = response.choices[0].message.content
pytree_dict = json.loads(pytree_str)
if _validate_pytree_with_schema(pytree_dict, self.schema):
logging.info("成功生成并验证了Pytree。")
plan_id = str(uuid.uuid4())
pytree_dict['plan_id'] = plan_id
# Generate visualization to a static path
vis_filename = "py_tree.png"
vis_path = os.path.join(self.vis_dir, vis_filename)
_visualize_pytree(pytree_dict['root'], os.path.splitext(vis_path)[0])
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
return pytree_dict
else:
logging.warning("生成的Pytree验证失败正在重试...")
except (OpenAIError, json.JSONDecodeError) as e:
logging.error(f"生成Pytree时发生错误: {e}")
raise RuntimeError("在3次尝试后仍未能生成一个有效的Pytree。")
# Create a single instance for the application
py_tree_generator = PyTreeGenerator()

View File

@@ -0,0 +1,67 @@
import rclpy
from rclpy.action import ActionClient
from rclpy.node import Node
import json
from typing import Dict, Any
import logging
from drone_interfaces.action import ExecuteMission
from .websocket_manager import websocket_manager
class MissionActionClient(Node):
"""
Interfaces with the drone's `ExecuteMission` ROS2 Action Server.
"""
def __init__(self):
super().__init__('mission_action_client')
self._action_client = ActionClient(self, ExecuteMission, 'execute_mission')
self.get_logger().info("MissionActionClient initialized.")
def send_goal(self, py_tree: Dict[str, Any]):
"""
Sends the mission (py_tree) to the action server.
"""
if not self._action_client.server_is_ready():
self.get_logger().error("Action server not available, goal not sent.")
# Optionally, you could broadcast a status update to the frontend here
return
self.get_logger().info("Received request to send goal to drone.")
goal_msg = ExecuteMission.Goal()
goal_msg.py_tree_json = json.dumps(py_tree)
self.get_logger().info(f"Sending goal to action server...")
send_goal_future = self._action_client.send_goal_async(
goal_msg,
feedback_callback=self.feedback_callback
)
send_goal_future.add_done_callback(self.goal_response_callback)
def goal_response_callback(self, future):
goal_handle = future.result()
if not goal_handle.accepted:
self.get_logger().info('Goal rejected :(')
return
self.get_logger().info('Goal accepted :)')
self._get_result_future = goal_handle.get_result_async()
self._get_result_future.add_done_callback(self.get_result_callback)
def get_result_callback(self, future):
result = future.result().result
self.get_logger().info(f'Result: {{success: {result.success}, message: {result.message}}}')
# Optionally, you can broadcast the final result via WebSocket here
def feedback_callback(self, feedback_msg):
"""
This callback is triggered by the action server.
It forwards the status to the QGC plugin via the WebSocket manager in a thread-safe manner.
"""
feedback = feedback_msg.feedback
feedback_payload = json.dumps({"node_id": feedback.node_id, "status": feedback.status})
self.get_logger().info(f"Received feedback: {feedback_payload}")
websocket_manager.broadcast(feedback_payload)
# Note: The rclpy.init() and spinning of the node will be handled in main.py

View File

@@ -0,0 +1,48 @@
import asyncio
from typing import List
from fastapi import WebSocket
import logging
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
self.loop: asyncio.AbstractEventLoop | None = None
def set_loop(self, loop: asyncio.AbstractEventLoop):
"""Sets the asyncio event loop."""
self.loop = loop
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
def broadcast(self, message: str):
"""
Thread-safely broadcasts a message to all active WebSocket connections.
This method is designed to be called from a different thread (e.g., a ROS2 callback).
"""
if not self.loop:
logging.error("Event loop not set in ConnectionManager. Cannot broadcast.")
return
# Schedule the coroutine to be executed in the event loop
self.loop.call_soon_threadsafe(self._broadcast_in_loop, message)
def _broadcast_in_loop(self, message: str):
"""
Helper to run the broadcast coroutine in the correct event loop.
"""
asyncio.ensure_future(self._broadcast_async(message), loop=self.loop)
async def _broadcast_async(self, message: str):
"""
The actual async method that sends messages.
"""
tasks = [connection.send_text(message) for connection in self.active_connections]
await asyncio.gather(*tasks, return_exceptions=True)
# Create a single instance of the manager to be used across the application
websocket_manager = ConnectionManager()