代码内容转移
This commit is contained in:
0
backend_service/src/__init__.py
Normal file
0
backend_service/src/__init__.py
Normal file
BIN
backend_service/src/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
backend_service/src/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend_service/src/__pycache__/__init__.cpython-313.pyc
Normal file
BIN
backend_service/src/__pycache__/__init__.cpython-313.pyc
Normal file
Binary file not shown.
BIN
backend_service/src/__pycache__/main.cpython-310.pyc
Normal file
BIN
backend_service/src/__pycache__/main.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend_service/src/__pycache__/main.cpython-313.pyc
Normal file
BIN
backend_service/src/__pycache__/main.cpython-313.pyc
Normal file
Binary file not shown.
BIN
backend_service/src/__pycache__/models.cpython-310.pyc
Normal file
BIN
backend_service/src/__pycache__/models.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend_service/src/__pycache__/models.cpython-313.pyc
Normal file
BIN
backend_service/src/__pycache__/models.cpython-313.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
backend_service/src/__pycache__/ros2_client.cpython-310.pyc
Normal file
BIN
backend_service/src/__pycache__/ros2_client.cpython-310.pyc
Normal file
Binary file not shown.
BIN
backend_service/src/__pycache__/ros2_client.cpython-313.pyc
Normal file
BIN
backend_service/src/__pycache__/ros2_client.cpython-313.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
93
backend_service/src/main.py
Normal file
93
backend_service/src/main.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import asyncio
|
||||
import os
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
import logging
|
||||
import threading
|
||||
import rclpy
|
||||
|
||||
from .models import GeneratePlanRequest, ExecuteMissionRequest
|
||||
from .websocket_manager import websocket_manager
|
||||
from .py_tree_generator import py_tree_generator
|
||||
from .ros2_client import MissionActionClient
|
||||
|
||||
# --- Application Setup ---
|
||||
app = FastAPI(
|
||||
title="Drone Backend Service",
|
||||
description="Handles mission planning, generation, and execution for the drone.",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
# --- Mount Static Files for Visualizations ---
|
||||
static_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'generated_visualizations'))
|
||||
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
||||
|
||||
# --- ROS2 Node and Client Initialization ---
|
||||
rclpy.init()
|
||||
ros2_client = MissionActionClient()
|
||||
|
||||
def run_ros2_node():
|
||||
"""Spins the ROS2 node in a dedicated thread."""
|
||||
logging.info("Starting to spin ROS2 node...")
|
||||
rclpy.spin(ros2_client)
|
||||
logging.info("ROS2 node has stopped spinning.")
|
||||
|
||||
# --- API Endpoints ---
|
||||
|
||||
@app.post("/generate_plan", response_model=dict)
|
||||
async def generate_plan_endpoint(request: GeneratePlanRequest):
|
||||
"""
|
||||
Receives a user prompt and returns a generated `py_tree.json` with a visualization URL.
|
||||
"""
|
||||
try:
|
||||
pytree_dict = await py_tree_generator.generate(request.user_prompt)
|
||||
return pytree_dict
|
||||
except RuntimeError as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
@app.post("/execute_mission", response_model=dict)
|
||||
async def execute_mission_endpoint(request: ExecuteMissionRequest):
|
||||
"""
|
||||
Receives a `py_tree.json` and sends it to the drone for execution.
|
||||
"""
|
||||
ros2_client.send_goal(request.py_tree)
|
||||
return {"status": "execution_started"}
|
||||
|
||||
@app.websocket("/ws/status")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""
|
||||
Handles the WebSocket connection for real-time status updates.
|
||||
"""
|
||||
await websocket_manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
websocket_manager.disconnect(websocket)
|
||||
logging.info("Client disconnected from WebSocket.")
|
||||
|
||||
|
||||
# --- Server Lifecycle ---
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""
|
||||
On startup, get the current asyncio event loop and pass it to the websocket manager.
|
||||
Also, start the ROS2 node in a background thread.
|
||||
"""
|
||||
# Configure WebSocket Manager
|
||||
loop = asyncio.get_running_loop()
|
||||
websocket_manager.set_loop(loop)
|
||||
logging.info("WebSocket event loop configured.")
|
||||
|
||||
# Start ROS2 node in a background thread
|
||||
ros2_thread = threading.Thread(target=run_ros2_node, daemon=True)
|
||||
ros2_thread.start()
|
||||
logging.info("ROS2 node thread started.")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
logging.info("Backend service shutting down.")
|
||||
ros2_client.destroy_node()
|
||||
rclpy.shutdown()
|
||||
logging.info("ROS2 node shut down successfully.")
|
||||
12
backend_service/src/models.py
Normal file
12
backend_service/src/models.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, Any
|
||||
|
||||
class GeneratePlanRequest(BaseModel):
|
||||
user_prompt: str
|
||||
|
||||
class ExecuteMissionRequest(BaseModel):
|
||||
py_tree: Dict[str, Any]
|
||||
|
||||
class StatusUpdate(BaseModel):
|
||||
node_id: str
|
||||
status: int
|
||||
143
backend_service/src/prompts/system_prompt.txt
Normal file
143
backend_service/src/prompts/system_prompt.txt
Normal file
@@ -0,0 +1,143 @@
|
||||
你是一个无人机任务规划专家。你的**唯一**任务是根据用户提供的任务指令和参考知识,生成一个结构化、可执行的行为树(Pytree)。
|
||||
|
||||
你的输出**必须**是一个严格的、单一的JSON对象,不包含任何形式的解释、总结或自然语言描述。
|
||||
|
||||
---
|
||||
|
||||
#### 1. 可用节点定义 (必须遵守)
|
||||
|
||||
你**必须**严格从以下JSON定义的列表中选择节点来构建行为树。不允许幻想或使用任何未定义的节点。
|
||||
|
||||
```json
|
||||
{
|
||||
"actions": [
|
||||
{
|
||||
"name": "takeoff",
|
||||
"description": "无人机从当前位置垂直起飞到指定的海拔高度。",
|
||||
"params": {
|
||||
"altitude": "float, 目标海拔高度(米)"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "fly_to_waypoint",
|
||||
"description": "导航至一个WGS84坐标航点。无人机到达航点后该动作才算完成。",
|
||||
"params": {
|
||||
"latitude": "float, 目标纬度",
|
||||
"longitude": "float, 目标经度",
|
||||
"altitude": "float, 目标海拔高度(米)"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "loiter_for_duration",
|
||||
"description": "在当前位置上空悬停或盘旋一段时间。",
|
||||
"params": {
|
||||
"duration": "float, 悬停时间(秒)"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "land_at_position",
|
||||
"description": "在当前位置降落。",
|
||||
"params": {}
|
||||
},
|
||||
{
|
||||
"name": "return_to_launch",
|
||||
"description": "自动返航并降落到起飞点。",
|
||||
"params": {}
|
||||
}
|
||||
],
|
||||
"conditions": [],
|
||||
"control_flow": [
|
||||
{
|
||||
"name": "Sequence",
|
||||
"description": "序列节点,按顺序执行其子节点。只有当所有子节点都成功时,它才成功。",
|
||||
"params": {},
|
||||
"children": "array, 包含按顺序执行的子节点"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
#### 2. JSON结构规范 (必须遵守)
|
||||
|
||||
生成的JSON对象**必须**有一个名为`root`的键,其值是一个有效的行为树节点对象。每个节点都必须包含 "type" 和 "name" 字段。
|
||||
|
||||
```json
|
||||
{
|
||||
"root": {
|
||||
"type": "string, 'action' 或 'Sequence'",
|
||||
"name": "string, 来自上方可用节点列表",
|
||||
"params": "object, 包含所需的参数",
|
||||
"children": "array, (可选) 包含子节点对象"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
#### 3. 如何使用参考知识 (必须遵守)
|
||||
|
||||
当系统提供“参考知识”时,你**必须**使用其中的坐标来填充`params`字段。例如,如果任务是“飞到教学楼”,并且参考知识中提供了“教学楼”的坐标是 `(纬度: 31.2304, 经度: 121.4737)`,那么你的 `fly_to_waypoint` 节点应该写成:
|
||||
|
||||
`"params": {"latitude": 31.2304, "longitude": 121.4737, "altitude": 50.0}` (高度为默认或指定值)
|
||||
|
||||
---
|
||||
|
||||
#### 4. 任务规划示例 (Few-shot Learning)
|
||||
|
||||
以下是一些完整的任务规划示例,请学习并模仿它们的思考过程和输出格式。
|
||||
|
||||
##### 示例 1
|
||||
|
||||
**用户任务:**
|
||||
"无人机起飞,飞到教学楼进行30秒的勘察,然后返航降落。"
|
||||
|
||||
**参考知识:**
|
||||
"地理元素 '教学楼' (ID: 123) 是一个 建筑。 其中心位置坐标大约为 (纬度: 31.2304, 经度: 121.4737)。"
|
||||
|
||||
**生成的Pytree:**
|
||||
```json
|
||||
{
|
||||
"root": {
|
||||
"type": "Sequence",
|
||||
"name": "Mission",
|
||||
"children": [
|
||||
{ "type": "action", "name": "takeoff", "params": { "altitude": 20.0 } },
|
||||
{ "type": "action", "name": "fly_to_waypoint", "params": { "latitude": 31.2304, "longitude": 121.4737, "altitude": 20.0 } },
|
||||
{ "type": "action", "name": "loiter_for_duration", "params": { "duration": 30.0 } },
|
||||
{ "type": "action", "name": "return_to_launch", "params": {} }
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
##### 示例 2
|
||||
|
||||
**用户任务:**
|
||||
"飞到图书馆,然后去体育馆,最后在体育馆的位置降落。"
|
||||
|
||||
**参考知识:**
|
||||
"地理元素 '图书馆' (ID: 456) 是一个 建筑。 其中心位置坐标大约为 (纬度: 31.2315, 经度: 121.4758)。
|
||||
地理元素 '体育馆' (ID: 789) 是一个 建筑。 其中心位置坐标大约为 (纬度: 31.2330, 经度: 121.4780)。"
|
||||
|
||||
**生成的Pytree:**
|
||||
```json
|
||||
{
|
||||
"root": {
|
||||
"type": "Sequence",
|
||||
"name": "Mission",
|
||||
"children": [
|
||||
{ "type": "action", "name": "takeoff", "params": { "altitude": 50.0 } },
|
||||
{ "type": "action", "name": "fly_to_waypoint", "params": { "latitude": 31.2315, "longitude": 121.4758, "altitude": 50.0 } },
|
||||
{ "type": "action", "name": "fly_to_waypoint", "params": { "latitude": 31.2330, "longitude": 121.4780, "altitude": 50.0 } },
|
||||
{ "type": "action", "name": "land_at_position", "params": {} }
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
---
|
||||
|
||||
#### 5. 最终指令
|
||||
|
||||
现在,请严格按照以上规则和示例,根据用户提供的最新任务和参考知识,生成行为树JSON。直接输出JSON对象,不要有任何其他内容。
|
||||
330
backend_service/src/py_tree_generator.py
Normal file
330
backend_service/src/py_tree_generator.py
Normal file
@@ -0,0 +1,330 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import uuid
|
||||
import re
|
||||
from typing import Dict, Any, Optional, Set
|
||||
|
||||
import chromadb
|
||||
import openai
|
||||
from openai import OpenAIError
|
||||
import jsonschema
|
||||
import requests
|
||||
|
||||
# --- 自定义远程嵌入函数 (与ingest.py中定义一致) ---
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
|
||||
|
||||
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
def __init__(self, api_url: str):
|
||||
self._api_url = api_url
|
||||
|
||||
def __call__(self, input: Embeddable) -> Embeddings:
|
||||
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
|
||||
return []
|
||||
try:
|
||||
response = requests.post(
|
||||
self._api_url,
|
||||
json={"input": input},
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json().get("data", [])
|
||||
if not data:
|
||||
return []
|
||||
return [item['embedding'] for item in data]
|
||||
except Exception as e:
|
||||
logging.error(f"调用远程嵌入API时出错: {e}")
|
||||
return []
|
||||
|
||||
# --- 日志记录设置 ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
|
||||
# ==============================================================================
|
||||
# VALIDATION LOGIC (from utils/validation.py)
|
||||
# ==============================================================================
|
||||
|
||||
def _parse_allowed_nodes_from_prompt(prompt_text: str) -> tuple[Set[str], Set[str]]:
|
||||
"""
|
||||
从系统提示词中解析出允许的行动和条件节点。
|
||||
"""
|
||||
try:
|
||||
# 使用正则表达式查找JSON代码块
|
||||
match = re.search(r"```json\s*({.*?})\s*```", prompt_text, re.DOTALL)
|
||||
if not match:
|
||||
logging.error("在系统提示词中未找到可用节点的JSON定义块。")
|
||||
return set(), set()
|
||||
|
||||
json_str = match.group(1)
|
||||
allowed_nodes = json.loads(json_str)
|
||||
|
||||
# 从对象列表中提取节点名称
|
||||
actions = {action['name'] for action in allowed_nodes.get("actions", []) if 'name' in action}
|
||||
conditions = {condition['name'] for condition in allowed_nodes.get("conditions", []) if 'name' in condition}
|
||||
|
||||
if not actions:
|
||||
logging.warning("关键错误:从提示词解析出的行动节点列表为空,无法生成任何有效任务。")
|
||||
|
||||
return actions, conditions
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logging.error("解析系统提示词中的JSON时失败。请检查格式。")
|
||||
return set(), set()
|
||||
except Exception as e:
|
||||
logging.error(f"解析可用节点时发生未知错误: {e}")
|
||||
return set(), set()
|
||||
|
||||
def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> dict:
|
||||
"""
|
||||
根据允许的行动和条件节点,动态生成一个JSON Schema。
|
||||
"""
|
||||
# 递归节点定义
|
||||
node_definition = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"type": "string", "enum": ["action", "condition", "Sequence", "Selector"]},
|
||||
"name": {"type": "string"},
|
||||
"params": {"type": "object"},
|
||||
"children": {
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/definitions/node"}
|
||||
}
|
||||
},
|
||||
"required": ["type", "name"],
|
||||
# 使用 allOf 和 if/then 来实现基于'type'字段的条件性'name'验证
|
||||
"allOf": [
|
||||
{
|
||||
"if": {"properties": {"type": {"const": "action"}}},
|
||||
"then": {"properties": {"name": {"enum": sorted(list(allowed_actions))}}}
|
||||
},
|
||||
{
|
||||
"if": {"properties": {"type": {"const": "condition"}}},
|
||||
"then": {"properties": {"name": {"enum": sorted(list(allowed_conditions))}}}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 完整的Schema结构
|
||||
schema = {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "Pytree",
|
||||
"definitions": {
|
||||
"node": node_definition
|
||||
},
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"root": { "$ref": "#/definitions/node" }
|
||||
},
|
||||
"required": ["root"]
|
||||
}
|
||||
return schema
|
||||
|
||||
def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
|
||||
"""
|
||||
使用JSON Schema验证给定的Pytree实例。
|
||||
"""
|
||||
try:
|
||||
jsonschema.validate(instance=pytree_instance, schema=schema)
|
||||
logging.info("验证成功!Pytree格式和内容均符合规范。")
|
||||
return True
|
||||
except jsonschema.ValidationError as e:
|
||||
logging.warning("--- Pytree验证失败 ---")
|
||||
logging.warning(f"错误信息: {e.message}")
|
||||
error_path = list(e.path)
|
||||
logging.warning(f"错误路径: {' -> '.join(map(str, error_path)) if error_path else '根节点'}")
|
||||
logging.warning(f"出错的实例部分: {e.instance}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.error(f"进行JSON Schema验证时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
# ==============================================================================
|
||||
# VISUALIZATION LOGIC (from utils/visualization.py)
|
||||
# ==============================================================================
|
||||
|
||||
def _visualize_pytree(node: Dict, file_path: str):
|
||||
"""
|
||||
使用Graphviz将Pytree字典可视化,并保存到指定路径。
|
||||
"""
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
logging.critical("错误:未安装graphviz库。请运行: pip install graphviz")
|
||||
return
|
||||
|
||||
dot = Digraph('Pytree', comment='Drone Mission Plan')
|
||||
dot.attr('node', shape='box', style='rounded,filled', fontname='helvetica')
|
||||
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20')
|
||||
|
||||
_add_nodes_and_edges(node, dot)
|
||||
|
||||
try:
|
||||
# 保存为 .png 文件,并自动删除源码 .gv 文件
|
||||
output_path = dot.render(file_path, format='png', cleanup=True, view=False)
|
||||
logging.info("--- 任务树可视化成功 ---")
|
||||
logging.info(f"图形已保存到: {output_path}")
|
||||
except Exception as e:
|
||||
logging.error("--- 错误:生成可视化图形失败 ---")
|
||||
logging.error("请确保您的系统已经正确安装了Graphviz图形库。")
|
||||
logging.error(f"错误详情: {e}")
|
||||
|
||||
def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
|
||||
"""递归辅助函数,用于添加节点和边。"""
|
||||
|
||||
# 为每个节点创建一个唯一的ID
|
||||
current_id = str(id(node))
|
||||
|
||||
# 准备节点标签
|
||||
node_label = f"<{node['name']}<br/><i>({node['type']})</i>"
|
||||
if node.get('params'):
|
||||
params_str = json.dumps(node.get('params'))
|
||||
node_label += f"<br/><font point-size='10'>params: {params_str}</font>"
|
||||
node_label += ">"
|
||||
|
||||
# 根据类型设置节点样式
|
||||
node_type = node.get('type', '').lower()
|
||||
if node_type == 'action':
|
||||
dot.node(current_id, label=node_label, shape='box', color="#cde4ff")
|
||||
elif node_type == 'condition':
|
||||
dot.node(current_id, label=node_label, shape='diamond', color="#fff2cc")
|
||||
else: # Sequence, Selector, etc.
|
||||
dot.node(current_id, label=node_label, shape='ellipse', color='#e6e6e6')
|
||||
|
||||
# 连接父节点
|
||||
if parent_id:
|
||||
dot.edge(parent_id, current_id)
|
||||
|
||||
# 递归处理子节点
|
||||
last_child_id = current_id
|
||||
for child in node.get("children", []):
|
||||
# 对于序列,边是连续的;对于选择器,所有子节点都连接到父节点
|
||||
if node_type in ['sequence']:
|
||||
last_child_id = _add_nodes_and_edges(child, dot, last_child_id)
|
||||
else: # Selector, Parallel
|
||||
_add_nodes_and_edges(child, dot, current_id)
|
||||
|
||||
return current_id
|
||||
|
||||
# ==============================================================================
|
||||
# CORE PYTREE GENERATOR CLASS
|
||||
# ==============================================================================
|
||||
|
||||
class PyTreeGenerator:
|
||||
def __init__(self):
|
||||
self.base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.prompts_dir = os.path.join(self.base_dir, 'prompts')
|
||||
|
||||
# Updated output directory for visualizations
|
||||
self.vis_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_visualizations'))
|
||||
os.makedirs(self.vis_dir, exist_ok=True)
|
||||
|
||||
self.system_prompt = self._load_prompt("system_prompt.txt")
|
||||
|
||||
self.orin_ip = os.getenv("ORIN_IP", "172.101.1.117")
|
||||
self.llm_client = openai.OpenAI(
|
||||
api_key=os.getenv("OPENAI_API_KEY", "sk-no-key-required"),
|
||||
base_url=f"http://{self.orin_ip}:8081/v1"
|
||||
)
|
||||
|
||||
# --- ChromaDB Client Setup ---
|
||||
vector_store_path = os.path.abspath(os.path.join(self.base_dir, '..', '..', 'tools', 'vector_store'))
|
||||
self.chroma_client = chromadb.PersistentClient(path=vector_store_path)
|
||||
|
||||
# Explicitly use the remote embedding function for queries
|
||||
embedding_api_url = f"http://{self.orin_ip}:8090/v1/embeddings"
|
||||
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
|
||||
|
||||
self.collection = self.chroma_client.get_collection(
|
||||
name="drone_docs",
|
||||
embedding_function=embedding_func
|
||||
)
|
||||
|
||||
allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.system_prompt)
|
||||
self.schema = _generate_pytree_schema(allowed_actions, allowed_conditions)
|
||||
|
||||
def _load_prompt(self, file_name: str) -> str:
|
||||
try:
|
||||
with open(os.path.join(self.prompts_dir, file_name), 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
logging.error(f"提示词文件未找到 -> {file_name}")
|
||||
return ""
|
||||
|
||||
def _retrieve_context(self, query: str) -> Optional[str]:
|
||||
logging.info("--- 开始从向量数据库检索上下文 ---")
|
||||
try:
|
||||
results = self.collection.query(query_texts=[query], n_results=5)
|
||||
retrieved_docs = results.get("documents", [[]])[0]
|
||||
if not retrieved_docs:
|
||||
logging.warning("在向量数据库中没有找到相关的上下文信息。")
|
||||
return None
|
||||
context_str = "\n\n".join(retrieved_docs)
|
||||
logging.info("--- 成功检索到上下文信息 ---")
|
||||
return context_str
|
||||
except Exception as e:
|
||||
logging.error(f"从向量数据库检索时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def generate(self, user_prompt: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Generates a py_tree.json structure based on the user's prompt.
|
||||
"""
|
||||
logging.info(f"接收到用户请求: {user_prompt}")
|
||||
|
||||
retrieved_context = self._retrieve_context(user_prompt)
|
||||
|
||||
final_user_prompt = user_prompt
|
||||
if retrieved_context:
|
||||
augmentation = (
|
||||
"\n\n---\n"
|
||||
"参考知识:\n"
|
||||
"以下是从知识库中检索到的、与当前任务最相关的信息,请优先参考这些信息来生成行为树:\n"
|
||||
f"{retrieved_context}"
|
||||
"\n---"
|
||||
)
|
||||
final_user_prompt += augmentation
|
||||
else:
|
||||
logging.warning("未检索到上下文或检索失败,将使用原始用户提示词。")
|
||||
|
||||
for attempt in range(3):
|
||||
logging.info(f"--- 第 {attempt + 1}/3 次尝试生成Pytree ---")
|
||||
try:
|
||||
response = self.llm_client.chat.completions.create(
|
||||
model="local-model",
|
||||
messages=[
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": final_user_prompt}
|
||||
],
|
||||
temperature=0.1,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
pytree_str = response.choices[0].message.content
|
||||
pytree_dict = json.loads(pytree_str)
|
||||
|
||||
if _validate_pytree_with_schema(pytree_dict, self.schema):
|
||||
logging.info("成功生成并验证了Pytree。")
|
||||
plan_id = str(uuid.uuid4())
|
||||
pytree_dict['plan_id'] = plan_id
|
||||
|
||||
# Generate visualization to a static path
|
||||
vis_filename = "py_tree.png"
|
||||
vis_path = os.path.join(self.vis_dir, vis_filename)
|
||||
_visualize_pytree(pytree_dict['root'], os.path.splitext(vis_path)[0])
|
||||
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
|
||||
|
||||
return pytree_dict
|
||||
else:
|
||||
logging.warning("生成的Pytree验证失败,正在重试...")
|
||||
|
||||
except (OpenAIError, json.JSONDecodeError) as e:
|
||||
logging.error(f"生成Pytree时发生错误: {e}")
|
||||
|
||||
raise RuntimeError("在3次尝试后,仍未能生成一个有效的Pytree。")
|
||||
|
||||
# Create a single instance for the application
|
||||
py_tree_generator = PyTreeGenerator()
|
||||
67
backend_service/src/ros2_client.py
Normal file
67
backend_service/src/ros2_client.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import rclpy
|
||||
from rclpy.action import ActionClient
|
||||
from rclpy.node import Node
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
import logging
|
||||
|
||||
from drone_interfaces.action import ExecuteMission
|
||||
from .websocket_manager import websocket_manager
|
||||
|
||||
class MissionActionClient(Node):
|
||||
"""
|
||||
Interfaces with the drone's `ExecuteMission` ROS2 Action Server.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__('mission_action_client')
|
||||
self._action_client = ActionClient(self, ExecuteMission, 'execute_mission')
|
||||
self.get_logger().info("MissionActionClient initialized.")
|
||||
|
||||
def send_goal(self, py_tree: Dict[str, Any]):
|
||||
"""
|
||||
Sends the mission (py_tree) to the action server.
|
||||
"""
|
||||
if not self._action_client.server_is_ready():
|
||||
self.get_logger().error("Action server not available, goal not sent.")
|
||||
# Optionally, you could broadcast a status update to the frontend here
|
||||
return
|
||||
|
||||
self.get_logger().info("Received request to send goal to drone.")
|
||||
goal_msg = ExecuteMission.Goal()
|
||||
goal_msg.py_tree_json = json.dumps(py_tree)
|
||||
|
||||
self.get_logger().info(f"Sending goal to action server...")
|
||||
send_goal_future = self._action_client.send_goal_async(
|
||||
goal_msg,
|
||||
feedback_callback=self.feedback_callback
|
||||
)
|
||||
|
||||
send_goal_future.add_done_callback(self.goal_response_callback)
|
||||
|
||||
def goal_response_callback(self, future):
|
||||
goal_handle = future.result()
|
||||
if not goal_handle.accepted:
|
||||
self.get_logger().info('Goal rejected :(')
|
||||
return
|
||||
|
||||
self.get_logger().info('Goal accepted :)')
|
||||
|
||||
self._get_result_future = goal_handle.get_result_async()
|
||||
self._get_result_future.add_done_callback(self.get_result_callback)
|
||||
|
||||
def get_result_callback(self, future):
|
||||
result = future.result().result
|
||||
self.get_logger().info(f'Result: {{success: {result.success}, message: {result.message}}}')
|
||||
# Optionally, you can broadcast the final result via WebSocket here
|
||||
|
||||
def feedback_callback(self, feedback_msg):
|
||||
"""
|
||||
This callback is triggered by the action server.
|
||||
It forwards the status to the QGC plugin via the WebSocket manager in a thread-safe manner.
|
||||
"""
|
||||
feedback = feedback_msg.feedback
|
||||
feedback_payload = json.dumps({"node_id": feedback.node_id, "status": feedback.status})
|
||||
self.get_logger().info(f"Received feedback: {feedback_payload}")
|
||||
websocket_manager.broadcast(feedback_payload)
|
||||
|
||||
# Note: The rclpy.init() and spinning of the node will be handled in main.py
|
||||
48
backend_service/src/websocket_manager.py
Normal file
48
backend_service/src/websocket_manager.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
from fastapi import WebSocket
|
||||
import logging
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: List[WebSocket] = []
|
||||
self.loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
def set_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
"""Sets the asyncio event loop."""
|
||||
self.loop = loop
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
self.active_connections.remove(websocket)
|
||||
|
||||
def broadcast(self, message: str):
|
||||
"""
|
||||
Thread-safely broadcasts a message to all active WebSocket connections.
|
||||
This method is designed to be called from a different thread (e.g., a ROS2 callback).
|
||||
"""
|
||||
if not self.loop:
|
||||
logging.error("Event loop not set in ConnectionManager. Cannot broadcast.")
|
||||
return
|
||||
|
||||
# Schedule the coroutine to be executed in the event loop
|
||||
self.loop.call_soon_threadsafe(self._broadcast_in_loop, message)
|
||||
|
||||
def _broadcast_in_loop(self, message: str):
|
||||
"""
|
||||
Helper to run the broadcast coroutine in the correct event loop.
|
||||
"""
|
||||
asyncio.ensure_future(self._broadcast_async(message), loop=self.loop)
|
||||
|
||||
async def _broadcast_async(self, message: str):
|
||||
"""
|
||||
The actual async method that sends messages.
|
||||
"""
|
||||
tasks = [connection.send_text(message) for connection in self.active_connections]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Create a single instance of the manager to be used across the application
|
||||
websocket_manager = ConnectionManager()
|
||||
Reference in New Issue
Block a user