将 groundcontrol 从子模块转换为普通目录

This commit is contained in:
2025-11-29 17:54:59 +08:00
parent 6acda62380
commit 6260ef4f7c
16 changed files with 2045 additions and 1 deletions

Submodule scripts/ai_agent/groundcontrol deleted from d026107bc2

View File

@@ -0,0 +1,143 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# static files generated from Django application using `collectstatic`
media
static

View File

@@ -0,0 +1,345 @@
# 无人机自然语言控制项目
本项目构建了一个完整的无人机自然语言控制系统集成了检索增强生成RAG知识库、大型语言模型LLM、FastAPI后端服务和ROS2通信最终实现通过自然语言指令控制无人机执行复杂任务。
## 项目结构
项目被清晰地划分为几个核心模块:
```
.
├── backend_service/
│ ├── src/ # FastAPI应用核心代码
│ │ ├── __init__.py
│ │ ├── main.py # 应用主入口提供Web API
│ │ ├── py_tree_generator.py # RAG与LLM集成生成py_tree
│ │ ├── prompts/ # LLM 提示词
│ │ │ ├── system_prompt.txt # 复杂模式提示词(行为树与安全监控)
│ │ │ ├── simple_mode_prompt.txt # 简单模式提示词单一原子动作JSON
│ │ │ └── classifier_prompt.txt # 指令简单/复杂分类提示词
│ │ ├── ...
│ ├── generated_visualizations/ # 存放最新生成的py_tree可视化图像
│ └── requirements.txt # 后端服务的Python依赖
├── tools/
│ ├── map/ # 【数据源】存放原始地图文件(如.world, .json)
│ ├── knowledge_base/ # 【处理后】存放build_knowledge_base.py生成的.ndjson文件
│ ├── vector_store/ # 【数据库】存放最终的ChromaDB向量数据库
│ ├── build_knowledge_base.py # 【步骤1】用于将原始数据转换为自然语言知识
│ └── ingest.py # 【步骤2】用于将自然语言知识摄入向量数据库
├── / # ROS2接口定义 (保持不变)
└── docs/
└── README.md # 本说明文件
```
## 核心配置Orin IP 地址
**重要提示:** 本项目的后端服务和知识库工具需要与在NVIDIA Jetson Orin设备上运行的服务进行通信嵌入模型和LLM推理服务**默认的IP地址为localhost**,所以使用电脑本地部署的模型服务同样可以,但是需要注意指定模型的端口。
在使用前您必须配置正确的Orin设备IP地址。您可以通过以下两种方式之一进行设置
1. **设置环境变量 (推荐)**:
在您的终端中设置一个名为 `ORIN_IP` 的环境变量。
```bash
export ORIN_IP="192.168.1.100" # 请替换为您的Orin设备的实际IP地址
```
脚本会优先使用这个环境变量。
2. **直接修改脚本**:
如果您不想设置环境变量,可以打开 `tools/ingest.py` 和 `backend_service/src/py_tree_generator.py` 文件,找到 `orin_ip = os.getenv("ORIN_IP", "...")` 这样的行并将默认的IP地址修改为您的Orin设备的实际IP地址。
**在继续后续步骤之前,请务必完成此项配置。**
## 模型端口启动
本项目启动依赖于后端的模型推理服务,即`ORIN_IP`所指向的设备的模型服务端口目前项目使用instruct模型与embedding模型实现流程分别部署在8081端口与8090端口。
1. **推理模型部署**:
在`/llama.cpp/build/bin`路径下执行以下命令启动模型
```bash
./llama-server -m ~/models/gguf/Qwen/Qwen3-8B-GGUF/Qwen3-8B-Q4_K_M.gguf --port 8081 --gpu-layers 36 --host 0.0.0.0 -c 8192
```
2. **Embedding模型部署**
在`/llama.cpp/build/bin`路径下执行以下命令启动模型
```bash
./llama-server -m ~/models/gguf/Qwen/Qwen3-embedding-4B/Qwen3-Embedding-4B-Q4_K_M.gguf --gpu-layers 36 --port 8090 --embeddings --pooling last --host 0.0.0.0
```
---
## 指令分类与分流
后端在生成任务前会先对用户指令进行“简单/复杂”分类,并分流到不同提示词与模型:
- 分类提示词:`backend_service/src/prompts/classifier_prompt.txt`
- 简单模式提示词:`backend_service/src/prompts/simple_mode_prompt.txt`
- 复杂模式提示词:`backend_service/src/prompts/system_prompt.txt`
分类仅输出如下JSON之一`{"mode":"simple"}` 或 `{"mode":"complex"}`。两种模式都会执行检索增强RAG将参考知识拼接到用户指令后再进行推理。
当为简单模式时LLM仅输出
`{"mode":"simple","action":{"name":"<action>","params":{...}}}`。
后端不会再自动封装为复杂行为树将直接返回简单JSON并附加 `plan_id` 与 `visualization_url`(单动作可视化)。
### 环境变量(可选)
支持为“分类/简单/复杂”三类调用分别配置模型与Base URL未设置时回退到默认本地配置
- `CLASSIFIER_MODEL`, `CLASSIFIER_BASE_URL`
- `SIMPLE_MODEL`, `SIMPLE_BASE_URL`
- `COMPLEX_MODEL`, `COMPLEX_BASE_URL`
通用API Key`OPENAI_API_KEY`
示例:
```bash
export CLASSIFIER_MODEL="qwen2.5-1.8b-instruct"
export SIMPLE_MODEL="qwen2.5-1.8b-instruct"
export COMPLEX_MODEL="qwen2.5-7b-instruct"
export CLASSIFIER_BASE_URL="http://$ORIN_IP:8081/v1"
export SIMPLE_BASE_URL="http://$ORIN_IP:8081/v1"
export COMPLEX_BASE_URL="http://$ORIN_IP:8081/v1"
export OPENAI_API_KEY="sk-no-key-required"
```
### 测试简单模式
启动服务后,运行内置测试脚本:
```bash
cd tools
python test_api.py
```
示例输入:“简单模式,起飞” 或 “起飞到10米”。返回结果为简单JSON无 `root`):包含 `mode`、`action`、`plan_id`、`visualization_url`。
---
## 工作流程
整个系统的工作流程分为两个主要阶段:
1. **知识库构建(一次性设置)**: 将环境信息、无人机能力等知识加工并存入向量数据库。
2. **后端服务运行与交互**: 启动主服务通过API接收指令、生成并执行任务。
### 阶段一:环境设置与编译
此阶段为项目准备好运行环境,仅需在初次配置或依赖变更时执行。一个稳定、隔离且兼容的环境是所有后续步骤成功的基础。
#### 1. 创建Conda环境 (关键步骤)
为了从根源上避免本地Python环境与系统ROS 2环境的库版本冲突特别是Python版本和C++标准库),我们**必须**使用Conda创建一个干净、隔离且版本精确的虚拟环境。
```bash
# 1. 创建一个使用Python 3.10的新环境。
# --name backend: 指定环境名称。
# python=3.10: 指定Python版本必须与ROS 2 Humble要求的版本一致。
# --channel conda-forge: 使用conda-forge社区源其包通常有更好的兼容性。
# --no-default-packages: 关键不安装Conda默认的包如libgcc避免与系统ROS 2的C++库冲突。
conda create --name backend --channel conda-forge --no-default-packages python=3.10
# 2. 激活新创建的环境
conda activate backend
```
#### 2. 安装所有Python依赖
在激活`backend`环境后,使用`pip`一次性安装所有依赖。`requirements.txt`已包含**运行时**如fastapi, rclpy和**编译时**如empy, catkin-pkg, lark所需的所有库。
```bash
# 确保在项目根目录 (drone/) 下执行
pip install -r backend_service/requirements.txt
```
#### 3. 编译ROS 2接口
为了让后端服务能够像导入普通Python包一样导入我们自定义的Action接口 (`drone_interfaces`),你需要先使用`colcon`对其进行编译。
```bash
# 确保在项目根目录 (drone/) 下执行
colcon build
```
成功后,您会看到`build/`, `install/`, `log/`三个新目录。这一步会将`.action`文件转换为Python和C++代码。
---
### 阶段二:数据处理流水线
此阶段为RAG系统准备数据让LLM能够理解任务环境。
#### 1. 准备原始数据
将你的原始数据文件(例如,`.world`, `.json` 文件等)放入 `tools/map/` 目录中。
#### 2. 数据预处理
运行脚本将原始数据“翻译”成自然语言知识。
```bash
# 确保在项目根目录 (drone/) 下并已激活backend环境
cd tools
python build_knowledge_base.py
```
该脚本会扫描 `tools/map/` 目录,并在 `tools/knowledge_base/` 目录下生成对应的 `_knowledge.ndjson` 文件。
#### 3. 数据入库Ingestion
运行脚本将处理好的知识加载到向量数据库中。
```bash
# 仍在tools/目录下执行
python ingest.py
```
该脚本会自动扫描 `tools/knowledge_base/` 目录,并将数据存入 `tools/vector_store/` 目录中。
---
### 阶段三:服务启动与测试
完成前两个阶段后,即可启动并测试后端服务。
#### 1. 启动后端服务
启动服务的关键在于**按顺序激活环境**先激活ROS 2工作空间再激活Conda环境。
```bash
# 1. 切换到项目根目录
cd /path/to/your/drone
# 2. 激活ROS 2编译环境
# 作用:将我们编译好的`drone_interfaces`包的路径告知系统否则Python会报`ModuleNotFoundError`。
# 注意:此命令必须在每次打开新终端时执行一次。
source install/setup.bash
# 3. 激活Conda Python环境
conda activate backend
# 4. 启动FastAPI服务
cd backend_service/
uvicorn src.main:app --host 0.0.0.0 --port 8000
```
当您看到日志中出现 `Uvicorn running on http://0.0.0.0:8000` 时,表示服务已成功启动。
#### 2. 运行API接口测试
我们提供了一个脚本来验证核心的“任务生成”功能。
**打开一个新的终端**,并执行以下命令:
```bash
# 1. 切换到项目根目录
cd /path/to/your/drone
# 2. 激活Conda环境
conda activate backend
# 3. 运行测试脚本
cd tools/
python test_api.py
```
如果一切正常,您将在终端看到一系列 `PASS` 信息以及从服务器返回的Pytree JSON。
#### 3. API接口使用说明
---
## 故障排除 / 常见问题 (FAQ)
以下是在配置和运行此项目时可能遇到的一些常见问题及其解决方案。
#### **Q1: 启动服务时报错 `ModuleNotFoundError: No module named 'drone_interfaces'`**
- **原因**: 您当前的终端环境没有加载ROS 2工作空间的路径。仅仅激活Conda环境是不够的。
- **解决方案**: 严格遵循“启动后端服务”章节的说明在激活Conda环境**之前**,必须先运行 `source install/setup.bash` 命令。
#### **Q2: `colcon build` 编译失败,提示 `ModuleNotFoundError: No module named 'em'`, `'catkin_pkg'`, 或 `'lark'`**
- **原因**: 您的Python环境中缺少ROS 2编译代码时所必需的依赖包。
- **解决方案**: 我们已将所有已知的编译时依赖(`empy`, `catkin-pkg`, `lark`等)添加到了`requirements.txt`中。请确保您已激活正确的Conda环境然后运行 `pip install -r backend_service/requirements.txt` 来安装它们。
#### **Q3: 启动服务时报错 `ImportError: ... GLIBCXX_... not found` 或 `ModuleNotFoundError: No module named 'rclpy._rclpy_pybind11'`**
- **原因**: 您的Conda环境与系统ROS 2环境存在核心库冲突。最常见的原因是Python版本不匹配例如Conda是Python 3.11而ROS 2 Humble需要3.10或者Conda自带的C++库与系统库冲突。
- **解决方案**: 这是最棘手的环境问题。最可靠的解决方法是彻底删除当前的Conda环境 (`conda env remove --name backend`),然后严格按照本文档「环境设置」章节的说明,用正确的命令 (`conda create --name backend --channel conda-forge --no-default-packages python=3.10`) 重建一个干净、兼容的环境。
#### **Q4: 服务启动时,日志显示正在从网络上下载模型(例如 `all-MiniLM-L6-v2`**
- **原因**: 后端服务在连接向量数据库时没有正确指定使用远程嵌入模型导致ChromaDB退回到默认的、需要下载模型的本地嵌入函数。
- **解决方案**: 此问题在当前代码中**已被修复**。`backend_service/src/py_tree_generator.py`现在会正确地将远程嵌入函数实例传递给ChromaDB。如果您在自己的代码中遇到此问题请检查您的`get_collection`调用。
#### **Q5: 服务启动时,日志停在 `waiting for action server...`无法访问API**
- **原因**: 代码中存在阻塞式的`wait_for_server()`调用它会一直等待直到无人机端的Action服务器上线从而卡住了Web服务的启动流程。
- **解决方案**: 此问题在当前代码中**已被修复**。`backend_service/src/ros2_client.py`现在使用非阻塞的方式初始化,并在发送任务时检查服务器是否可用。
##### **A. 生成任务计划**
接收自然语言指令返回生成的行为树py_treeJSON。
- **Endpoint**: `POST /generate_plan`
- **Request Body**:
```json
{
"user_prompt": "无人机起飞到10米然后前往机库最后降落。"
}
```
- **Success Response复杂模式**:
```json
{
"root": { ... },
"plan_id": "some-unique-id",
"visualization_url": "/static/py_tree.png"
}
```
- **Success Response简单模式**:
```json
{
"mode": "simple",
"action": { "name": "takeoff", "params": { "altitude": 10.0 } },
"plan_id": "some-unique-id",
"visualization_url": "/static/py_tree.png"
}
```
##### **B. 查看任务可视化**
获取最新生成的行为树的可视化图像。
- **Endpoint**: `GET /static/py_tree.png`
- **Usage**: 在浏览器中直接打开 `http://<服务器IP>:8000/static/py_tree.png` 即可查看。每次成功调用 `/generate_plan` 后,该图像都会被更新。
##### **C. 执行任务**
接收一个py_tree JSON下发给无人机执行当前为模拟执行
- **Endpoint**: `POST /execute_mission`
- **Request Body**: (使用 `/generate_plan` 返回的 `root` 对象)
```json
{
"py_tree": {
"root": { ... }
}
}
```
- **Response**:
```json
{
"status": "execution_started"
}
```
##### **D. 接收实时状态**
通过WebSocket连接实时接收无人机在执行任务时的状态反馈。
- **Endpoint**: `WS /ws/status`
- **Usage**: 使用任意WebSocket客户端连接到 `ws://<服务器IP>:8000/ws/status`。当任务执行时服务器会主动推送JSON消息例如
```json
{"node_id": "takeoff_node_1", "status": 0} // 0: RUNNING
{"node_id": "takeoff_node_1", "status": 1} // 1: SUCCESS
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 117 KiB

View File

@@ -0,0 +1,39 @@
# Web Framework and Server
fastapi>=0.104.0
uvicorn[standard]>=0.24.0
python-multipart>=0.0.6
websockets>=12.0
# Data Validation and Serialization
pydantic>=2.5.0
jsonschema>=4.20.0
# AI and Vector Database
openai>=1.3.0
chromadb>=0.4.0
# Visualization
graphviz>=0.20.0
# ROS 2 Python Client
rclpy>=0.0.1
# Document Processing
unstructured[all]>=0.11.0
# HTTP Requests
requests>=2.31.0
# Progress Bars and UI
rich>=13.7.0
# Type Hints Support
typing-extensions>=4.8.0
# ROS 2 Build Dependencies
empy==3.3.4
catkin-pkg>=0.4.0
lark>=1.1.0
colcon-common-extensions>=0.3.0
vcstool>=0.2.0
rosdep>=0.22.0

View File

@@ -0,0 +1,225 @@
import asyncio
import os
import sys
import subprocess
import time
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
import logging
import threading
# --- Sys Path Injection ---
# Add ai_agent root to sys.path
current_dir = os.path.dirname(os.path.abspath(__file__))
# Go up 2 levels: src -> groundcontrol -> ai_agent
ai_agent_root = os.path.abspath(os.path.join(current_dir, "../.."))
# Also add models directory specifically for implicit imports if needed
models_dir = os.path.join(ai_agent_root, "models")
if ai_agent_root not in sys.path:
sys.path.append(ai_agent_root)
if models_dir not in sys.path:
sys.path.append(models_dir)
# Now we can import from ai_agent modules
from models.models_client import Models_Client
from tools.core.common_functions import read_json_file
from .models import GeneratePlanRequest, ExecuteMissionRequest
from .websocket_manager import websocket_manager
from .py_tree_generator import py_tree_generator
# --- Global Variables ---
model_server_process = None
MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://localhost:8000")
# --- Application Setup ---
app = FastAPI(
title="Drone Backend Service",
description="Handles mission planning, generation, and execution for the drone.",
version="1.0.0",
)
# --- Mount Static Files for Visualizations ---
static_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'generated_visualizations'))
if not os.path.exists(static_dir):
os.makedirs(static_dir)
app.mount("/static", StaticFiles(directory=static_dir), name="static")
def start_model_server():
"""Starts the model server if it's not already running."""
global model_server_process
client = Models_Client(MODEL_SERVER_URL)
# Check if already running
try:
logging.info(f"Checking model server health at {MODEL_SERVER_URL}...")
health = client.check_health()
if health.get("status") in ["healthy", "loading"]:
logging.info("Model server is already running.")
return
except Exception:
pass # Not running
logging.info("Starting model server...")
# 读取配置文件
try:
config_json_file = os.path.join(ai_agent_root, "config", "model_config.json")
config_data = read_json_file(config_json_file)
if config_data is None or "models_server" not in config_data:
logging.error(f"无法读取配置文件或配置文件中缺少 'models_server' 字段: {config_json_file}")
return
# 如果 models_server 是列表,取第一个;如果是字典,直接使用
models_server_data = config_data["models_server"]
model_server_config = models_server_data[0] if isinstance(models_server_data, list) else models_server_data
host = model_server_config.get("model_server_host", "localhost")
port = model_server_config.get("model_server_port", 8000)
workers = model_server_config.get("model_controller_workers", 1)
except Exception as e:
logging.error(f"读取配置文件失败: {e},使用默认配置")
host = "localhost"
port = 8000
workers = 1
# Start as subprocess using uvicorn
try:
def log_output(pipe, prefix):
"""实时输出子进程的日志"""
try:
for line in iter(pipe.readline, b''):
if line:
line_str = line.decode('utf-8', errors='replace').strip()
if line_str:
logging.info(f"[Model Server {prefix}] {line_str}")
except Exception as e:
logging.error(f"Error reading {prefix}: {e}")
finally:
pipe.close()
# 使用 uvicorn 启动模型服务器
# 使用 models_server:create_app 作为应用入口
uvicorn_cmd = [
sys.executable, "-m", "uvicorn",
"models.models_server:create_app",
"--host", host,
"--port", str(port),
"--workers", str(workers),
"--factory"
]
model_server_process = subprocess.Popen(
uvicorn_cmd,
cwd=ai_agent_root,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# 启动线程实时输出日志
stdout_thread = threading.Thread(
target=log_output,
args=(model_server_process.stdout, "STDOUT"),
daemon=True
)
stderr_thread = threading.Thread(
target=log_output,
args=(model_server_process.stderr, "STDERR"),
daemon=True
)
stdout_thread.start()
stderr_thread.start()
# 先等待一小段时间,检查进程是否立即退出
time.sleep(2)
if model_server_process.poll() is not None:
# 进程已退出,读取错误信息
out, err = model_server_process.communicate()
error_msg = err.decode('utf-8', errors='replace') if err else out.decode('utf-8', errors='replace') if out else "Unknown error"
logging.error(f"Model server process exited immediately: {error_msg}")
return
# Wait for it to be ready
logging.info("Waiting for model server to be ready...")
if client.wait_for_server(timeout=300, check_interval=5):
logging.info("Model server started and ready.")
else:
logging.error("Model server failed to start within timeout.")
# Check for errors
if model_server_process.poll() is not None:
out, err = model_server_process.communicate()
error_msg = err.decode('utf-8', errors='replace') if err else out.decode('utf-8', errors='replace') if out else 'Unknown error'
logging.error(f"Model server process exited with error: {error_msg}")
except Exception as e:
logging.error(f"Failed to launch model server: {e}")
def stop_model_server():
"""Stops the model server subprocess if it was started by this service."""
global model_server_process
if model_server_process:
logging.info("Stopping model server subprocess...")
model_server_process.terminate()
try:
model_server_process.wait(timeout=10)
except subprocess.TimeoutExpired:
model_server_process.kill()
logging.info("Model server stopped.")
# --- API Endpoints ---
@app.post("/generate_plan", response_model=dict)
async def generate_plan_endpoint(request: GeneratePlanRequest):
"""
Receives a user prompt and returns a generated `py_tree.json` with a visualization URL.
"""
try:
pytree_dict = await py_tree_generator.generate(request.user_prompt)
return pytree_dict
except RuntimeError as e:
return {"error": str(e)}
@app.post("/execute_mission", response_model=dict)
async def execute_mission_endpoint(request: ExecuteMissionRequest):
"""
Receives a `py_tree.json` and sends it to the drone for execution.
"""
# ROS2 execution removed
return {"status": "execution_started (simulation mode)"}
@app.websocket("/ws/status")
async def websocket_endpoint(websocket: WebSocket):
"""
Handles the WebSocket connection for real-time status updates.
"""
await websocket_manager.connect(websocket)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
websocket_manager.disconnect(websocket)
logging.info("Client disconnected from WebSocket.")
# --- Server Lifecycle ---
@app.on_event("startup")
async def startup_event():
"""
On startup, get the current asyncio event loop and pass it to the websocket manager.
Start the ROS2 node in a background thread.
Ensure the Model Server is running.
"""
# Configure WebSocket Manager
loop = asyncio.get_running_loop()
websocket_manager.set_loop(loop)
logging.info("WebSocket event loop configured.")
# Start Model Server
threading.Thread(target=start_model_server, daemon=True).start()
@app.on_event("shutdown")
async def shutdown_event():
logging.info("Backend service shutting down.")
stop_model_server()
logging.info("Backend service shut down successfully.")

View File

@@ -0,0 +1,12 @@
from pydantic import BaseModel
from typing import Dict, Any
class GeneratePlanRequest(BaseModel):
user_prompt: str
class ExecuteMissionRequest(BaseModel):
py_tree: Dict[str, Any]
class StatusUpdate(BaseModel):
node_id: str
status: int

View File

@@ -0,0 +1,24 @@
你是一个严格的任务分类器。只输出一个JSON对象不要输出解释或多余文本。
根据用户指令与下述可用节点定义,判断其为“简单”或“复杂”。
- 简单:单一原子动作即可完成(例如“起飞”“飞机自检”“移动到某地(已给定坐标)”“对着某点环绕XY圈对着学生宿舍环绕三十两圈”等且无需行为树与安全并行监控。
- 复杂:需要多步流程、搜索/检测/跟踪/评估、战损确认、或需要模板化任务结构与安全并行监控。
输出格式(严格遵守):
{"mode":"simple"} 或 {"mode":"complex"}
—— 可用节点定义——
```json
{
"actions": [
{"name": "takeoff"}, {"name": "land"}, {"name": "fly_to_waypoint"}, {"name": "move_direction"}, {"name": "orbit_around_point"}, {"name": "orbit_around_target"}, {"name": "loiter"},
{"name": "object_detect"}, {"name": "strike_target"}, {"name": "battle_damage_assessment"},
{"name": "search_pattern"}, {"name": "track_object"}, {"name": "deliver_payload"},
{"name": "preflight_checks"}, {"name": "emergency_return"}
],
"conditions": [
{"name": "battery_above"}, {"name": "at_waypoint"}, {"name": "object_detected"},
{"name": "target_destroyed"}, {"name": "time_elapsed"}, {"name": "gps_status"}
]
}
```

View File

@@ -0,0 +1,63 @@
你是一个无人机简单指令执行规划器。你的任务当用户给出“简单指令”单一原子动作即可完成输出一个严格的JSON对象。
输出要求(必须遵守):
- 只输出一个JSON对象不要任何解释或多余文本。
- JSON结构
{"root":{"type":"action","name":"<action_name>","params":{...}}}
- <action_name> 与参数定义、取值范围必须与“复杂模式”提示词system_prompt.txt中的定义完全一致。
- root节点必须是action类型节点不能是控制流节点。
- 当用户指令或检索到的参考知识中已经给出了具体坐标(如 (32.118904, 118.955208)),请直接按原值使用该坐标,不进行任何缩放、单位换算或数值变换,只需根据需要映射到对应字段
示例:
- “起飞到10米” → {"root":{"type":"action","name":"takeoff","params":{"altitude":10.0}}}
- “移动到(120,80,20)” → {"root":{"type":"action","name":"fly_to_waypoint","params":{"x":120.0,"y":80.0,"z":20.0,"acceptance_radius":2.0}}}
- “飞机自检” → {"root":{"type":"action","name":"preflight_checks","params":{"check_level":"comprehensive"}}}
—— 可用节点定义——
```json
{
"actions": [
{"name": "takeoff", "description": "无人机从当前位置垂直起飞到指定的海拔高度。", "params": {"altitude": "float, 目标海拔高度(米),范围[1, 100]默认为2"}},
{"name": "land", "description": "降落无人机。可选择当前位置或返航点降落。", "params": {"mode": "string, 可选值: 'current'(当前位置), 'home'(返航点)"}},
{"name": "fly_to_waypoint", "description": "导航至一个指定坐标点。使用相对坐标系x,y,z单位为米。", "params": {"x": "float", "y": "float", "z": "float", "acceptance_radius": "float, 可选默认2.0"}},
{"name": "move_direction", "description": "按指定方向直线移动。方向可为绝对方位或相对机体朝向。", "params": {"direction": "string: north|south|east|west|forward|backward|left|right", "distance": "float[1,10000], 可选, 不指定则持续移动"}},
{"name": "orbit_around_point", "description": "以给定中心点为中心,等速圆周飞行指定圈数。", "params": {"center_x": "float", "center_y": "float", "center_z": "float", "radius": "float[5,1000]", "laps": "int[1,20]", "clockwise": "boolean, 可选, 默认true", "speed_mps": "float[0.5,15], 可选", "gimbal_lock": "boolean, 可选, 默认true"}},
{"name": "orbit_around_target", "description": "以目标为中心,等速圆周飞行指定圈数(需已有目标)。", "params": {"target_class": "string, 取值同object_detect列表", "description": "string, 可选", "radius": "float[5,1000]", "laps": "int[1,20]", "clockwise": "boolean, 可选, 默认true", "speed_mps": "float[0.5,15], 可选", "gimbal_lock": "boolean, 可选, 默认true"}},
{"name": "loiter", "description": "在当前位置上空悬停一段时间或直到条件触发。", "params": {"duration": "float, 可选[1,600]", "until_condition": "string, 可选"}},
{"name": "object_detect", "description": "识别特定目标对象。一般是用户提到的需要检测的目标;如果用户给出了需要探索的目标的优先级,比如蓝色球危险性大于红色球大于绿色球,需要检测最危险的球,此处应给出检测优先级,描述应当为 '蓝>红>绿'", "params": {"target_class": "string, 要识别的目标类别,必须为以下值之一: balloon,person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic_light, fire_hydrant, stop_sign, parking_meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports_ball, kite, baseball_bat, baseball_glove, skateboard, surfboard, tennis_racket, bottle, wine_glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot_dog, pizza, donut, cake, chair, couch, potted_plant, bed, dining_table, toilet, tv, laptop, mouse, remote, keyboard, cell_phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy_bear, hair_drier, toothbrush", "description": "string, 可选", "count": "int, 可选, 默认1"}},
{"name": "strike_target", "description": "对已识别目标进行打击。", "params": {"target_class": "string", "description": "string, 可选", "count": "int, 可选, 默认1"}},
{"name": "battle_damage_assessment", "description": "战损评估。", "params": {"target_class": "string", "assessment_time": "float[5-60], 默认15.0"}},
{"name": "search_pattern", "description": "按模式搜索。", "params": {"pattern_type": "string: spiral|grid", "center_x": "float", "center_y": "float", "center_z": "float", "radius": "float[5,1000]", "target_class": "string", "description": "string, 可选", "count": "int, 可选, 默认1"}},
{"name": "track_object", "description": "持续跟踪目标。", "params": {"target_class": "string, 取值同object_detect列表", "description": "string, 可选", "track_time": "float[1,600], 默认30.0", "min_confidence": "float[0.5-1.0], 默认0.7", "safe_distance": "float[2-50], 默认10.0"}},
{"name": "deliver_payload", "description": "投放物资。", "params": {"payload_type": "string", "release_altitude": "float[2,100], 默认5.0"}},
{"name": "preflight_checks", "description": "飞行前系统自检。", "params": {"check_level": "string: basic|comprehensive"}},
{"name": "emergency_return", "description": "执行紧急返航程序。", "params": {"reason": "string"}}
],
"conditions": [
{"name": "battery_above", "description": "电池电量高于阈值。", "params": {"threshold": "float[0.0,1.0]"}},
{"name": "at_waypoint", "description": "在指定坐标容差范围内。", "params": {"x": "float", "y": "float", "z": "float", "tolerance": "float, 可选, 默认3.0"}},
{"name": "object_detected", "description": "检测到特定目标。", "params": {"target_class": "string", "description": "string, 可选", "count": "int, 可选, 默认1"}},
{"name": "target_destroyed", "description": "目标已被摧毁。", "params": {"target_class": "string", "description": "string, 可选", "confidence": "float[0.5-1.0], 默认0.8"}},
{"name": "time_elapsed", "description": "时间经过。", "params": {"duration": "float[1,2700]"}},
{"name": "gps_status", "description": "GPS状态良好。", "params": {"min_satellites": "int[6,15], 默认10"}}
]
}
```
—— 参数约束——
- takeoff.altitude: [1, 100]
- fly_to_waypoint.z: [1, 5000]
- fly_to_waypoint.x,y: [-10000, 10000]
- search_pattern.radius: [5, 1000]
- move_direction.distance: [1, 10000]
- orbit_around_point.radius: [5, 1000]
- orbit_around_target.radius: [5, 1000]
- orbit_around_point/target.laps: [1, 20]
- orbit_around_point/target.speed_mps: [0.5, 15]
- 若参考知识提供坐标,必须使用并裁剪到约束范围内
—— 口令转化规则(环绕类)——
- “环绕X米Y圈” → 若有目标上下文则使用 `orbit_around_target`,否则根据是否给出中心坐标选择 `orbit_around_point``radius=X``laps=Y`,默认 `clockwise=true``gimbal_lock=true`
- “顺时针/逆时针” → `clockwise=true/false`
- “等速” → 若未给速度则 `speed_mps` 采用默认值例如3.0);若口令指明速度,裁剪到[0.5,15]
- “以(x,y,z)为中心”/“当前位置为中心” → 选择 `orbit_around_point` 并填充 `center_x/center_y/center_z`

View File

@@ -0,0 +1,117 @@
任务根据用户任意任务指令生成结构化可执行的无人机行为树PytreeJSON。**仅输出单一JSON对象无任何自然语言、注释或额外内容**。
## 一、核心节点定义(格式不可修改,确保后端解析)
#### 1. 可用节点定义 (必须遵守)
你必须严格从以下JSON定义的列表中选择节点构建行为树不允许使用未定义节点
```json
{
"actions": [
{"name":"takeoff","params":{"altitude":"float[1,100]默认2"}},
{"name":"land","params":{"mode":"'current'/'home'"}},
{"name":"fly_to_waypoint","params":{"x":"±10000","y":"±10000","z":"[1,5000]","acceptance_radius":"默认2.0"}},
{"name":"move_direction","params":{"direction":"north/south/east/west/forward/backward/left/right","distance":"[1,10000],缺省持续移动"}},
{"name":"orbit_around_point","params":{"center_x":"±10000","center_y":"±10000","center_z":"[1,5000]","radius":"[5,1000]","laps":"[1,20]","clockwise":"默认true","speed_mps":"[0.5,15]","gimbal_lock":"默认true"}},
{"name":"orbit_around_target","params":{"target_class":"见object_detect列表","description":"可选,目标属性","radius":"[5,1000]","laps":"[1,20]","clockwise":"默认true","speed_mps":"[0.5,15]","gimbal_lock":"默认true"}},
{"name":"loiter","params":{"duration":"[1,600]秒/until_condition:可选"}},
{"name":"object_detect","params":{"target_class":"person,bicycle,car,motorcycle,airplane,bus,train,truck,boat,traffic_light,fire_hydrant,stop_sign,parking_meter,bench,bird,cat,dog,horse,sheep,cow,elephant,bear,zebra,giraffe,backpack,umbrella,handbag,tie,suitcase,frisbee,skis,snowboard,sports_ball,kite,baseball_bat,baseball_glove,skateboard,surfboard,tennis_racket,bottle,wine_glass,cup,fork,knife,spoon,bowl,banana,apple,sandwich,orange,broccoli,carrot,hot_dog,pizza,donut,cake,chair,couch,potted_plant,bed,dining_table,toilet,tv,laptop,mouse,remote,keyboard,cell_phone,microwave,oven,toaster,sink,refrigerator,book,clock,vase,scissors,teddy_bear,hair_drier,toothbrush","description":"可选,","count":"默认1"}},
{"name":"strike_target","params":{"target_class":"同object_detect","description":"可选,目标属性","count":"默认1"}},
{"name":"battle_damage_assessment","params":{"target_class":"同object_detect","assessment_time":"[5,60]默认15"}},
{"name":"search_pattern","params":{"pattern_type":"spiral/grid","center_x":"±10000","center_y":"±10000","center_z":"[1,5000]","radius":"[5,1000]","target_class":"同object_detect","description":"可选,目标属性","count":"默认1"}},
{"name":"track_object","params":{"target_class":"同object_detect","description":"可选,目标属性","track_time":"[1,600]秒(必传,不可用'duration'","min_confidence":"[0.5,1.0]默认0.7","safe_distance":"[2,50]默认10"}},
{"name":"deliver_payload","params":{"payload_type":"string","release_altitude":"[2,100]默认5"}},
{"name":"preflight_checks","params":{"check_level":"basic/comprehensive"}},
{"name":"emergency_return","params":{"reason":"string"}}
],
"conditions": [
{"name":"battery_above","params":{"threshold":"[0.0,1.0],必传"}},
{"name":"at_waypoint","params":{"x":"±10000","y":"±10000","z":"[1,5000]","tolerance":"默认3.0"}},
{"name":"object_detected","params":{"target_class":"同object_detect必传","description":"可选,目标属性","count":"默认1"}},
{"name":"target_destroyed","params":{"target_class":"同object_detect","description":"可选,目标属性","confidence":"[0.5,1.0]默认0.8"}},
{"name":"time_elapsed","params":{"duration":"[1,2700]秒"}},
{"name":"gps_status","params":{"min_satellites":"int[6,15]必传如8"}}
],
"control_flow": [
{"name":"Sequence","params":{},"children":"子节点数组(按序执行,全成功则成功)"},
{"name":"Selector","params":{"memory":"默认true"},"children":"子节点数组(执行到成功为止)"},
{"name":"Parallel","params":{"policy":"all_success"},"children":"子节点数组(同时执行,严禁用'one_success'"}
]
}
```
## 二、节点必填字段后端Schema强制要求缺一验证失败
每个节点必须包含以下字段,字段名/类型不可自定义:
1. **`type`**
- 动作节点→`"action"`,条件节点→`"condition"`,控制流节点→`"Sequence"`/`"Selector"`/`"Parallel"`(与`name`字段值完全一致);
2. **`name`**必须是上述JSON中`actions`/`conditions`/`control_flow`下的`name`值如“gps_status”不可错写为“gps_check”
3. **`params`**:严格匹配上述节点的`params`定义无自定义参数如优先级排序不可加“priority”字段仅用`description`
4. **`children`**:仅控制流节点必含(子节点数组),动作/条件节点无此字段。
## 三、行为树固定结构(通用不变,确保安全验证)
根节点必须是`Parallel``children`含`MainTask`Sequence和`SafetyMonitor`Selector结构不随任务类型含优先级排序修改
```json
{
"root": {
"type": "Parallel",
"name": "MissionWithSafety",
"params": {"policy": "all_success"},
"children": [
{
"type": "Sequence",
"name": "MainTask",
"params": {},
"children": [
// 通用主任务步骤(含优先级排序任务示例,需按用户指令替换):
{"type":"action","name":"preflight_checks","params":{"check_level":"comprehensive"}},
{"type":"action","name":"takeoff","params":{"altitude":10.0}},
{"type":"action","name":"fly_to_waypoint","params":{"x":200.0,"y":150.0,"z":10.0}}, // 搜索区坐标(用户未给时填合理值)
{"type":"action","name":"search_pattern","params":{"pattern_type":"grid","center_x":200.0,"center_y":150.0,"center_z":10.0,"radius":50.0,"target_class":"balloon","description":"红色"}},
{"type":"condition","name":"object_detected","params":{"target_class":"balloon","description":"红色"}}, // 确认高优先级目标
{"type":"action","name":"track_object","params":{"target_class":"balloon","description":"红色","track_time":30.0}},
{"type":"action","name":"strike_target","params":{"target_class":"balloon","description":"红色"}},
{"type":"action","name":"land","params":{"mode":"home"}}
]
},
{
"type": "Selector",
"name": "SafetyMonitor",
"params": {"memory": true},
"children": [
{"type":"condition","name":"battery_above","params":{"threshold":0.3}},
{"type":"condition","name":"gps_status","params":{"min_satellites":8}},
{
"type":"Sequence",
"name":"EmergencyHandler",
"params": {},
"children": [
{"type":"action","name":"emergency_return","params":{"reason":"safety_breach"}},
{"type":"action","name":"land","params":{"mode":"home"}}
]
}
]
}
]
}
}
```
## 四、优先级排序任务通用示例
当用户指令中明确提出有多个待考察且具有优先级关系的物体时,节点描述须为优先级关系。比如当指令为已知有三个气球,危险级关系为红色气球大于蓝色气球大于绿色气球,要求优先跟踪最危险的气球时,节点的描述参考下表情形。
| 用户指令场景 | `target_class` | `description` | 核心节点示例search_pattern |
|-----------------------------|-----------------|-------------------------|------------------------------------------------------------------------------------------------|
| 红气球>蓝气球>绿气球 | `balloon` | `(红>蓝>绿)` | `{"type":"action","name":"search_pattern","params":{"pattern_type":"grid","center_x":200,"center_y":150,"center_z":10,"radius":50,"target_class":"balloon","description":"(红>蓝>绿)"}}` |
| 军用卡车>民用卡车>面包车 | `truck` | `(军用卡车>民用卡车>面包车)` | `{"type":"action","name":"object_detect","params":{"target_class":"truck","description":"(军用卡车>民用卡车>面包车)"}}` |
## 五、高频错误规避(确保验证通过)
1. 优先级排序不可修改`target_class`:如“民用卡车、面包车与军用卡车中,军用卡车优先”,`target_class`仍为`truck`,仅用`description`填排序规则;
2. 在没有明确指出物体之间的优先级关系情况下,`description`字段只描述物体属性本身,严禁与用户指令中不存在的物体进行排序;
3. `track_object`必传`track_time`:不可用`duration`替代如跟踪30秒填`"track_time":30.0`
4. `gps_status`的`min_satellites`必须在6-15之间如8不可缺省
5. 无自定义节点:“锁定高优先级目标”需通过`object_detect`+`object_detected`实现不可用“lock_high_risk_target”。
## 六、输出要求
仅输出1个严格符合上述所有规则的JSON对象**确保1. 优先级排序逻辑正确填入`description`2. `target_class`匹配预定义列表3. 行为树结构不变4. 后端解析与Schema验证无错误**无任何冗余内容5. 当用户指令或检索到的参考知识中已经给出了具体坐标(如 (32.118904, 118.955208)),请直接按原值使用该坐标,不进行任何缩放、单位换算或数值变换,只需根据需要映射到对应字段

View File

@@ -0,0 +1,934 @@
import json
import os
import logging
import uuid
import re
import time
from typing import Dict, Any, Optional, Set, List
import jsonschema
import platform
# --- Imports from ai_agent ---
# Note: sys.path is updated in main.py to include ai_agent root
from models.models_client import Models_Client
from models.model_definition import TextInferenceRequest
from tools.memory_mag.rag.rag import set_embeddings, load_vector_database, retrieve_relevant_info
from tools.core.common_functions import read_json_file
# --- Logging Setup ---
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
# ==============================================================================
# VALIDATION LOGIC (from utils/validation.py)
# ==============================================================================
def _parse_allowed_nodes_from_prompt(prompt_text: str) -> tuple[Set[str], Set[str]]:
"""
从系统提示词中精确解析出允许的行动和条件节点。
"""
try:
# 使用更精确的正则表达式匹配节点定义部分
# 匹配 "#### 1. 可用节点定义" 或 "#### 2. 可用节点定义" 等
node_section_pattern = r"####\s*\d+\.\s*可用节点定义.*?```json\s*({.*?})\s*```"
match = re.search(node_section_pattern, prompt_text, re.DOTALL | re.IGNORECASE)
if not match:
logging.error("在系统提示词中未找到'可用节点定义'部分的JSON代码块。")
# 备用方案尝试查找所有JSON块并识别节点定义
return _fallback_parse_nodes(prompt_text)
json_str = match.group(1)
logging.info("成功找到节点定义JSON代码块")
# 解析JSON
allowed_nodes = json.loads(json_str)
# 从对象列表中提取节点名称
actions = set()
conditions = set()
# 提取动作节点
if "actions" in allowed_nodes and isinstance(allowed_nodes["actions"], list):
for action in allowed_nodes["actions"]:
if isinstance(action, dict) and "name" in action:
actions.add(action["name"])
# 提取条件节点
if "conditions" in allowed_nodes and isinstance(allowed_nodes["conditions"], list):
for condition in allowed_nodes["conditions"]:
if isinstance(condition, dict) and "name" in condition:
conditions.add(condition["name"])
if not actions:
logging.warning("关键错误:从提示词解析出的行动节点列表为空。")
logging.info(f"成功解析出动作节点: {sorted(actions)}")
logging.info(f"成功解析出条件节点: {sorted(conditions)}")
return actions, conditions
except json.JSONDecodeError as e:
logging.error(f"解析节点定义JSON时失败: {e}")
return set(), set()
except Exception as e:
logging.error(f"解析可用节点时发生未知错误: {e}")
return set(), set()
def _fallback_parse_nodes(prompt_text: str) -> tuple[Set[str], Set[str]]:
"""
备用解析方案:当精确匹配失败时使用。
"""
logging.warning("使用备用方案解析节点定义...")
# 查找所有JSON代码块
matches = re.findall(r"```json\s*({.*?})\s*```", prompt_text, re.DOTALL)
if not matches:
logging.error("在系统提示词中未找到任何JSON代码块。")
return set(), set()
# 尝试从每个JSON块中解析节点定义
for i, json_str in enumerate(matches):
try:
data = json.loads(json_str)
# 检查是否是节点定义的结构包含actions、conditions、control_flow
if ("actions" in data and isinstance(data["actions"], list) and
"conditions" in data and isinstance(data["conditions"], list) and
"control_flow" in data and isinstance(data["control_flow"], list)):
actions = set()
conditions = set()
# 提取动作节点
for action in data["actions"]:
if isinstance(action, dict) and "name" in action:
actions.add(action["name"])
# 提取条件节点
for condition in data["conditions"]:
if isinstance(condition, dict) and "name" in condition:
conditions.add(condition["name"])
if actions:
logging.info(f"从第{i+1}个JSON块中成功解析出节点定义")
logging.info(f"动作节点: {sorted(actions)}")
logging.info(f"条件节点: {sorted(conditions)}")
return actions, conditions
except json.JSONDecodeError:
continue # 尝试下一个JSON块
logging.error("在所有JSON代码块中都没有找到有效的节点定义结构。")
return set(), set()
def _find_nodes_by_name(node: Dict, target_name: str) -> List[Dict]:
"""递归查找所有指定名称的节点"""
nodes_found = []
if node.get("name") == target_name:
nodes_found.append(node)
# 递归搜索子节点
for child in node.get("children", []):
nodes_found.extend(_find_nodes_by_name(child, target_name))
return nodes_found
def _validate_safety_monitoring(pytree_instance: dict) -> bool:
"""验证行为树是否包含必要的安全监控"""
root_node = pytree_instance.get("root", {})
# 查找所有电池监控节点
battery_nodes = _find_nodes_by_name(root_node, "battery_above")
# 检查是否包含安全监控结构
safety_monitors = _find_nodes_by_name(root_node, "SafetyMonitor")
if not battery_nodes and not safety_monitors:
logging.warning("⚠️ 安全警告: 行为树中没有发现电池监控节点或安全监控器")
return False
# 检查电池阈值设置是否合理
for battery_node in battery_nodes:
threshold = battery_node.get("params", {}).get("threshold")
if threshold is not None:
if threshold < 0.25:
logging.warning(f"⚠️ 安全警告: 电池阈值设置过低 ({threshold})建议不低于0.25")
elif threshold > 0.5:
logging.warning(f"⚠️ 安全警告: 电池阈值设置过高 ({threshold}),可能影响任务执行")
logging.info("✅ 安全监控验证通过")
return True
def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> dict:
"""
根据允许的行动和条件节点动态生成一个JSON Schema。
"""
# 所有可能的节点类型
node_types = ["action", "condition", "Sequence", "Selector", "Parallel"]
# 目标检测相关的类别枚举
target_classes = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
"traffic_light", "fire_hydrant", "stop_sign", "parking_meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack",
"umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports_ball",
"kite", "baseball_bat", "baseball_glove", "skateboard", "surfboard", "tennis_racket",
"bottle", "wine_glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot_dog", "pizza", "donut", "cake", "chair",
"couch", "potted_plant", "bed", "dining_table", "toilet", "tv", "laptop", "mouse", "remote",
"keyboard", "cell_phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
"clock", "vase", "scissors", "teddy_bear", "hair_drier", "toothbrush","balloon"
]
# 递归节点定义
node_definition = {
"type": "object",
"properties": {
"type": {"type": "string", "enum": node_types},
"name": {"type": "string"},
"params": {"type": "object"},
"children": {
"type": "array",
"items": {"$ref": "#/definitions/node"}
}
},
"required": ["type", "name"],
"allOf": [
# 动作节点验证
{
"if": {"properties": {"type": {"const": "action"}}},
"then": {"properties": {"name": {"enum": sorted(list(allowed_actions))}}}
},
# 条件节点验证
{
"if": {"properties": {"type": {"const": "condition"}}},
"then": {"properties": {"name": {"enum": sorted(list(allowed_conditions))}}}
},
# 目标检测动作节点的参数验证
{
"if": {
"properties": {
"type": {"const": "action"},
"name": {"const": "object_detect"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"target_class": {"type": "string", "enum": target_classes},
"description": {"type": "string"},
"count": {"type": "integer", "minimum": 1}
},
"required": ["target_class"],
"additionalProperties": False
}
}
}
},
# 目标检测条件节点的参数验证
{
"if": {
"properties": {
"type": {"const": "condition"},
"name": {"const": "object_detected"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"target_class": {"type": "string", "enum": target_classes},
"description": {"type": "string"},
"count": {"type": "integer", "minimum": 1}
},
"required": ["target_class"],
"additionalProperties": False
}
}
}
},
# 电池监控节点的参数验证
{
"if": {
"properties": {
"type": {"const": "condition"},
"name": {"const": "battery_above"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"threshold": {"type": "number", "minimum": 0.0, "maximum": 1.0}
},
"required": ["threshold"],
"additionalProperties": False
}
}
}
},
# GPS状态节点的参数验证
{
"if": {
"properties": {
"type": {"const": "condition"},
"name": {"const": "gps_status"}
}
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"min_satellites": {"type": "integer", "minimum": 6, "maximum": 15}
},
"required": ["min_satellites"],
"additionalProperties": False
}
}
}
}
]
}
# 完整的Schema结构
schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "Pytree",
"definitions": {
"node": node_definition
},
"type": "object",
"properties": {
"root": { "$ref": "#/definitions/node" }
},
"required": ["root"]
}
return schema
def _generate_simple_mode_schema(allowed_actions: set) -> dict:
"""
生成简单模式JSON Schema{"root":{"type":"action","name":"<action_name>","params":{...}}}
仅校验动作名称在允许集合内,以及基本结构完整性;参数按对象形状放宽,由上游提示词与运行时再约束。
"""
# 使用复杂模式Schema中的node定义但限制root节点必须是action类型
node_definition = {
"type": "object",
"properties": {
"type": {"type": "string", "const": "action"},
"name": {"type": "string", "enum": sorted(list(allowed_actions))},
"params": {"type": "object"}
},
"required": ["type", "name"],
"additionalProperties": False
}
schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "SimpleMode",
"definitions": {
"node": node_definition
},
"type": "object",
"properties": {
"root": { "$ref": "#/definitions/node" }
},
"required": ["root"],
"additionalProperties": False
}
return schema
def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool:
"""
使用JSON Schema验证给定的Pytree实例。
"""
try:
jsonschema.validate(instance=pytree_instance, schema=schema)
logging.info("✅ JSON Schema验证成功")
# 额外验证安全监控
safety_valid = _validate_safety_monitoring(pytree_instance)
return True and safety_valid
except jsonschema.ValidationError as e:
logging.warning("❌ Pytree验证失败")
logging.warning(f"错误信息: {e.message}")
error_path = list(e.path)
logging.warning(f"错误路径: {' -> '.join(map(str, error_path)) if error_path else '根节点'}")
# 提供更具体的错误信息
if "object_detect" in str(e.message) or "object_detected" in str(e.message):
logging.warning("💡 提示: 请确保目标类别是预定义列表中的有效值")
elif "battery_above" in str(e.message):
logging.warning("💡 提示: 电池阈值必须在0.0到1.0之间")
elif "gps_status" in str(e.message):
logging.warning("💡 提示: 最小卫星数量必须在6到15之间")
return False
except Exception as e:
logging.error(f"进行JSON Schema验证时发生未知错误: {e}")
return False
# ==============================================================================
# VISUALIZATION LOGIC (from utils/visualization.py)
# ==============================================================================
def _visualize_pytree(node: Dict, file_path: str):
"""
使用Graphviz将Pytree字典可视化并保存到指定路径。
"""
try:
from graphviz import Digraph
except ImportError:
logging.critical("错误未安装graphviz库。请运行: pip install graphviz")
return
# 选择合适的中文字体,避免中文乱码
def _pick_zh_font():
sys = platform.system()
if sys == "Windows":
return "Microsoft YaHei"
elif sys == "Darwin":
return "PingFang SC"
else:
return "Noto Sans CJK SC"
fontname = _pick_zh_font()
dot = Digraph('Pytree', comment='Drone Mission Plan')
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20', fontname=fontname)
dot.attr('node', shape='box', style='rounded,filled', fontname=fontname)
dot.attr('edge', fontname=fontname)
_add_nodes_and_edges(node, dot)
try:
# 确保输出目录存在,并避免生成 .png.png
base_path, ext = os.path.splitext(file_path)
render_path = base_path if ext.lower() == '.png' else file_path
out_dir = os.path.dirname(render_path)
if out_dir and not os.path.exists(out_dir):
os.makedirs(out_dir, exist_ok=True)
# 保存为 .png 文件,并自动删除源码 .gv 文件
output_path = dot.render(render_path, format='png', cleanup=True, view=False)
logging.info("✅ 任务树可视化成功")
logging.info(f"图形已保存到: {output_path}")
except Exception as e:
logging.error("❌ 生成可视化图形失败")
logging.error("请确保您的系统已经正确安装了Graphviz图形库。")
logging.error(f"错误详情: {e}")
def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
"""递归辅助函数,用于添加节点和边。"""
# 为每个节点创建一个唯一的ID加上随机数避免冲突
import random
import html
current_id = f"{id(node)}_{random.randint(1000, 9999)}"
# 准备节点标签HTML-like正确换行与转义
name = html.escape(str(node.get('name', '')))
ntype = html.escape(str(node.get('type', '')))
label_parts = [f"<B>{name}</B> <FONT POINT-SIZE='10'><I>({ntype})</I></FONT>"]
# 格式化参数显示
params = node.get('params') or {}
if params:
params_lines = []
for key, value in params.items():
k = html.escape(str(key))
if isinstance(value, float):
value_str = f"{value:.2f}".rstrip('0').rstrip('.')
else:
value_str = str(value)
v = html.escape(value_str)
params_lines.append(f"{k}: {v}")
params_text = "<BR ALIGN='LEFT'/>".join(params_lines)
label_parts.append(f"<FONT POINT-SIZE='9' COLOR='#555555'>{params_text}</FONT>")
node_label = f"<{'<BR/>'.join(label_parts)}>"
# 根据类型设置节点样式和颜色(使用 fillcolor 控制填充色)
node_type = (node.get('type') or '').lower()
shape = 'ellipse'
style = 'filled'
fillcolor = '#e6e6e6' # 默认灰色填充
border_color = '#666666' # 默认描边色
if node_type == 'action':
shape = 'box'
style = 'rounded,filled'
fillcolor = "#cde4ff" # 浅蓝
elif node_type == 'condition':
shape = 'diamond'
style = 'filled'
fillcolor = "#fff2cc" # 浅黄
elif node_type == 'sequence':
shape = 'ellipse'
style = 'filled'
fillcolor = '#d5e8d4' # 绿色
elif node_type == 'selector':
shape = 'ellipse'
style = 'filled'
fillcolor = '#ffe6cc' # 橙色
elif node_type == 'parallel':
shape = 'ellipse'
style = 'filled'
fillcolor = '#e1d5e7' # 紫色
# 特别标记安全相关节点
if node.get('name') in ['battery_above', 'gps_status', 'SafetyMonitor']:
border_color = '#ff0000' # 红色边框突出显示安全节点
style = 'filled,bold' # 加粗
dot.node(current_id, label=node_label, shape=shape, style=style, fillcolor=fillcolor, color=border_color)
# 连接父节点
if parent_id:
dot.edge(parent_id, current_id)
# 递归处理子节点
children = node.get("children", [])
if not children:
return current_id
# 记录所有子节点的ID
child_ids = []
# 正确的递归连接:每个子节点都连接到当前节点
for child in children:
child_id = _add_nodes_and_edges(child, dot, current_id)
child_ids.append(child_id)
# 子节点同级排列(横向排布,更直观地表现同层)
if len(child_ids) > 1:
with dot.subgraph(name=f"rank_{current_id}") as s:
s.attr(rank='same')
for cid in child_ids:
s.node(cid)
# 行为树中,所有类型的节点都只是父连子,不需要子节点间的额外连接
# Sequence、Selector、Parallel 的执行逻辑由行为树引擎处理,不需要在可视化中体现
return current_id
# ==============================================================================
# CORE PYTREE GENERATOR CLASS
# ==============================================================================
class PyTreeGenerator:
def __init__(self):
self.base_dir = os.path.dirname(os.path.abspath(__file__))
self.prompts_dir = os.path.join(self.base_dir, 'prompts')
# Updated output directory for visualizations
self.vis_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_visualizations'))
os.makedirs(self.vis_dir, exist_ok=True)
# Load prompts
self.complex_prompt = self._load_prompt("system_prompt.txt")
self.simple_prompt = self._load_prompt("simple_mode_prompt.txt")
self.classifier_prompt = self._load_prompt("classifier_prompt.txt")
# Legacy variable compatibility
self.system_prompt = self.complex_prompt
# --- Models Client Setup ---
# Unified client for all models
self.model_server_url = os.getenv("MODEL_SERVER_URL", "http://localhost:8000")
self.models_client = Models_Client(server_url=self.model_server_url)
# --- RAG Setup ---
# 计算 ai_agent 根目录(从 groundcontrol/src 向上2级
ai_agent_root = os.path.abspath(os.path.join(self.base_dir, "../.."))
# 尝试从配置文件读取 RAG 配置
rag_config_path = os.path.join(ai_agent_root, "config", "rag_config.json")
rag_config = None
if os.path.exists(rag_config_path):
try:
rag_config = read_json_file(rag_config_path)
logging.info(f"成功读取 RAG 配置文件: {rag_config_path}")
except Exception as e:
logging.warning(f"读取 RAG 配置文件失败: {e},将使用默认配置")
# 确定向量数据库路径(优先使用配置文件,否则使用默认路径)
if rag_config and "rag_mag" in rag_config:
vector_store_path = rag_config["rag_mag"].get("vectorstore_persist_directory")
collection_name = rag_config["rag_mag"].get("collection_name", "osm_map_docs")
embedding_model_path = rag_config["rag_mag"].get("embedding_model_path")
embedding_type = rag_config["rag_mag"].get("embedding_framework_type", "llamacpp_embedding")
# 如果配置文件中是绝对路径,直接使用;否则转换为绝对路径
if vector_store_path and not os.path.isabs(vector_store_path):
vector_store_path = os.path.join(ai_agent_root, vector_store_path)
# 读取模型配置
if embedding_type == "llamacpp_embedding":
model_config = rag_config["rag_mag"].get("model_config_llamacpp", {})
else:
model_config = rag_config["rag_mag"].get("model_config_huggingFace", {})
else:
# 使用默认配置:指向 scripts/ai_agent/memory/knowledge_base/map/vector_store/osm_map1
vector_store_path = os.path.join(ai_agent_root, "memory", "knowledge_base", "map", "vector_store", "osm_map1")
collection_name = "osm_map_docs"
embedding_model_path = os.getenv("EMBEDDING_MODEL_PATH", "/home/huangfukk/models/gguf/Qwen/Qwen3-Embedding-4B/Qwen3-Embedding-4B-Q4_K_M.gguf")
embedding_type = "llamacpp_embeddings"
# 默认模型配置
model_config = {
"n_ctx": 512,
"n_threads": 4,
"n_gpu_layers": 0,
"n_seq_max": 256,
"n_threads_batch": 4,
"flash_attn": False,
"verbose": False
}
# 确保路径是绝对路径
vector_store_path = os.path.abspath(vector_store_path)
logging.info(f"向量数据库路径: {vector_store_path}")
logging.info(f"集合名称: {collection_name}")
logging.info(f"初始化 RAG嵌入模型路径: {embedding_model_path}")
if os.path.exists(embedding_model_path):
# 设置嵌入类型名称(配置文件使用 "llamacpp_embedding",函数期望 "llamacpp_embeddings"
if embedding_type == "llamacpp_embedding":
embedding_type = "llamacpp_embeddings"
self.embeddings = set_embeddings(
embedding_model_path=embedding_model_path,
embedding_type=embedding_type,
model_config=model_config
)
# 加载现有向量数据库
self.vector_db = load_vector_database(
embeddings=self.embeddings,
persist_directory=vector_store_path,
collection_name=collection_name
)
if not self.vector_db:
logging.warning(f"向量数据库未找到于 {vector_store_path},集合名称: {collection_name}。RAG 检索将返回空结果,直到数据库被填充。")
else:
logging.info(f"✅ 成功加载向量数据库: {vector_store_path}")
else:
logging.warning(f"嵌入模型未找到于 {embedding_model_path}。RAG 将被禁用。")
self.embeddings = None
self.vector_db = None
# Parse allowed nodes from prompt for validation
allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.complex_prompt)
self.schema = _generate_pytree_schema(allowed_actions, allowed_conditions)
self.simple_schema = _generate_simple_mode_schema(allowed_actions)
def _load_prompt(self, file_name: str) -> str:
try:
with open(os.path.join(self.prompts_dir, file_name), 'r', encoding='utf-8') as f:
return f.read()
except FileNotFoundError:
logging.error(f"提示词文件未找到 -> {file_name}")
return ""
def _retrieve_context(self, query: str, timeout: int = 15) -> Optional[str]:
"""
Retrieve context from vector database with timeout protection.
"""
if not self.vector_db:
return None
logging.info("--- Retrieving context from Vector DB ---")
try:
import threading
result_container = []
exception_container = []
def query_func():
try:
# 直接从 Chroma 拿原始距离分数
docs_and_scores = self.vector_db.similarity_search_with_score(query, k=5)
if not docs_and_scores:
result_container.append([])
return
# 提取距离,并做一次查询内的 min-max 归一化 -> 相似度 ∈ [0,1]
distances = [score for _, score in docs_and_scores]
min_d, max_d = min(distances), max(distances)
normalized_results = []
for (doc, dist) in docs_and_scores:
if max_d == min_d:
sim = 1.0 # 所有距离一样时,认为同等相似
else:
sim = (max_d - dist) / (max_d - min_d) # 距离小 -> 相似度高
# 使用绝对阈值 0.2 的相似度过滤
if sim >= 0.2:
normalized_results.append({
"content": doc.page_content,
"metadata": doc.metadata,
"similarity_score": float(sim)
})
result_container.append(normalized_results)
except Exception as e:
exception_container.append(e)
thread = threading.Thread(target=query_func)
thread.daemon = True
thread.start()
thread.join(timeout=timeout)
if thread.is_alive():
logging.warning(f"⚠️ Vector DB query timed out (> {timeout}s), skipping RAG.")
return None
if exception_container:
raise exception_container[0]
results = result_container[0] if result_container else []
if not results:
logging.warning("Vector DB returned no results.")
return None
# Extract content from results (List[Dict])
# Each result has 'content', 'metadata', 'similarity_score'
context_list = [doc['content'] for doc in results]
context_str = "\n\n".join(context_list)
logging.info(f"✅ Successfully retrieved {len(results)} documents.")
return context_str
except Exception as e:
logging.error(f"RAG retrieval error: {e}")
return None
async def generate(self, user_prompt: str) -> Dict[str, Any]:
"""
Generates a py_tree.json structure based on the user's prompt.
"""
total_start_time = time.time()
logging.info(f"接收到用户请求: {user_prompt}")
# 第一步:分类(简单/复杂)
mode = "complex"
classify_start = time.time()
try:
# Construct request for classification
# 分类请求使用简单的 JSON Schema 约束输出格式
classifier_schema = {
"type": "object",
"properties": {
"mode": {"type": "string", "enum": ["simple", "complex"]}
},
"required": ["mode"],
"additionalProperties": False
}
req = TextInferenceRequest(
user_prompt=user_prompt,
system_prompt=self.classifier_prompt or "你是一个分类器只输出JSON。",
temperature=0.0,
max_tokens=100,
top_p=0.95,
json_schema=classifier_schema # 传递 JSON Schema 约束分类输出
)
# Call unified models client
resp = self.models_client.text_inference(req)
if "error" in resp:
logging.warning(f"分类请求失败: {resp['error']}")
else:
class_str = resp.get("result", "{}")
# Try to clean markdown if present
class_str = re.sub(r'```json\s*', '', class_str).replace('```', '').strip()
try:
class_obj = json.loads(class_str)
if isinstance(class_obj, dict) and class_obj.get("mode") in ("simple", "complex"):
mode = class_obj.get("mode")
except json.JSONDecodeError:
pass # Default to complex
classify_time = time.time() - classify_start
logging.info(f"分类结果: {mode}, ⏱️ 耗时: {classify_time:.2f}")
except Exception as e:
classify_time = time.time() - classify_start
logging.warning(f"分类过程发生异常,默认按复杂指令处理: {e}, ⏱️ 耗时: {classify_time:.2f}")
# 第二步:根据模式准备提示词与上下文(简单与复杂都执行检索增强)
# 基于模式选择提示词;复杂模式追加一条强制规则,避免模型误输出简单结构
use_prompt = self.simple_prompt if mode == "simple" else (
(self.complex_prompt or "") +
"\n\n【强制规则】仅生成包含root的复杂行为树JSON不得输出简单模式不得包含mode字段或仅有action节点"
)
final_user_prompt = user_prompt
retrieval_start = time.time()
retrieved_context = self._retrieve_context(user_prompt)
retrieval_time = time.time() - retrieval_start
logging.info(f"⏱️ 检索耗时: {retrieval_time:.2f}")
if retrieved_context:
augmentation = (
"\n\n---\n"
"参考知识:\n"
"以下是从知识库中检索到的、与当前任务最相关的信息,请优先参考这些信息来生成结果:\n"
f"{retrieved_context}"
"\n---"
)
final_user_prompt += augmentation
else:
logging.warning("未检索到上下文或检索失败,将使用原始用户提示词。")
for attempt in range(3):
logging.info(f"--- 第 {attempt + 1}/3 次尝试生成Pytree ---")
try:
generation_start = time.time()
# Construct generation request with JSON Schema
# 根据模式选择对应的 JSON Schema
selected_schema = self.simple_schema if mode == "simple" else self.schema
req = TextInferenceRequest(
user_prompt=final_user_prompt,
system_prompt=use_prompt,
temperature=0.1 if mode == "complex" else 0.0,
max_tokens=4096, # Large enough for complex trees
top_p=0.95,
json_schema=selected_schema # 传递 JSON Schema 约束输出格式
)
resp = self.models_client.text_inference(req)
if "error" in resp:
logging.error(f"生成请求失败: {resp['error']}")
continue
pytree_str = resp.get("result", "")
generation_time = time.time() - generation_start
logging.info(f"⏱️ LLM生成耗时: {generation_time:.2f}")
# Clean markdown code blocks if present
clean_pytree_str = re.sub(r'```json\s*', '', pytree_str).replace('```', '').strip()
# 单独捕获JSON解析错误并打印原始响应
try:
pytree_dict = json.loads(clean_pytree_str)
except json.JSONDecodeError as e:
logging.error(f"❌ JSON解析失败{attempt + 1}/3 次)。原始响应如下:\n{pytree_str}")
continue
# 简单/复杂分别验证与返回
if mode == "simple":
try:
jsonschema.validate(instance=pytree_dict, schema=self.simple_schema)
logging.info("✅ 简单模式JSON Schema验证成功")
except jsonschema.ValidationError as e:
logging.warning(f"❌ 简单模式验证失败: {e.message}")
continue
# 附加元信息并生成简单可视化(单动作)
plan_id = str(uuid.uuid4())
pytree_dict['plan_id'] = plan_id
# 简单模式可视化直接使用root节点
try:
vis_start = time.time()
vis_filename = "py_tree.png"
vis_path = os.path.join(self.vis_dir, vis_filename)
_visualize_pytree(pytree_dict['root'], os.path.splitext(vis_path)[0])
vis_time = time.time() - vis_start
logging.info(f"⏱️ 可视化耗时: {vis_time:.2f}")
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
except Exception as e:
logging.warning(f"简单模式可视化失败: {e}")
total_time = time.time() - total_start_time
logging.info(f"🎉 ⏱️ 总耗时: {total_time:.2f}秒 (分类:{classify_time:.1f}s + 检索:{retrieval_time:.1f}s + 生成:{generation_time:.1f}s + 可视化:{vis_time if 'vis_time' in locals() else 0:.1f}s)")
return pytree_dict
# 复杂模式回退:若模型误返回简单结构,则自动包装为含安全监控的行为树
if mode == "complex" and isinstance(pytree_dict, dict) and 'root' not in pytree_dict:
try:
jsonschema.validate(instance=pytree_dict, schema=self.simple_schema)
logging.warning("⚠️ 复杂模式生成了简单结构,触发自动包装为完整行为树的回退逻辑。")
simple_action_obj = pytree_dict.get('action') or {}
action_name = simple_action_obj.get('name')
action_params = simple_action_obj.get('params') if isinstance(simple_action_obj.get('params'), dict) else {}
safety_selector = {
"type": "Selector",
"name": "SafetyMonitor",
"params": {"memory": True},
"children": [
{"type": "condition", "name": "battery_above", "params": {"threshold": 0.3}},
{"type": "condition", "name": "gps_status", "params": {"min_satellites": 8}},
{"type": "Sequence", "name": "EmergencyHandler", "children": [
{"type": "action", "name": "emergency_return", "params": {"reason": "safety_breach"}},
{"type": "action", "name": "land", "params": {"mode": "home"}}
]}
]
}
main_children = [{"type": "action", "name": action_name, "params": action_params}]
if action_name != "land":
main_children.append({"type": "action", "name": "land", "params": {"mode": "home"}})
root_parallel = {
"type": "Parallel",
"name": "MissionWithSafety",
"params": {"policy": "all_success"},
"children": [
{"type": "Sequence", "name": "MainTask", "children": main_children},
safety_selector
]
}
pytree_dict = {"root": root_parallel}
except jsonschema.ValidationError:
# 不符合简单结构,按正常复杂验证继续
pass
if _validate_pytree_with_schema(pytree_dict, self.schema):
logging.info("✅ 成功生成并验证了Pytree")
plan_id = str(uuid.uuid4())
pytree_dict['plan_id'] = plan_id
# Generate visualization to a static path
vis_start = time.time()
vis_filename = "py_tree.png"
vis_path = os.path.join(self.vis_dir, vis_filename)
_visualize_pytree(pytree_dict['root'], os.path.splitext(vis_path)[0])
vis_time = time.time() - vis_start
logging.info(f"⏱️ 可视化耗时: {vis_time:.2f}")
pytree_dict['visualization_url'] = f"/static/{vis_filename}"
total_time = time.time() - total_start_time
logging.info(f"🎉 ⏱️ 总耗时: {total_time:.2f}秒 (分类:{classify_time:.1f}s + 检索:{retrieval_time:.1f}s + 生成:{generation_time:.1f}s + 可视化:{vis_time:.1f}s)")
return pytree_dict
else:
# 打印未通过验证的Pytree以便排查
preview = json.dumps(pytree_dict, ensure_ascii=False, indent=2)
logging.warning(f"❌ 未通过验证的Pytree{attempt + 1}/3 次尝试):\n{preview}")
logging.warning("生成的Pytree验证失败正在重试...")
except Exception as e:
logging.error(f"生成Pytree时发生错误: {e}")
raise RuntimeError("在3次尝试后仍未能生成一个有效的Pytree。")
# Create a single instance for the application
py_tree_generator = PyTreeGenerator()

View File

@@ -0,0 +1,75 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import requests
import json
# --- Configuration ---
# The base URL of your running FastAPI service
BASE_URL = "http://127.0.0.1:8001"
# The endpoint we want to test
ENDPOINT = "/generate_plan"
# The user prompt we will send for the test
TEST_PROMPT = "起飞,飞到匡亚明学院"
def test_generate_plan():
"""
Sends a request to the /generate_plan endpoint and validates the response.
"""
url = BASE_URL + ENDPOINT
payload = {"user_prompt": TEST_PROMPT}
headers = {"Content-Type": "application/json"}
print("--- API Test: Generate Plan ---")
print(f"✅ URL: {url}")
print(f"✅ Sending Prompt: \"{TEST_PROMPT}\"")
try:
# Send the POST request
response = requests.post(url, data=json.dumps(payload), headers=headers)
# Check for HTTP errors (e.g., 404, 500)
response.raise_for_status()
# Parse the JSON response
data = response.json()
print("✅ Received Response:")
print(json.dumps(data, indent=2, ensure_ascii=False))
# --- Validation ---
print("\n--- Validation Checks ---")
# 1. Check if the response is a dictionary
if isinstance(data, dict):
print("PASS: Response is a valid JSON object.")
else:
print("FAIL: Response is not a valid JSON object.")
return
# 2. Check for the existence of the 'root' key
if "root" in data and isinstance(data['root'], dict):
print("PASS: Response contains a valid 'root' key.")
else:
print("FAIL: Response does not contain a valid 'root' key.")
# 3. Check for the existence and format of the 'visualization_url' key
if "visualization_url" in data and data["visualization_url"].endswith(".png"):
print(f"PASS: Response contains a valid 'visualization_url': {data['visualization_url']}")
else:
print("FAIL: Response does not contain a valid 'visualization_url'.")
except requests.exceptions.RequestException as e:
print(f"\n❌ TEST FAILED: Could not connect to the server.")
print(" Please make sure the backend service is running.")
print(f" Error details: {e}")
except json.JSONDecodeError:
print(f"\n❌ TEST FAILED: The server response was not valid JSON.")
print(f" Response text: {response.text}")
except Exception as e:
print(f"\n❌ TEST FAILED: An unexpected error occurred: {e}")
if __name__ == "__main__":
test_generate_plan()

View File

@@ -0,0 +1,48 @@
import asyncio
from typing import List
from fastapi import WebSocket
import logging
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
self.loop: asyncio.AbstractEventLoop | None = None
def set_loop(self, loop: asyncio.AbstractEventLoop):
"""Sets the asyncio event loop."""
self.loop = loop
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
def broadcast(self, message: str):
"""
Thread-safely broadcasts a message to all active WebSocket connections.
This method is designed to be called from a different thread (e.g., a ROS2 callback).
"""
if not self.loop:
logging.error("Event loop not set in ConnectionManager. Cannot broadcast.")
return
# Schedule the coroutine to be executed in the event loop
self.loop.call_soon_threadsafe(self._broadcast_in_loop, message)
def _broadcast_in_loop(self, message: str):
"""
Helper to run the broadcast coroutine in the correct event loop.
"""
asyncio.ensure_future(self._broadcast_async(message), loop=self.loop)
async def _broadcast_async(self, message: str):
"""
The actual async method that sends messages.
"""
tasks = [connection.send_text(message) for connection in self.active_connections]
await asyncio.gather(*tasks, return_exceptions=True)
# Create a single instance of the manager to be used across the application
websocket_manager = ConnectionManager()

View File

@@ -0,0 +1,20 @@
#!/bin/bash
# 后端服务启动脚本
# 用于启动 FastAPI 后端服务
# 获取脚本所在目录
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
# 设置 Python 路径
export PYTHONPATH="${SCRIPT_DIR}:${PYTHONPATH}"
# 启动后端服务
echo "正在启动后端服务..."
echo "工作目录: $SCRIPT_DIR"
echo "监听地址: 0.0.0.0:8001"
echo ""
python -m uvicorn src.main:app --host 0.0.0.0 --port 8001