diff --git a/scripts/ai_agent/groundcontrol b/scripts/ai_agent/groundcontrol deleted file mode 160000 index d026107..0000000 --- a/scripts/ai_agent/groundcontrol +++ /dev/null @@ -1 +0,0 @@ -Subproject commit d026107bc214ea16a425bbb2cac100e8497aadf2 diff --git a/scripts/ai_agent/groundcontrol/.gitignore b/scripts/ai_agent/groundcontrol/.gitignore new file mode 100644 index 0000000..d82fa7a --- /dev/null +++ b/scripts/ai_agent/groundcontrol/.gitignore @@ -0,0 +1,143 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# static files generated from Django application using `collectstatic` +media +static diff --git a/scripts/ai_agent/groundcontrol/README.md b/scripts/ai_agent/groundcontrol/README.md new file mode 100644 index 0000000..9d2dca5 --- /dev/null +++ b/scripts/ai_agent/groundcontrol/README.md @@ -0,0 +1,345 @@ +# 无人机自然语言控制项目 + +本项目构建了一个完整的无人机自然语言控制系统,集成了检索增强生成(RAG)知识库、大型语言模型(LLM)、FastAPI后端服务和ROS2通信,最终实现通过自然语言指令控制无人机执行复杂任务。 + +## 项目结构 + +项目被清晰地划分为几个核心模块: + +``` +. +├── backend_service/ +│ ├── src/ # FastAPI应用核心代码 +│ │ ├── __init__.py +│ │ ├── main.py # 应用主入口,提供Web API +│ │ ├── py_tree_generator.py # RAG与LLM集成,生成py_tree +│ │ ├── prompts/ # LLM 提示词 +│ │ │ ├── system_prompt.txt # 复杂模式提示词(行为树与安全监控) +│ │ │ ├── simple_mode_prompt.txt # 简单模式提示词(单一原子动作JSON) +│ │ │ └── classifier_prompt.txt # 指令简单/复杂分类提示词 +│ │ ├── ... +│ ├── generated_visualizations/ # 存放最新生成的py_tree可视化图像 +│ └── requirements.txt # 后端服务的Python依赖 +│ +├── tools/ +│ ├── map/ # 【数据源】存放原始地图文件(如.world, .json) +│ ├── knowledge_base/ # 【处理后】存放build_knowledge_base.py生成的.ndjson文件 +│ ├── vector_store/ # 【数据库】存放最终的ChromaDB向量数据库 +│ ├── build_knowledge_base.py # 【步骤1】用于将原始数据转换为自然语言知识 +│ └── ingest.py # 【步骤2】用于将自然语言知识摄入向量数据库 +│ +├── / # ROS2接口定义 (保持不变) +└── docs/ + └── README.md # 本说明文件 +``` + +## 核心配置:Orin IP 地址 + +**重要提示:** 本项目的后端服务和知识库工具需要与在NVIDIA Jetson Orin设备上运行的服务进行通信(嵌入模型和LLM推理服务),**默认的IP地址为localhost**,所以使用电脑本地部署的模型服务同样可以,但是需要注意指定模型的端口。 + +在使用前,您必须配置正确的Orin设备IP地址。您可以通过以下两种方式之一进行设置: + +1. **设置环境变量 (推荐)**: + 在您的终端中设置一个名为 `ORIN_IP` 的环境变量。 + ```bash + export ORIN_IP="192.168.1.100" # 请替换为您的Orin设备的实际IP地址 + ``` + 脚本会优先使用这个环境变量。 + +2. **直接修改脚本**: + 如果您不想设置环境变量,可以打开 `tools/ingest.py` 和 `backend_service/src/py_tree_generator.py` 文件,找到 `orin_ip = os.getenv("ORIN_IP", "...")` 这样的行,并将默认的IP地址修改为您的Orin设备的实际IP地址。 + +**在继续后续步骤之前,请务必完成此项配置。** + +## 模型端口启动 + +本项目启动依赖于后端的模型推理服务,即`ORIN_IP`所指向的设备的模型服务端口,目前项目使用instruct模型与embedding模型实现流程,分别部署在8081端口与8090端口。 + +1. **推理模型部署**: + + 在`/llama.cpp/build/bin`路径下执行以下命令启动模型 + ```bash + ./llama-server -m ~/models/gguf/Qwen/Qwen3-8B-GGUF/Qwen3-8B-Q4_K_M.gguf --port 8081 --gpu-layers 36 --host 0.0.0.0 -c 8192 + ``` + +2. **Embedding模型部署** + + 在`/llama.cpp/build/bin`路径下执行以下命令启动模型 + ```bash + ./llama-server -m ~/models/gguf/Qwen/Qwen3-embedding-4B/Qwen3-Embedding-4B-Q4_K_M.gguf --gpu-layers 36 --port 8090 --embeddings --pooling last --host 0.0.0.0 + ``` + +--- + +## 指令分类与分流 + +后端在生成任务前会先对用户指令进行“简单/复杂”分类,并分流到不同提示词与模型: + +- 分类提示词:`backend_service/src/prompts/classifier_prompt.txt` +- 简单模式提示词:`backend_service/src/prompts/simple_mode_prompt.txt` +- 复杂模式提示词:`backend_service/src/prompts/system_prompt.txt` + +分类仅输出如下JSON之一:`{"mode":"simple"}` 或 `{"mode":"complex"}`。两种模式都会执行检索增强(RAG),将参考知识拼接到用户指令后再进行推理。 + +当为简单模式时,LLM仅输出: +`{"mode":"simple","action":{"name":"","params":{...}}}`。 +后端不会再自动封装为复杂行为树;将直接返回简单JSON,并附加 `plan_id` 与 `visualization_url`(单动作可视化)。 + +### 环境变量(可选) + +支持为“分类/简单/复杂”三类调用分别配置模型与Base URL(未设置时回退到默认本地配置): + +- `CLASSIFIER_MODEL`, `CLASSIFIER_BASE_URL` +- `SIMPLE_MODEL`, `SIMPLE_BASE_URL` +- `COMPLEX_MODEL`, `COMPLEX_BASE_URL` + +通用API Key:`OPENAI_API_KEY` + +示例: +```bash +export CLASSIFIER_MODEL="qwen2.5-1.8b-instruct" +export SIMPLE_MODEL="qwen2.5-1.8b-instruct" +export COMPLEX_MODEL="qwen2.5-7b-instruct" +export CLASSIFIER_BASE_URL="http://$ORIN_IP:8081/v1" +export SIMPLE_BASE_URL="http://$ORIN_IP:8081/v1" +export COMPLEX_BASE_URL="http://$ORIN_IP:8081/v1" +export OPENAI_API_KEY="sk-no-key-required" +``` + +### 测试简单模式 + +启动服务后,运行内置测试脚本: + +```bash +cd tools +python test_api.py +``` + +示例输入:“简单模式,起飞” 或 “起飞到10米”。返回结果为简单JSON(无 `root`):包含 `mode`、`action`、`plan_id`、`visualization_url`。 + +--- + +## 工作流程 + +整个系统的工作流程分为两个主要阶段: + +1. **知识库构建(一次性设置)**: 将环境信息、无人机能力等知识加工并存入向量数据库。 +2. **后端服务运行与交互**: 启动主服务,通过API接收指令、生成并执行任务。 + +### 阶段一:环境设置与编译 + +此阶段为项目准备好运行环境,仅需在初次配置或依赖变更时执行。一个稳定、隔离且兼容的环境是所有后续步骤成功的基础。 + +#### 1. 创建Conda环境 (关键步骤) + +为了从根源上避免本地Python环境与系统ROS 2环境的库版本冲突(特别是Python版本和C++标准库),我们**必须**使用Conda创建一个干净、隔离且版本精确的虚拟环境。 + +```bash +# 1. 创建一个使用Python 3.10的新环境。 +# --name backend: 指定环境名称。 +# python=3.10: 指定Python版本,必须与ROS 2 Humble要求的版本一致。 +# --channel conda-forge: 使用conda-forge社区源,其包通常有更好的兼容性。 +# --no-default-packages: 关键!不安装Conda默认的包(如libgcc),避免与系统ROS 2的C++库冲突。 +conda create --name backend --channel conda-forge --no-default-packages python=3.10 + +# 2. 激活新创建的环境 +conda activate backend +``` + +#### 2. 安装所有Python依赖 + +在激活`backend`环境后,使用`pip`一次性安装所有依赖。`requirements.txt`已包含**运行时**(如fastapi, rclpy)和**编译时**(如empy, catkin-pkg, lark)所需的所有库。 + +```bash +# 确保在项目根目录 (drone/) 下执行 +pip install -r backend_service/requirements.txt +``` + +#### 3. 编译ROS 2接口 + +为了让后端服务能够像导入普通Python包一样导入我们自定义的Action接口 (`drone_interfaces`),你需要先使用`colcon`对其进行编译。 + +```bash +# 确保在项目根目录 (drone/) 下执行 +colcon build +``` +成功后,您会看到`build/`, `install/`, `log/`三个新目录。这一步会将`.action`文件转换为Python和C++代码。 + +--- + +### 阶段二:数据处理流水线 + +此阶段为RAG系统准备数据,让LLM能够理解任务环境。 + +#### 1. 准备原始数据 + +将你的原始数据文件(例如,`.world`, `.json` 文件等)放入 `tools/map/` 目录中。 + +#### 2. 数据预处理 + +运行脚本将原始数据“翻译”成自然语言知识。 + +```bash +# 确保在项目根目录 (drone/) 下,并已激活backend环境 +cd tools +python build_knowledge_base.py +``` +该脚本会扫描 `tools/map/` 目录,并在 `tools/knowledge_base/` 目录下生成对应的 `_knowledge.ndjson` 文件。 + +#### 3. 数据入库(Ingestion) + +运行脚本将处理好的知识加载到向量数据库中。 + +```bash +# 仍在tools/目录下执行 +python ingest.py +``` +该脚本会自动扫描 `tools/knowledge_base/` 目录,并将数据存入 `tools/vector_store/` 目录中。 + +--- + +### 阶段三:服务启动与测试 + +完成前两个阶段后,即可启动并测试后端服务。 + +#### 1. 启动后端服务 + +启动服务的关键在于**按顺序激活环境**:先激活ROS 2工作空间,再激活Conda环境。 + +```bash +# 1. 切换到项目根目录 +cd /path/to/your/drone + +# 2. 激活ROS 2编译环境 +# 作用:将我们编译好的`drone_interfaces`包的路径告知系统,否则Python会报`ModuleNotFoundError`。 +# 注意:此命令必须在每次打开新终端时执行一次。 +source install/setup.bash + +# 3. 激活Conda Python环境 +conda activate backend + +# 4. 启动FastAPI服务 +cd backend_service/ +uvicorn src.main:app --host 0.0.0.0 --port 8000 +``` +当您看到日志中出现 `Uvicorn running on http://0.0.0.0:8000` 时,表示服务已成功启动。 + +#### 2. 运行API接口测试 + +我们提供了一个脚本来验证核心的“任务生成”功能。 + +**打开一个新的终端**,并执行以下命令: + +```bash +# 1. 切换到项目根目录 +cd /path/to/your/drone + +# 2. 激活Conda环境 +conda activate backend + +# 3. 运行测试脚本 +cd tools/ +python test_api.py +``` +如果一切正常,您将在终端看到一系列 `PASS` 信息,以及从服务器返回的Pytree JSON。 + +#### 3. API接口使用说明 + +--- + +## 故障排除 / 常见问题 (FAQ) + +以下是在配置和运行此项目时可能遇到的一些常见问题及其解决方案。 + +#### **Q1: 启动服务时报错 `ModuleNotFoundError: No module named 'drone_interfaces'`** + +- **原因**: 您当前的终端环境没有加载ROS 2工作空间的路径。仅仅激活Conda环境是不够的。 +- **解决方案**: 严格遵循“启动后端服务”章节的说明,在激活Conda环境**之前**,必须先运行 `source install/setup.bash` 命令。 + +#### **Q2: `colcon build` 编译失败,提示 `ModuleNotFoundError: No module named 'em'`, `'catkin_pkg'`, 或 `'lark'`** + +- **原因**: 您的Python环境中缺少ROS 2编译代码时所必需的依赖包。 +- **解决方案**: 我们已将所有已知的编译时依赖(`empy`, `catkin-pkg`, `lark`等)添加到了`requirements.txt`中。请确保您已激活正确的Conda环境,然后运行 `pip install -r backend_service/requirements.txt` 来安装它们。 + +#### **Q3: 启动服务时报错 `ImportError: ... GLIBCXX_... not found` 或 `ModuleNotFoundError: No module named 'rclpy._rclpy_pybind11'`** + +- **原因**: 您的Conda环境与系统ROS 2环境存在核心库冲突。最常见的原因是Python版本不匹配(例如,Conda是Python 3.11而ROS 2 Humble需要3.10),或者Conda自带的C++库与系统库冲突。 +- **解决方案**: 这是最棘手的环境问题。最可靠的解决方法是彻底删除当前的Conda环境 (`conda env remove --name backend`),然后严格按照本文档「环境设置」章节的说明,用正确的命令 (`conda create --name backend --channel conda-forge --no-default-packages python=3.10`) 重建一个干净、兼容的环境。 + +#### **Q4: 服务启动时,日志显示正在从网络上下载模型(例如 `all-MiniLM-L6-v2`)** + +- **原因**: 后端服务在连接向量数据库时,没有正确指定使用远程嵌入模型,导致ChromaDB退回到默认的、需要下载模型的本地嵌入函数。 +- **解决方案**: 此问题在当前代码中**已被修复**。`backend_service/src/py_tree_generator.py`现在会正确地将远程嵌入函数实例传递给ChromaDB。如果您在自己的代码中遇到此问题,请检查您的`get_collection`调用。 + +#### **Q5: 服务启动时,日志停在 `waiting for action server...`,无法访问API** + +- **原因**: 代码中存在阻塞式的`wait_for_server()`调用,它会一直等待直到无人机端的Action服务器上线,从而卡住了Web服务的启动流程。 +- **解决方案**: 此问题在当前代码中**已被修复**。`backend_service/src/ros2_client.py`现在使用非阻塞的方式初始化,并在发送任务时检查服务器是否可用。 + +##### **A. 生成任务计划** + +接收自然语言指令,返回生成的行为树(py_tree)JSON。 + +- **Endpoint**: `POST /generate_plan` +- **Request Body**: + ```json + { + "user_prompt": "无人机起飞到10米,然后前往机库,最后降落。" + } + ``` +- **Success Response(复杂模式)**: + ```json + { + "root": { ... }, + "plan_id": "some-unique-id", + "visualization_url": "/static/py_tree.png" + } + ``` +- **Success Response(简单模式)**: + ```json + { + "mode": "simple", + "action": { "name": "takeoff", "params": { "altitude": 10.0 } }, + "plan_id": "some-unique-id", + "visualization_url": "/static/py_tree.png" + } + ``` + +##### **B. 查看任务可视化** + +获取最新生成的行为树的可视化图像。 + +- **Endpoint**: `GET /static/py_tree.png` +- **Usage**: 在浏览器中直接打开 `http://<服务器IP>:8000/static/py_tree.png` 即可查看。每次成功调用 `/generate_plan` 后,该图像都会被更新。 + +##### **C. 执行任务** + +接收一个py_tree JSON,下发给无人机执行(当前为模拟执行)。 + +- **Endpoint**: `POST /execute_mission` +- **Request Body**: (使用 `/generate_plan` 返回的 `root` 对象) + ```json + { + "py_tree": { + "root": { ... } + } + } + ``` +- **Response**: + ```json + { + "status": "execution_started" + } + ``` + +##### **D. 接收实时状态** + +通过WebSocket连接,实时接收无人机在执行任务时的状态反馈。 + +- **Endpoint**: `WS /ws/status` +- **Usage**: 使用任意WebSocket客户端连接到 `ws://<服务器IP>:8000/ws/status`。当任务执行时,服务器会主动推送JSON消息,例如: + ```json + {"node_id": "takeoff_node_1", "status": 0} // 0: RUNNING + {"node_id": "takeoff_node_1", "status": 1} // 1: SUCCESS + ``` + \ No newline at end of file diff --git a/scripts/ai_agent/groundcontrol/generated_visualizations/py_tree.png b/scripts/ai_agent/groundcontrol/generated_visualizations/py_tree.png new file mode 100644 index 0000000..e23bc82 Binary files /dev/null and b/scripts/ai_agent/groundcontrol/generated_visualizations/py_tree.png differ diff --git a/scripts/ai_agent/groundcontrol/requirements.txt b/scripts/ai_agent/groundcontrol/requirements.txt new file mode 100644 index 0000000..7507870 --- /dev/null +++ b/scripts/ai_agent/groundcontrol/requirements.txt @@ -0,0 +1,39 @@ +# Web Framework and Server +fastapi>=0.104.0 +uvicorn[standard]>=0.24.0 +python-multipart>=0.0.6 +websockets>=12.0 + +# Data Validation and Serialization +pydantic>=2.5.0 +jsonschema>=4.20.0 + +# AI and Vector Database +openai>=1.3.0 +chromadb>=0.4.0 + +# Visualization +graphviz>=0.20.0 + +# ROS 2 Python Client +rclpy>=0.0.1 + +# Document Processing +unstructured[all]>=0.11.0 + +# HTTP Requests +requests>=2.31.0 + +# Progress Bars and UI +rich>=13.7.0 + +# Type Hints Support +typing-extensions>=4.8.0 + +# ROS 2 Build Dependencies +empy==3.3.4 +catkin-pkg>=0.4.0 +lark>=1.1.0 +colcon-common-extensions>=0.3.0 +vcstool>=0.2.0 +rosdep>=0.22.0 diff --git a/scripts/ai_agent/groundcontrol/src/__init__.py b/scripts/ai_agent/groundcontrol/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/ai_agent/groundcontrol/src/main.py b/scripts/ai_agent/groundcontrol/src/main.py new file mode 100644 index 0000000..b4c2bcf --- /dev/null +++ b/scripts/ai_agent/groundcontrol/src/main.py @@ -0,0 +1,225 @@ +import asyncio +import os +import sys +import subprocess +import time +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.staticfiles import StaticFiles +import logging +import threading + +# --- Sys Path Injection --- +# Add ai_agent root to sys.path +current_dir = os.path.dirname(os.path.abspath(__file__)) +# Go up 2 levels: src -> groundcontrol -> ai_agent +ai_agent_root = os.path.abspath(os.path.join(current_dir, "../..")) +# Also add models directory specifically for implicit imports if needed +models_dir = os.path.join(ai_agent_root, "models") + +if ai_agent_root not in sys.path: + sys.path.append(ai_agent_root) +if models_dir not in sys.path: + sys.path.append(models_dir) + +# Now we can import from ai_agent modules +from models.models_client import Models_Client +from tools.core.common_functions import read_json_file + +from .models import GeneratePlanRequest, ExecuteMissionRequest +from .websocket_manager import websocket_manager +from .py_tree_generator import py_tree_generator + +# --- Global Variables --- +model_server_process = None +MODEL_SERVER_URL = os.getenv("MODEL_SERVER_URL", "http://localhost:8000") + +# --- Application Setup --- +app = FastAPI( + title="Drone Backend Service", + description="Handles mission planning, generation, and execution for the drone.", + version="1.0.0", +) + +# --- Mount Static Files for Visualizations --- +static_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'generated_visualizations')) +if not os.path.exists(static_dir): + os.makedirs(static_dir) +app.mount("/static", StaticFiles(directory=static_dir), name="static") + +def start_model_server(): + """Starts the model server if it's not already running.""" + global model_server_process + client = Models_Client(MODEL_SERVER_URL) + + # Check if already running + try: + logging.info(f"Checking model server health at {MODEL_SERVER_URL}...") + health = client.check_health() + if health.get("status") in ["healthy", "loading"]: + logging.info("Model server is already running.") + return + except Exception: + pass # Not running + + logging.info("Starting model server...") + + # 读取配置文件 + try: + config_json_file = os.path.join(ai_agent_root, "config", "model_config.json") + config_data = read_json_file(config_json_file) + if config_data is None or "models_server" not in config_data: + logging.error(f"无法读取配置文件或配置文件中缺少 'models_server' 字段: {config_json_file}") + return + + # 如果 models_server 是列表,取第一个;如果是字典,直接使用 + models_server_data = config_data["models_server"] + model_server_config = models_server_data[0] if isinstance(models_server_data, list) else models_server_data + + host = model_server_config.get("model_server_host", "localhost") + port = model_server_config.get("model_server_port", 8000) + workers = model_server_config.get("model_controller_workers", 1) + except Exception as e: + logging.error(f"读取配置文件失败: {e},使用默认配置") + host = "localhost" + port = 8000 + workers = 1 + + # Start as subprocess using uvicorn + try: + def log_output(pipe, prefix): + """实时输出子进程的日志""" + try: + for line in iter(pipe.readline, b''): + if line: + line_str = line.decode('utf-8', errors='replace').strip() + if line_str: + logging.info(f"[Model Server {prefix}] {line_str}") + except Exception as e: + logging.error(f"Error reading {prefix}: {e}") + finally: + pipe.close() + + # 使用 uvicorn 启动模型服务器 + # 使用 models_server:create_app 作为应用入口 + uvicorn_cmd = [ + sys.executable, "-m", "uvicorn", + "models.models_server:create_app", + "--host", host, + "--port", str(port), + "--workers", str(workers), + "--factory" + ] + + model_server_process = subprocess.Popen( + uvicorn_cmd, + cwd=ai_agent_root, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + + # 启动线程实时输出日志 + stdout_thread = threading.Thread( + target=log_output, + args=(model_server_process.stdout, "STDOUT"), + daemon=True + ) + stderr_thread = threading.Thread( + target=log_output, + args=(model_server_process.stderr, "STDERR"), + daemon=True + ) + stdout_thread.start() + stderr_thread.start() + + # 先等待一小段时间,检查进程是否立即退出 + time.sleep(2) + if model_server_process.poll() is not None: + # 进程已退出,读取错误信息 + out, err = model_server_process.communicate() + error_msg = err.decode('utf-8', errors='replace') if err else out.decode('utf-8', errors='replace') if out else "Unknown error" + logging.error(f"Model server process exited immediately: {error_msg}") + return + + # Wait for it to be ready + logging.info("Waiting for model server to be ready...") + if client.wait_for_server(timeout=300, check_interval=5): + logging.info("Model server started and ready.") + else: + logging.error("Model server failed to start within timeout.") + # Check for errors + if model_server_process.poll() is not None: + out, err = model_server_process.communicate() + error_msg = err.decode('utf-8', errors='replace') if err else out.decode('utf-8', errors='replace') if out else 'Unknown error' + logging.error(f"Model server process exited with error: {error_msg}") + except Exception as e: + logging.error(f"Failed to launch model server: {e}") + +def stop_model_server(): + """Stops the model server subprocess if it was started by this service.""" + global model_server_process + if model_server_process: + logging.info("Stopping model server subprocess...") + model_server_process.terminate() + try: + model_server_process.wait(timeout=10) + except subprocess.TimeoutExpired: + model_server_process.kill() + logging.info("Model server stopped.") + +# --- API Endpoints --- + +@app.post("/generate_plan", response_model=dict) +async def generate_plan_endpoint(request: GeneratePlanRequest): + """ + Receives a user prompt and returns a generated `py_tree.json` with a visualization URL. + """ + try: + pytree_dict = await py_tree_generator.generate(request.user_prompt) + return pytree_dict + except RuntimeError as e: + return {"error": str(e)} + +@app.post("/execute_mission", response_model=dict) +async def execute_mission_endpoint(request: ExecuteMissionRequest): + """ + Receives a `py_tree.json` and sends it to the drone for execution. + """ + # ROS2 execution removed + return {"status": "execution_started (simulation mode)"} + +@app.websocket("/ws/status") +async def websocket_endpoint(websocket: WebSocket): + """ + Handles the WebSocket connection for real-time status updates. + """ + await websocket_manager.connect(websocket) + try: + while True: + await websocket.receive_text() + except WebSocketDisconnect: + websocket_manager.disconnect(websocket) + logging.info("Client disconnected from WebSocket.") + + +# --- Server Lifecycle --- + +@app.on_event("startup") +async def startup_event(): + """ + On startup, get the current asyncio event loop and pass it to the websocket manager. + Start the ROS2 node in a background thread. + Ensure the Model Server is running. + """ + # Configure WebSocket Manager + loop = asyncio.get_running_loop() + websocket_manager.set_loop(loop) + logging.info("WebSocket event loop configured.") + + # Start Model Server + threading.Thread(target=start_model_server, daemon=True).start() + +@app.on_event("shutdown") +async def shutdown_event(): + logging.info("Backend service shutting down.") + stop_model_server() + logging.info("Backend service shut down successfully.") diff --git a/scripts/ai_agent/groundcontrol/src/models.py b/scripts/ai_agent/groundcontrol/src/models.py new file mode 100644 index 0000000..33416a7 --- /dev/null +++ b/scripts/ai_agent/groundcontrol/src/models.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel +from typing import Dict, Any + +class GeneratePlanRequest(BaseModel): + user_prompt: str + +class ExecuteMissionRequest(BaseModel): + py_tree: Dict[str, Any] + +class StatusUpdate(BaseModel): + node_id: str + status: int diff --git a/scripts/ai_agent/groundcontrol/src/prompts/classifier_prompt.txt b/scripts/ai_agent/groundcontrol/src/prompts/classifier_prompt.txt new file mode 100644 index 0000000..98d551d --- /dev/null +++ b/scripts/ai_agent/groundcontrol/src/prompts/classifier_prompt.txt @@ -0,0 +1,24 @@ +你是一个严格的任务分类器。只输出一个JSON对象,不要输出解释或多余文本。 +根据用户指令与下述可用节点定义,判断其为“简单”或“复杂”。 + +- 简单:单一原子动作即可完成(例如“起飞”“飞机自检”“移动到某地(已给定坐标)”“对着某点环绕XY圈(如‘对着学生宿舍环绕三十两圈’)”等),且无需行为树与安全并行监控。 +- 复杂:需要多步流程、搜索/检测/跟踪/评估、战损确认、或需要模板化任务结构与安全并行监控。 + +输出格式(严格遵守): +{"mode":"simple"} 或 {"mode":"complex"} + +—— 可用节点定义—— +```json +{ + "actions": [ + {"name": "takeoff"}, {"name": "land"}, {"name": "fly_to_waypoint"}, {"name": "move_direction"}, {"name": "orbit_around_point"}, {"name": "orbit_around_target"}, {"name": "loiter"}, + {"name": "object_detect"}, {"name": "strike_target"}, {"name": "battle_damage_assessment"}, + {"name": "search_pattern"}, {"name": "track_object"}, {"name": "deliver_payload"}, + {"name": "preflight_checks"}, {"name": "emergency_return"} + ], + "conditions": [ + {"name": "battery_above"}, {"name": "at_waypoint"}, {"name": "object_detected"}, + {"name": "target_destroyed"}, {"name": "time_elapsed"}, {"name": "gps_status"} + ] +} +``` diff --git a/scripts/ai_agent/groundcontrol/src/prompts/simple_mode_prompt.txt b/scripts/ai_agent/groundcontrol/src/prompts/simple_mode_prompt.txt new file mode 100644 index 0000000..1de6f93 --- /dev/null +++ b/scripts/ai_agent/groundcontrol/src/prompts/simple_mode_prompt.txt @@ -0,0 +1,63 @@ +你是一个无人机简单指令执行规划器。你的任务:当用户给出“简单指令”(单一原子动作即可完成)时,输出一个严格的JSON对象。 + +输出要求(必须遵守): +- 只输出一个JSON对象,不要任何解释或多余文本。 +- JSON结构: +{"root":{"type":"action","name":"","params":{...}}} +- 与参数定义、取值范围,必须与“复杂模式”提示词(system_prompt.txt)中的定义完全一致。 +- root节点必须是action类型节点,不能是控制流节点。 +- 当用户指令或检索到的参考知识中已经给出了具体坐标(如 (32.118904, 118.955208)),请直接按原值使用该坐标,不进行任何缩放、单位换算或数值变换,只需根据需要映射到对应字段 + +示例: +- “起飞到10米” → {"root":{"type":"action","name":"takeoff","params":{"altitude":10.0}}} +- “移动到(120,80,20)” → {"root":{"type":"action","name":"fly_to_waypoint","params":{"x":120.0,"y":80.0,"z":20.0,"acceptance_radius":2.0}}} +- “飞机自检” → {"root":{"type":"action","name":"preflight_checks","params":{"check_level":"comprehensive"}}} + +—— 可用节点定义—— +```json +{ + "actions": [ + {"name": "takeoff", "description": "无人机从当前位置垂直起飞到指定的海拔高度。", "params": {"altitude": "float, 目标海拔高度(米),范围[1, 100],默认为2"}}, + {"name": "land", "description": "降落无人机。可选择当前位置或返航点降落。", "params": {"mode": "string, 可选值: 'current'(当前位置), 'home'(返航点)"}}, + {"name": "fly_to_waypoint", "description": "导航至一个指定坐标点。使用相对坐标系(x,y,z),单位为米。", "params": {"x": "float", "y": "float", "z": "float", "acceptance_radius": "float, 可选,默认2.0"}}, + {"name": "move_direction", "description": "按指定方向直线移动。方向可为绝对方位或相对机体朝向。", "params": {"direction": "string: north|south|east|west|forward|backward|left|right", "distance": "float[1,10000], 可选, 不指定则持续移动"}}, + {"name": "orbit_around_point", "description": "以给定中心点为中心,等速圆周飞行指定圈数。", "params": {"center_x": "float", "center_y": "float", "center_z": "float", "radius": "float[5,1000]", "laps": "int[1,20]", "clockwise": "boolean, 可选, 默认true", "speed_mps": "float[0.5,15], 可选", "gimbal_lock": "boolean, 可选, 默认true"}}, + {"name": "orbit_around_target", "description": "以目标为中心,等速圆周飞行指定圈数(需已有目标)。", "params": {"target_class": "string, 取值同object_detect列表", "description": "string, 可选", "radius": "float[5,1000]", "laps": "int[1,20]", "clockwise": "boolean, 可选, 默认true", "speed_mps": "float[0.5,15], 可选", "gimbal_lock": "boolean, 可选, 默认true"}}, + {"name": "loiter", "description": "在当前位置上空悬停一段时间或直到条件触发。", "params": {"duration": "float, 可选[1,600]", "until_condition": "string, 可选"}}, + {"name": "object_detect", "description": "识别特定目标对象。一般是用户提到的需要检测的目标;如果用户给出了需要探索的目标的优先级,比如蓝色球危险性大于红色球大于绿色球,需要检测最危险的球,此处应给出检测优先级,描述应当为 '蓝>红>绿'", "params": {"target_class": "string, 要识别的目标类别,必须为以下值之一: balloon,person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic_light, fire_hydrant, stop_sign, parking_meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports_ball, kite, baseball_bat, baseball_glove, skateboard, surfboard, tennis_racket, bottle, wine_glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot_dog, pizza, donut, cake, chair, couch, potted_plant, bed, dining_table, toilet, tv, laptop, mouse, remote, keyboard, cell_phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy_bear, hair_drier, toothbrush", "description": "string, 可选", "count": "int, 可选, 默认1"}}, + {"name": "strike_target", "description": "对已识别目标进行打击。", "params": {"target_class": "string", "description": "string, 可选", "count": "int, 可选, 默认1"}}, + {"name": "battle_damage_assessment", "description": "战损评估。", "params": {"target_class": "string", "assessment_time": "float[5-60], 默认15.0"}}, + {"name": "search_pattern", "description": "按模式搜索。", "params": {"pattern_type": "string: spiral|grid", "center_x": "float", "center_y": "float", "center_z": "float", "radius": "float[5,1000]", "target_class": "string", "description": "string, 可选", "count": "int, 可选, 默认1"}}, + {"name": "track_object", "description": "持续跟踪目标。", "params": {"target_class": "string, 取值同object_detect列表", "description": "string, 可选", "track_time": "float[1,600], 默认30.0", "min_confidence": "float[0.5-1.0], 默认0.7", "safe_distance": "float[2-50], 默认10.0"}}, + {"name": "deliver_payload", "description": "投放物资。", "params": {"payload_type": "string", "release_altitude": "float[2,100], 默认5.0"}}, + {"name": "preflight_checks", "description": "飞行前系统自检。", "params": {"check_level": "string: basic|comprehensive"}}, + {"name": "emergency_return", "description": "执行紧急返航程序。", "params": {"reason": "string"}} + ], + "conditions": [ + {"name": "battery_above", "description": "电池电量高于阈值。", "params": {"threshold": "float[0.0,1.0]"}}, + {"name": "at_waypoint", "description": "在指定坐标容差范围内。", "params": {"x": "float", "y": "float", "z": "float", "tolerance": "float, 可选, 默认3.0"}}, + {"name": "object_detected", "description": "检测到特定目标。", "params": {"target_class": "string", "description": "string, 可选", "count": "int, 可选, 默认1"}}, + {"name": "target_destroyed", "description": "目标已被摧毁。", "params": {"target_class": "string", "description": "string, 可选", "confidence": "float[0.5-1.0], 默认0.8"}}, + {"name": "time_elapsed", "description": "时间经过。", "params": {"duration": "float[1,2700]"}}, + {"name": "gps_status", "description": "GPS状态良好。", "params": {"min_satellites": "int[6,15], 默认10"}} + ] +} +``` + +—— 参数约束—— +- takeoff.altitude: [1, 100] +- fly_to_waypoint.z: [1, 5000] +- fly_to_waypoint.x,y: [-10000, 10000] +- search_pattern.radius: [5, 1000] +- move_direction.distance: [1, 10000] +- orbit_around_point.radius: [5, 1000] +- orbit_around_target.radius: [5, 1000] +- orbit_around_point/target.laps: [1, 20] +- orbit_around_point/target.speed_mps: [0.5, 15] +- 若参考知识提供坐标,必须使用并裁剪到约束范围内 + +—— 口令转化规则(环绕类)—— +- “环绕X米Y圈” → 若有目标上下文则使用 `orbit_around_target`,否则根据是否给出中心坐标选择 `orbit_around_point`;`radius=X`,`laps=Y`,默认 `clockwise=true`,`gimbal_lock=true` +- “顺时针/逆时针” → `clockwise=true/false` +- “等速” → 若未给速度则 `speed_mps` 采用默认值(例如3.0);若口令指明速度,裁剪到[0.5,15] +- “以(x,y,z)为中心”/“当前位置为中心” → 选择 `orbit_around_point` 并填充 `center_x/center_y/center_z` \ No newline at end of file diff --git a/scripts/ai_agent/groundcontrol/src/prompts/system_prompt.txt b/scripts/ai_agent/groundcontrol/src/prompts/system_prompt.txt new file mode 100644 index 0000000..10850d6 --- /dev/null +++ b/scripts/ai_agent/groundcontrol/src/prompts/system_prompt.txt @@ -0,0 +1,117 @@ +任务:根据用户任意任务指令,生成结构化可执行的无人机行为树(Pytree)JSON。**仅输出单一JSON对象,无任何自然语言、注释或额外内容**。 + +## 一、核心节点定义(格式不可修改,确保后端解析) +#### 1. 可用节点定义 (必须遵守) +你必须严格从以下JSON定义的列表中选择节点构建行为树,不允许使用未定义节点: +```json +{ + "actions": [ + {"name":"takeoff","params":{"altitude":"float[1,100],默认2"}}, + {"name":"land","params":{"mode":"'current'/'home'"}}, + {"name":"fly_to_waypoint","params":{"x":"±10000","y":"±10000","z":"[1,5000]","acceptance_radius":"默认2.0"}}, + {"name":"move_direction","params":{"direction":"north/south/east/west/forward/backward/left/right","distance":"[1,10000],缺省持续移动"}}, + {"name":"orbit_around_point","params":{"center_x":"±10000","center_y":"±10000","center_z":"[1,5000]","radius":"[5,1000]","laps":"[1,20]","clockwise":"默认true","speed_mps":"[0.5,15]","gimbal_lock":"默认true"}}, + {"name":"orbit_around_target","params":{"target_class":"见object_detect列表","description":"可选,目标属性","radius":"[5,1000]","laps":"[1,20]","clockwise":"默认true","speed_mps":"[0.5,15]","gimbal_lock":"默认true"}}, + {"name":"loiter","params":{"duration":"[1,600]秒/until_condition:可选"}}, + {"name":"object_detect","params":{"target_class":"person,bicycle,car,motorcycle,airplane,bus,train,truck,boat,traffic_light,fire_hydrant,stop_sign,parking_meter,bench,bird,cat,dog,horse,sheep,cow,elephant,bear,zebra,giraffe,backpack,umbrella,handbag,tie,suitcase,frisbee,skis,snowboard,sports_ball,kite,baseball_bat,baseball_glove,skateboard,surfboard,tennis_racket,bottle,wine_glass,cup,fork,knife,spoon,bowl,banana,apple,sandwich,orange,broccoli,carrot,hot_dog,pizza,donut,cake,chair,couch,potted_plant,bed,dining_table,toilet,tv,laptop,mouse,remote,keyboard,cell_phone,microwave,oven,toaster,sink,refrigerator,book,clock,vase,scissors,teddy_bear,hair_drier,toothbrush","description":"可选,","count":"默认1"}}, + {"name":"strike_target","params":{"target_class":"同object_detect","description":"可选,目标属性","count":"默认1"}}, + {"name":"battle_damage_assessment","params":{"target_class":"同object_detect","assessment_time":"[5,60],默认15"}}, + {"name":"search_pattern","params":{"pattern_type":"spiral/grid","center_x":"±10000","center_y":"±10000","center_z":"[1,5000]","radius":"[5,1000]","target_class":"同object_detect","description":"可选,目标属性","count":"默认1"}}, + {"name":"track_object","params":{"target_class":"同object_detect","description":"可选,目标属性","track_time":"[1,600]秒(必传,不可用'duration')","min_confidence":"[0.5,1.0]默认0.7","safe_distance":"[2,50]默认10"}}, + {"name":"deliver_payload","params":{"payload_type":"string","release_altitude":"[2,100]默认5"}}, + {"name":"preflight_checks","params":{"check_level":"basic/comprehensive"}}, + {"name":"emergency_return","params":{"reason":"string"}} + ], + "conditions": [ + {"name":"battery_above","params":{"threshold":"[0.0,1.0],必传"}}, + {"name":"at_waypoint","params":{"x":"±10000","y":"±10000","z":"[1,5000]","tolerance":"默认3.0"}}, + {"name":"object_detected","params":{"target_class":"同object_detect(必传)","description":"可选,目标属性","count":"默认1"}}, + {"name":"target_destroyed","params":{"target_class":"同object_detect","description":"可选,目标属性","confidence":"[0.5,1.0]默认0.8"}}, + {"name":"time_elapsed","params":{"duration":"[1,2700]秒"}}, + {"name":"gps_status","params":{"min_satellites":"int[6,15],必传(如8)"}} + ], + "control_flow": [ + {"name":"Sequence","params":{},"children":"子节点数组(按序执行,全成功则成功)"}, + {"name":"Selector","params":{"memory":"默认true"},"children":"子节点数组(执行到成功为止)"}, + {"name":"Parallel","params":{"policy":"all_success"},"children":"子节点数组(同时执行,严禁用'one_success')"} + ] +} +``` + + +## 二、节点必填字段(后端Schema强制要求,缺一验证失败) +每个节点必须包含以下字段,字段名/类型不可自定义: +1. **`type`**: + - 动作节点→`"action"`,条件节点→`"condition"`,控制流节点→`"Sequence"`/`"Selector"`/`"Parallel"`(与`name`字段值完全一致); +2. **`name`**:必须是上述JSON中`actions`/`conditions`/`control_flow`下的`name`值(如“gps_status”不可错写为“gps_check”); +3. **`params`**:严格匹配上述节点的`params`定义,无自定义参数(如优先级排序不可加“priority”字段,仅用`description`); +4. **`children`**:仅控制流节点必含(子节点数组),动作/条件节点无此字段。 + + +## 三、行为树固定结构(通用不变,确保安全验证) +根节点必须是`Parallel`,`children`含`MainTask`(Sequence)和`SafetyMonitor`(Selector),结构不随任务类型(含优先级排序)修改: +```json +{ + "root": { + "type": "Parallel", + "name": "MissionWithSafety", + "params": {"policy": "all_success"}, + "children": [ + { + "type": "Sequence", + "name": "MainTask", + "params": {}, + "children": [ + // 通用主任务步骤(含优先级排序任务示例,需按用户指令替换): + {"type":"action","name":"preflight_checks","params":{"check_level":"comprehensive"}}, + {"type":"action","name":"takeoff","params":{"altitude":10.0}}, + {"type":"action","name":"fly_to_waypoint","params":{"x":200.0,"y":150.0,"z":10.0}}, // 搜索区坐标(用户未给时填合理值) + {"type":"action","name":"search_pattern","params":{"pattern_type":"grid","center_x":200.0,"center_y":150.0,"center_z":10.0,"radius":50.0,"target_class":"balloon","description":"红色"}}, + {"type":"condition","name":"object_detected","params":{"target_class":"balloon","description":"红色"}}, // 确认高优先级目标 + {"type":"action","name":"track_object","params":{"target_class":"balloon","description":"红色","track_time":30.0}}, + {"type":"action","name":"strike_target","params":{"target_class":"balloon","description":"红色"}}, + {"type":"action","name":"land","params":{"mode":"home"}} + ] + }, + { + "type": "Selector", + "name": "SafetyMonitor", + "params": {"memory": true}, + "children": [ + {"type":"condition","name":"battery_above","params":{"threshold":0.3}}, + {"type":"condition","name":"gps_status","params":{"min_satellites":8}}, + { + "type":"Sequence", + "name":"EmergencyHandler", + "params": {}, + "children": [ + {"type":"action","name":"emergency_return","params":{"reason":"safety_breach"}}, + {"type":"action","name":"land","params":{"mode":"home"}} + ] + } + ] + } + ] + } +} +``` + + +## 四、优先级排序任务通用示例 +当用户指令中明确提出有多个待考察且具有优先级关系的物体时,节点描述须为优先级关系。比如当指令为已知有三个气球,危险级关系为红色气球大于蓝色气球大于绿色气球,要求优先跟踪最危险的气球时,节点的描述参考下表情形。 +| 用户指令场景 | `target_class` | `description` | 核心节点示例(search_pattern) | +|-----------------------------|-----------------|-------------------------|------------------------------------------------------------------------------------------------| +| 红气球>蓝气球>绿气球 | `balloon` | `(红>蓝>绿)` | `{"type":"action","name":"search_pattern","params":{"pattern_type":"grid","center_x":200,"center_y":150,"center_z":10,"radius":50,"target_class":"balloon","description":"(红>蓝>绿)"}}` | +| 军用卡车>民用卡车>面包车 | `truck` | `(军用卡车>民用卡车>面包车)` | `{"type":"action","name":"object_detect","params":{"target_class":"truck","description":"(军用卡车>民用卡车>面包车)"}}` | + + +## 五、高频错误规避(确保验证通过) +1. 优先级排序不可修改`target_class`:如“民用卡车、面包车与军用卡车中,军用卡车优先”,`target_class`仍为`truck`,仅用`description`填排序规则; +2. 在没有明确指出物体之间的优先级关系情况下,`description`字段只描述物体属性本身,严禁与用户指令中不存在的物体进行排序; +3. `track_object`必传`track_time`:不可用`duration`替代(如跟踪30秒填`"track_time":30.0`); +4. `gps_status`的`min_satellites`必须在6-15之间(如8,不可缺省); +5. 无自定义节点:“锁定高优先级目标”需通过`object_detect`+`object_detected`实现,不可用“lock_high_risk_target”。 + + +## 六、输出要求 +仅输出1个严格符合上述所有规则的JSON对象,**确保:1. 优先级排序逻辑正确填入`description`;2. `target_class`匹配预定义列表;3. 行为树结构不变;4. 后端解析与Schema验证无错误**,无任何冗余内容;5. 当用户指令或检索到的参考知识中已经给出了具体坐标(如 (32.118904, 118.955208)),请直接按原值使用该坐标,不进行任何缩放、单位换算或数值变换,只需根据需要映射到对应字段 diff --git a/scripts/ai_agent/groundcontrol/src/py_tree_generator.py b/scripts/ai_agent/groundcontrol/src/py_tree_generator.py new file mode 100644 index 0000000..56d75ac --- /dev/null +++ b/scripts/ai_agent/groundcontrol/src/py_tree_generator.py @@ -0,0 +1,934 @@ +import json +import os +import logging +import uuid +import re +import time +from typing import Dict, Any, Optional, Set, List +import jsonschema +import platform + +# --- Imports from ai_agent --- +# Note: sys.path is updated in main.py to include ai_agent root +from models.models_client import Models_Client +from models.model_definition import TextInferenceRequest +from tools.memory_mag.rag.rag import set_embeddings, load_vector_database, retrieve_relevant_info +from tools.core.common_functions import read_json_file + +# --- Logging Setup --- +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler() + ] +) + +# ============================================================================== +# VALIDATION LOGIC (from utils/validation.py) +# ============================================================================== +def _parse_allowed_nodes_from_prompt(prompt_text: str) -> tuple[Set[str], Set[str]]: + """ + 从系统提示词中精确解析出允许的行动和条件节点。 + """ + try: + # 使用更精确的正则表达式匹配节点定义部分 + # 匹配 "#### 1. 可用节点定义" 或 "#### 2. 可用节点定义" 等 + node_section_pattern = r"####\s*\d+\.\s*可用节点定义.*?```json\s*({.*?})\s*```" + match = re.search(node_section_pattern, prompt_text, re.DOTALL | re.IGNORECASE) + + if not match: + logging.error("在系统提示词中未找到'可用节点定义'部分的JSON代码块。") + # 备用方案:尝试查找所有JSON块并识别节点定义 + return _fallback_parse_nodes(prompt_text) + + json_str = match.group(1) + logging.info("成功找到节点定义JSON代码块") + + # 解析JSON + allowed_nodes = json.loads(json_str) + + # 从对象列表中提取节点名称 + actions = set() + conditions = set() + + # 提取动作节点 + if "actions" in allowed_nodes and isinstance(allowed_nodes["actions"], list): + for action in allowed_nodes["actions"]: + if isinstance(action, dict) and "name" in action: + actions.add(action["name"]) + + # 提取条件节点 + if "conditions" in allowed_nodes and isinstance(allowed_nodes["conditions"], list): + for condition in allowed_nodes["conditions"]: + if isinstance(condition, dict) and "name" in condition: + conditions.add(condition["name"]) + + if not actions: + logging.warning("关键错误:从提示词解析出的行动节点列表为空。") + + logging.info(f"成功解析出动作节点: {sorted(actions)}") + logging.info(f"成功解析出条件节点: {sorted(conditions)}") + + return actions, conditions + + except json.JSONDecodeError as e: + logging.error(f"解析节点定义JSON时失败: {e}") + return set(), set() + except Exception as e: + logging.error(f"解析可用节点时发生未知错误: {e}") + return set(), set() + +def _fallback_parse_nodes(prompt_text: str) -> tuple[Set[str], Set[str]]: + """ + 备用解析方案:当精确匹配失败时使用。 + """ + logging.warning("使用备用方案解析节点定义...") + + # 查找所有JSON代码块 + matches = re.findall(r"```json\s*({.*?})\s*```", prompt_text, re.DOTALL) + if not matches: + logging.error("在系统提示词中未找到任何JSON代码块。") + return set(), set() + + # 尝试从每个JSON块中解析节点定义 + for i, json_str in enumerate(matches): + try: + data = json.loads(json_str) + + # 检查是否是节点定义的结构(包含actions、conditions、control_flow) + if ("actions" in data and isinstance(data["actions"], list) and + "conditions" in data and isinstance(data["conditions"], list) and + "control_flow" in data and isinstance(data["control_flow"], list)): + + actions = set() + conditions = set() + + # 提取动作节点 + for action in data["actions"]: + if isinstance(action, dict) and "name" in action: + actions.add(action["name"]) + + # 提取条件节点 + for condition in data["conditions"]: + if isinstance(condition, dict) and "name" in condition: + conditions.add(condition["name"]) + + if actions: + logging.info(f"从第{i+1}个JSON块中成功解析出节点定义") + logging.info(f"动作节点: {sorted(actions)}") + logging.info(f"条件节点: {sorted(conditions)}") + return actions, conditions + + except json.JSONDecodeError: + continue # 尝试下一个JSON块 + + logging.error("在所有JSON代码块中都没有找到有效的节点定义结构。") + return set(), set() + +def _find_nodes_by_name(node: Dict, target_name: str) -> List[Dict]: + """递归查找所有指定名称的节点""" + nodes_found = [] + + if node.get("name") == target_name: + nodes_found.append(node) + + # 递归搜索子节点 + for child in node.get("children", []): + nodes_found.extend(_find_nodes_by_name(child, target_name)) + + return nodes_found + +def _validate_safety_monitoring(pytree_instance: dict) -> bool: + """验证行为树是否包含必要的安全监控""" + root_node = pytree_instance.get("root", {}) + + # 查找所有电池监控节点 + battery_nodes = _find_nodes_by_name(root_node, "battery_above") + + # 检查是否包含安全监控结构 + safety_monitors = _find_nodes_by_name(root_node, "SafetyMonitor") + + if not battery_nodes and not safety_monitors: + logging.warning("⚠️ 安全警告: 行为树中没有发现电池监控节点或安全监控器") + return False + + # 检查电池阈值设置是否合理 + for battery_node in battery_nodes: + threshold = battery_node.get("params", {}).get("threshold") + if threshold is not None: + if threshold < 0.25: + logging.warning(f"⚠️ 安全警告: 电池阈值设置过低 ({threshold}),建议不低于0.25") + elif threshold > 0.5: + logging.warning(f"⚠️ 安全警告: 电池阈值设置过高 ({threshold}),可能影响任务执行") + + logging.info("✅ 安全监控验证通过") + return True + +def _generate_pytree_schema(allowed_actions: set, allowed_conditions: set) -> dict: + """ + 根据允许的行动和条件节点,动态生成一个JSON Schema。 + """ + # 所有可能的节点类型 + node_types = ["action", "condition", "Sequence", "Selector", "Parallel"] + + # 目标检测相关的类别枚举 + target_classes = [ + "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", + "traffic_light", "fire_hydrant", "stop_sign", "parking_meter", "bench", "bird", "cat", + "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", + "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports_ball", + "kite", "baseball_bat", "baseball_glove", "skateboard", "surfboard", "tennis_racket", + "bottle", "wine_glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", + "sandwich", "orange", "broccoli", "carrot", "hot_dog", "pizza", "donut", "cake", "chair", + "couch", "potted_plant", "bed", "dining_table", "toilet", "tv", "laptop", "mouse", "remote", + "keyboard", "cell_phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", + "clock", "vase", "scissors", "teddy_bear", "hair_drier", "toothbrush","balloon" + ] + + # 递归节点定义 + node_definition = { + "type": "object", + "properties": { + "type": {"type": "string", "enum": node_types}, + "name": {"type": "string"}, + "params": {"type": "object"}, + "children": { + "type": "array", + "items": {"$ref": "#/definitions/node"} + } + }, + "required": ["type", "name"], + "allOf": [ + # 动作节点验证 + { + "if": {"properties": {"type": {"const": "action"}}}, + "then": {"properties": {"name": {"enum": sorted(list(allowed_actions))}}} + }, + # 条件节点验证 + { + "if": {"properties": {"type": {"const": "condition"}}}, + "then": {"properties": {"name": {"enum": sorted(list(allowed_conditions))}}} + }, + # 目标检测动作节点的参数验证 + { + "if": { + "properties": { + "type": {"const": "action"}, + "name": {"const": "object_detect"} + } + }, + "then": { + "properties": { + "params": { + "type": "object", + "properties": { + "target_class": {"type": "string", "enum": target_classes}, + "description": {"type": "string"}, + "count": {"type": "integer", "minimum": 1} + }, + "required": ["target_class"], + "additionalProperties": False + } + } + } + }, + # 目标检测条件节点的参数验证 + { + "if": { + "properties": { + "type": {"const": "condition"}, + "name": {"const": "object_detected"} + } + }, + "then": { + "properties": { + "params": { + "type": "object", + "properties": { + "target_class": {"type": "string", "enum": target_classes}, + "description": {"type": "string"}, + "count": {"type": "integer", "minimum": 1} + }, + "required": ["target_class"], + "additionalProperties": False + } + } + } + }, + # 电池监控节点的参数验证 + { + "if": { + "properties": { + "type": {"const": "condition"}, + "name": {"const": "battery_above"} + } + }, + "then": { + "properties": { + "params": { + "type": "object", + "properties": { + "threshold": {"type": "number", "minimum": 0.0, "maximum": 1.0} + }, + "required": ["threshold"], + "additionalProperties": False + } + } + } + }, + # GPS状态节点的参数验证 + { + "if": { + "properties": { + "type": {"const": "condition"}, + "name": {"const": "gps_status"} + } + }, + "then": { + "properties": { + "params": { + "type": "object", + "properties": { + "min_satellites": {"type": "integer", "minimum": 6, "maximum": 15} + }, + "required": ["min_satellites"], + "additionalProperties": False + } + } + } + } + ] + } + + # 完整的Schema结构 + schema = { + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Pytree", + "definitions": { + "node": node_definition + }, + "type": "object", + "properties": { + "root": { "$ref": "#/definitions/node" } + }, + "required": ["root"] + } + + return schema + +def _generate_simple_mode_schema(allowed_actions: set) -> dict: + """ + 生成简单模式JSON Schema:{"root":{"type":"action","name":"","params":{...}}} + 仅校验动作名称在允许集合内,以及基本结构完整性;参数按对象形状放宽,由上游提示词与运行时再约束。 + """ + # 使用复杂模式Schema中的node定义,但限制root节点必须是action类型 + node_definition = { + "type": "object", + "properties": { + "type": {"type": "string", "const": "action"}, + "name": {"type": "string", "enum": sorted(list(allowed_actions))}, + "params": {"type": "object"} + }, + "required": ["type", "name"], + "additionalProperties": False + } + + schema = { + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "SimpleMode", + "definitions": { + "node": node_definition + }, + "type": "object", + "properties": { + "root": { "$ref": "#/definitions/node" } + }, + "required": ["root"], + "additionalProperties": False + } + + return schema + +def _validate_pytree_with_schema(pytree_instance: dict, schema: dict) -> bool: + """ + 使用JSON Schema验证给定的Pytree实例。 + """ + try: + jsonschema.validate(instance=pytree_instance, schema=schema) + logging.info("✅ JSON Schema验证成功") + + # 额外验证安全监控 + safety_valid = _validate_safety_monitoring(pytree_instance) + + return True and safety_valid + except jsonschema.ValidationError as e: + logging.warning("❌ Pytree验证失败") + logging.warning(f"错误信息: {e.message}") + error_path = list(e.path) + logging.warning(f"错误路径: {' -> '.join(map(str, error_path)) if error_path else '根节点'}") + + # 提供更具体的错误信息 + if "object_detect" in str(e.message) or "object_detected" in str(e.message): + logging.warning("💡 提示: 请确保目标类别是预定义列表中的有效值") + elif "battery_above" in str(e.message): + logging.warning("💡 提示: 电池阈值必须在0.0到1.0之间") + elif "gps_status" in str(e.message): + logging.warning("💡 提示: 最小卫星数量必须在6到15之间") + + return False + except Exception as e: + logging.error(f"进行JSON Schema验证时发生未知错误: {e}") + return False + +# ============================================================================== +# VISUALIZATION LOGIC (from utils/visualization.py) +# ============================================================================== +def _visualize_pytree(node: Dict, file_path: str): + """ + 使用Graphviz将Pytree字典可视化,并保存到指定路径。 + """ + try: + from graphviz import Digraph + except ImportError: + logging.critical("错误:未安装graphviz库。请运行: pip install graphviz") + return + + # 选择合适的中文字体,避免中文乱码 + def _pick_zh_font(): + sys = platform.system() + if sys == "Windows": + return "Microsoft YaHei" + elif sys == "Darwin": + return "PingFang SC" + else: + return "Noto Sans CJK SC" + + fontname = _pick_zh_font() + + dot = Digraph('Pytree', comment='Drone Mission Plan') + dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20', fontname=fontname) + dot.attr('node', shape='box', style='rounded,filled', fontname=fontname) + dot.attr('edge', fontname=fontname) + + _add_nodes_and_edges(node, dot) + + try: + # 确保输出目录存在,并避免生成 .png.png + base_path, ext = os.path.splitext(file_path) + render_path = base_path if ext.lower() == '.png' else file_path + + out_dir = os.path.dirname(render_path) + if out_dir and not os.path.exists(out_dir): + os.makedirs(out_dir, exist_ok=True) + + # 保存为 .png 文件,并自动删除源码 .gv 文件 + output_path = dot.render(render_path, format='png', cleanup=True, view=False) + logging.info("✅ 任务树可视化成功") + logging.info(f"图形已保存到: {output_path}") + except Exception as e: + logging.error("❌ 生成可视化图形失败") + logging.error("请确保您的系统已经正确安装了Graphviz图形库。") + logging.error(f"错误详情: {e}") + +def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str: + """递归辅助函数,用于添加节点和边。""" + + # 为每个节点创建一个唯一的ID(加上随机数避免冲突) + import random + import html + + current_id = f"{id(node)}_{random.randint(1000, 9999)}" + + # 准备节点标签(HTML-like,正确换行与转义) + name = html.escape(str(node.get('name', ''))) + ntype = html.escape(str(node.get('type', ''))) + label_parts = [f"{name} ({ntype})"] + + # 格式化参数显示 + params = node.get('params') or {} + if params: + params_lines = [] + for key, value in params.items(): + k = html.escape(str(key)) + if isinstance(value, float): + value_str = f"{value:.2f}".rstrip('0').rstrip('.') + else: + value_str = str(value) + v = html.escape(value_str) + params_lines.append(f"{k}: {v}") + params_text = "
".join(params_lines) + label_parts.append(f"{params_text}") + + node_label = f"<{'
'.join(label_parts)}>" + + # 根据类型设置节点样式和颜色(使用 fillcolor 控制填充色) + node_type = (node.get('type') or '').lower() + shape = 'ellipse' + style = 'filled' + fillcolor = '#e6e6e6' # 默认灰色填充 + border_color = '#666666' # 默认描边色 + + if node_type == 'action': + shape = 'box' + style = 'rounded,filled' + fillcolor = "#cde4ff" # 浅蓝 + elif node_type == 'condition': + shape = 'diamond' + style = 'filled' + fillcolor = "#fff2cc" # 浅黄 + elif node_type == 'sequence': + shape = 'ellipse' + style = 'filled' + fillcolor = '#d5e8d4' # 绿色 + elif node_type == 'selector': + shape = 'ellipse' + style = 'filled' + fillcolor = '#ffe6cc' # 橙色 + elif node_type == 'parallel': + shape = 'ellipse' + style = 'filled' + fillcolor = '#e1d5e7' # 紫色 + + # 特别标记安全相关节点 + if node.get('name') in ['battery_above', 'gps_status', 'SafetyMonitor']: + border_color = '#ff0000' # 红色边框突出显示安全节点 + style = 'filled,bold' # 加粗 + + dot.node(current_id, label=node_label, shape=shape, style=style, fillcolor=fillcolor, color=border_color) + + # 连接父节点 + if parent_id: + dot.edge(parent_id, current_id) + + # 递归处理子节点 + children = node.get("children", []) + if not children: + return current_id + + # 记录所有子节点的ID + child_ids = [] + + # 正确的递归连接:每个子节点都连接到当前节点 + for child in children: + child_id = _add_nodes_and_edges(child, dot, current_id) + child_ids.append(child_id) + + # 子节点同级排列(横向排布,更直观地表现同层) + if len(child_ids) > 1: + with dot.subgraph(name=f"rank_{current_id}") as s: + s.attr(rank='same') + for cid in child_ids: + s.node(cid) + + # 行为树中,所有类型的节点都只是父连子,不需要子节点间的额外连接 + # Sequence、Selector、Parallel 的执行逻辑由行为树引擎处理,不需要在可视化中体现 + + return current_id + +# ============================================================================== +# CORE PYTREE GENERATOR CLASS +# ============================================================================== +class PyTreeGenerator: + def __init__(self): + self.base_dir = os.path.dirname(os.path.abspath(__file__)) + self.prompts_dir = os.path.join(self.base_dir, 'prompts') + + # Updated output directory for visualizations + self.vis_dir = os.path.abspath(os.path.join(self.base_dir, '..', 'generated_visualizations')) + os.makedirs(self.vis_dir, exist_ok=True) + + # Load prompts + self.complex_prompt = self._load_prompt("system_prompt.txt") + self.simple_prompt = self._load_prompt("simple_mode_prompt.txt") + self.classifier_prompt = self._load_prompt("classifier_prompt.txt") + # Legacy variable compatibility + self.system_prompt = self.complex_prompt + + # --- Models Client Setup --- + # Unified client for all models + self.model_server_url = os.getenv("MODEL_SERVER_URL", "http://localhost:8000") + self.models_client = Models_Client(server_url=self.model_server_url) + + # --- RAG Setup --- + # 计算 ai_agent 根目录(从 groundcontrol/src 向上2级) + ai_agent_root = os.path.abspath(os.path.join(self.base_dir, "../..")) + + # 尝试从配置文件读取 RAG 配置 + rag_config_path = os.path.join(ai_agent_root, "config", "rag_config.json") + rag_config = None + if os.path.exists(rag_config_path): + try: + rag_config = read_json_file(rag_config_path) + logging.info(f"成功读取 RAG 配置文件: {rag_config_path}") + except Exception as e: + logging.warning(f"读取 RAG 配置文件失败: {e},将使用默认配置") + + # 确定向量数据库路径(优先使用配置文件,否则使用默认路径) + if rag_config and "rag_mag" in rag_config: + vector_store_path = rag_config["rag_mag"].get("vectorstore_persist_directory") + collection_name = rag_config["rag_mag"].get("collection_name", "osm_map_docs") + embedding_model_path = rag_config["rag_mag"].get("embedding_model_path") + embedding_type = rag_config["rag_mag"].get("embedding_framework_type", "llamacpp_embedding") + + # 如果配置文件中是绝对路径,直接使用;否则转换为绝对路径 + if vector_store_path and not os.path.isabs(vector_store_path): + vector_store_path = os.path.join(ai_agent_root, vector_store_path) + + # 读取模型配置 + if embedding_type == "llamacpp_embedding": + model_config = rag_config["rag_mag"].get("model_config_llamacpp", {}) + else: + model_config = rag_config["rag_mag"].get("model_config_huggingFace", {}) + else: + # 使用默认配置:指向 scripts/ai_agent/memory/knowledge_base/map/vector_store/osm_map1 + vector_store_path = os.path.join(ai_agent_root, "memory", "knowledge_base", "map", "vector_store", "osm_map1") + collection_name = "osm_map_docs" + embedding_model_path = os.getenv("EMBEDDING_MODEL_PATH", "/home/huangfukk/models/gguf/Qwen/Qwen3-Embedding-4B/Qwen3-Embedding-4B-Q4_K_M.gguf") + embedding_type = "llamacpp_embeddings" + # 默认模型配置 + model_config = { + "n_ctx": 512, + "n_threads": 4, + "n_gpu_layers": 0, + "n_seq_max": 256, + "n_threads_batch": 4, + "flash_attn": False, + "verbose": False + } + + # 确保路径是绝对路径 + vector_store_path = os.path.abspath(vector_store_path) + logging.info(f"向量数据库路径: {vector_store_path}") + logging.info(f"集合名称: {collection_name}") + logging.info(f"初始化 RAG,嵌入模型路径: {embedding_model_path}") + + if os.path.exists(embedding_model_path): + # 设置嵌入类型名称(配置文件使用 "llamacpp_embedding",函数期望 "llamacpp_embeddings") + if embedding_type == "llamacpp_embedding": + embedding_type = "llamacpp_embeddings" + + self.embeddings = set_embeddings( + embedding_model_path=embedding_model_path, + embedding_type=embedding_type, + model_config=model_config + ) + # 加载现有向量数据库 + self.vector_db = load_vector_database( + embeddings=self.embeddings, + persist_directory=vector_store_path, + collection_name=collection_name + ) + if not self.vector_db: + logging.warning(f"向量数据库未找到于 {vector_store_path},集合名称: {collection_name}。RAG 检索将返回空结果,直到数据库被填充。") + else: + logging.info(f"✅ 成功加载向量数据库: {vector_store_path}") + else: + logging.warning(f"嵌入模型未找到于 {embedding_model_path}。RAG 将被禁用。") + self.embeddings = None + self.vector_db = None + + # Parse allowed nodes from prompt for validation + allowed_actions, allowed_conditions = _parse_allowed_nodes_from_prompt(self.complex_prompt) + self.schema = _generate_pytree_schema(allowed_actions, allowed_conditions) + self.simple_schema = _generate_simple_mode_schema(allowed_actions) + + def _load_prompt(self, file_name: str) -> str: + try: + with open(os.path.join(self.prompts_dir, file_name), 'r', encoding='utf-8') as f: + return f.read() + except FileNotFoundError: + logging.error(f"提示词文件未找到 -> {file_name}") + return "" + + def _retrieve_context(self, query: str, timeout: int = 15) -> Optional[str]: + """ + Retrieve context from vector database with timeout protection. + """ + if not self.vector_db: + return None + + logging.info("--- Retrieving context from Vector DB ---") + + try: + import threading + result_container = [] + exception_container = [] + + def query_func(): + try: + # 直接从 Chroma 拿原始距离分数 + docs_and_scores = self.vector_db.similarity_search_with_score(query, k=5) + + if not docs_and_scores: + result_container.append([]) + return + + # 提取距离,并做一次查询内的 min-max 归一化 -> 相似度 ∈ [0,1] + distances = [score for _, score in docs_and_scores] + min_d, max_d = min(distances), max(distances) + + normalized_results = [] + for (doc, dist) in docs_and_scores: + if max_d == min_d: + sim = 1.0 # 所有距离一样时,认为同等相似 + else: + sim = (max_d - dist) / (max_d - min_d) # 距离小 -> 相似度高 + + # 使用绝对阈值 0.2 的相似度过滤 + if sim >= 0.2: + normalized_results.append({ + "content": doc.page_content, + "metadata": doc.metadata, + "similarity_score": float(sim) + }) + + result_container.append(normalized_results) + except Exception as e: + exception_container.append(e) + + thread = threading.Thread(target=query_func) + thread.daemon = True + thread.start() + thread.join(timeout=timeout) + + if thread.is_alive(): + logging.warning(f"⚠️ Vector DB query timed out (> {timeout}s), skipping RAG.") + return None + + if exception_container: + raise exception_container[0] + + results = result_container[0] if result_container else [] + + if not results: + logging.warning("Vector DB returned no results.") + return None + + # Extract content from results (List[Dict]) + # Each result has 'content', 'metadata', 'similarity_score' + context_list = [doc['content'] for doc in results] + context_str = "\n\n".join(context_list) + logging.info(f"✅ Successfully retrieved {len(results)} documents.") + return context_str + + except Exception as e: + logging.error(f"RAG retrieval error: {e}") + return None + + async def generate(self, user_prompt: str) -> Dict[str, Any]: + """ + Generates a py_tree.json structure based on the user's prompt. + """ + total_start_time = time.time() + logging.info(f"接收到用户请求: {user_prompt}") + + # 第一步:分类(简单/复杂) + mode = "complex" + classify_start = time.time() + try: + # Construct request for classification + # 分类请求使用简单的 JSON Schema 约束输出格式 + classifier_schema = { + "type": "object", + "properties": { + "mode": {"type": "string", "enum": ["simple", "complex"]} + }, + "required": ["mode"], + "additionalProperties": False + } + + req = TextInferenceRequest( + user_prompt=user_prompt, + system_prompt=self.classifier_prompt or "你是一个分类器,只输出JSON。", + temperature=0.0, + max_tokens=100, + top_p=0.95, + json_schema=classifier_schema # 传递 JSON Schema 约束分类输出 + ) + + # Call unified models client + resp = self.models_client.text_inference(req) + + if "error" in resp: + logging.warning(f"分类请求失败: {resp['error']}") + else: + class_str = resp.get("result", "{}") + # Try to clean markdown if present + class_str = re.sub(r'```json\s*', '', class_str).replace('```', '').strip() + + try: + class_obj = json.loads(class_str) + if isinstance(class_obj, dict) and class_obj.get("mode") in ("simple", "complex"): + mode = class_obj.get("mode") + except json.JSONDecodeError: + pass # Default to complex + + classify_time = time.time() - classify_start + logging.info(f"分类结果: {mode}, ⏱️ 耗时: {classify_time:.2f}秒") + except Exception as e: + classify_time = time.time() - classify_start + logging.warning(f"分类过程发生异常,默认按复杂指令处理: {e}, ⏱️ 耗时: {classify_time:.2f}秒") + + # 第二步:根据模式准备提示词与上下文(简单与复杂都执行检索增强) + # 基于模式选择提示词;复杂模式追加一条强制规则,避免模型误输出简单结构 + use_prompt = self.simple_prompt if mode == "simple" else ( + (self.complex_prompt or "") + + "\n\n【强制规则】仅生成包含root的复杂行为树JSON,不得输出简单模式(不得包含mode字段或仅有action节点)。" + ) + final_user_prompt = user_prompt + retrieval_start = time.time() + retrieved_context = self._retrieve_context(user_prompt) + retrieval_time = time.time() - retrieval_start + logging.info(f"⏱️ 检索耗时: {retrieval_time:.2f}秒") + if retrieved_context: + augmentation = ( + "\n\n---\n" + "参考知识:\n" + "以下是从知识库中检索到的、与当前任务最相关的信息,请优先参考这些信息来生成结果:\n" + f"{retrieved_context}" + "\n---" + ) + final_user_prompt += augmentation + else: + logging.warning("未检索到上下文或检索失败,将使用原始用户提示词。") + + for attempt in range(3): + logging.info(f"--- 第 {attempt + 1}/3 次尝试生成Pytree ---") + try: + generation_start = time.time() + + # Construct generation request with JSON Schema + # 根据模式选择对应的 JSON Schema + selected_schema = self.simple_schema if mode == "simple" else self.schema + + req = TextInferenceRequest( + user_prompt=final_user_prompt, + system_prompt=use_prompt, + temperature=0.1 if mode == "complex" else 0.0, + max_tokens=4096, # Large enough for complex trees + top_p=0.95, + json_schema=selected_schema # 传递 JSON Schema 约束输出格式 + ) + + resp = self.models_client.text_inference(req) + + if "error" in resp: + logging.error(f"生成请求失败: {resp['error']}") + continue + + pytree_str = resp.get("result", "") + + generation_time = time.time() - generation_start + logging.info(f"⏱️ LLM生成耗时: {generation_time:.2f}秒") + + # Clean markdown code blocks if present + clean_pytree_str = re.sub(r'```json\s*', '', pytree_str).replace('```', '').strip() + + # 单独捕获JSON解析错误并打印原始响应 + try: + pytree_dict = json.loads(clean_pytree_str) + except json.JSONDecodeError as e: + logging.error(f"❌ JSON解析失败(第 {attempt + 1}/3 次)。原始响应如下:\n{pytree_str}") + continue + + # 简单/复杂分别验证与返回 + if mode == "simple": + try: + jsonschema.validate(instance=pytree_dict, schema=self.simple_schema) + logging.info("✅ 简单模式JSON Schema验证成功") + except jsonschema.ValidationError as e: + logging.warning(f"❌ 简单模式验证失败: {e.message}") + continue + # 附加元信息并生成简单可视化(单动作) + plan_id = str(uuid.uuid4()) + pytree_dict['plan_id'] = plan_id + # 简单模式可视化:直接使用root节点 + try: + vis_start = time.time() + vis_filename = "py_tree.png" + vis_path = os.path.join(self.vis_dir, vis_filename) + _visualize_pytree(pytree_dict['root'], os.path.splitext(vis_path)[0]) + vis_time = time.time() - vis_start + logging.info(f"⏱️ 可视化耗时: {vis_time:.2f}秒") + pytree_dict['visualization_url'] = f"/static/{vis_filename}" + except Exception as e: + logging.warning(f"简单模式可视化失败: {e}") + + total_time = time.time() - total_start_time + logging.info(f"🎉 ⏱️ 总耗时: {total_time:.2f}秒 (分类:{classify_time:.1f}s + 检索:{retrieval_time:.1f}s + 生成:{generation_time:.1f}s + 可视化:{vis_time if 'vis_time' in locals() else 0:.1f}s)") + + return pytree_dict + + # 复杂模式回退:若模型误返回简单结构,则自动包装为含安全监控的行为树 + if mode == "complex" and isinstance(pytree_dict, dict) and 'root' not in pytree_dict: + try: + jsonschema.validate(instance=pytree_dict, schema=self.simple_schema) + logging.warning("⚠️ 复杂模式生成了简单结构,触发自动包装为完整行为树的回退逻辑。") + simple_action_obj = pytree_dict.get('action') or {} + action_name = simple_action_obj.get('name') + action_params = simple_action_obj.get('params') if isinstance(simple_action_obj.get('params'), dict) else {} + + safety_selector = { + "type": "Selector", + "name": "SafetyMonitor", + "params": {"memory": True}, + "children": [ + {"type": "condition", "name": "battery_above", "params": {"threshold": 0.3}}, + {"type": "condition", "name": "gps_status", "params": {"min_satellites": 8}}, + {"type": "Sequence", "name": "EmergencyHandler", "children": [ + {"type": "action", "name": "emergency_return", "params": {"reason": "safety_breach"}}, + {"type": "action", "name": "land", "params": {"mode": "home"}} + ]} + ] + } + + main_children = [{"type": "action", "name": action_name, "params": action_params}] + if action_name != "land": + main_children.append({"type": "action", "name": "land", "params": {"mode": "home"}}) + + root_parallel = { + "type": "Parallel", + "name": "MissionWithSafety", + "params": {"policy": "all_success"}, + "children": [ + {"type": "Sequence", "name": "MainTask", "children": main_children}, + safety_selector + ] + } + pytree_dict = {"root": root_parallel} + except jsonschema.ValidationError: + # 不符合简单结构,按正常复杂验证继续 + pass + if _validate_pytree_with_schema(pytree_dict, self.schema): + logging.info("✅ 成功生成并验证了Pytree") + plan_id = str(uuid.uuid4()) + pytree_dict['plan_id'] = plan_id + + # Generate visualization to a static path + vis_start = time.time() + vis_filename = "py_tree.png" + vis_path = os.path.join(self.vis_dir, vis_filename) + _visualize_pytree(pytree_dict['root'], os.path.splitext(vis_path)[0]) + vis_time = time.time() - vis_start + logging.info(f"⏱️ 可视化耗时: {vis_time:.2f}秒") + + pytree_dict['visualization_url'] = f"/static/{vis_filename}" + + total_time = time.time() - total_start_time + logging.info(f"🎉 ⏱️ 总耗时: {total_time:.2f}秒 (分类:{classify_time:.1f}s + 检索:{retrieval_time:.1f}s + 生成:{generation_time:.1f}s + 可视化:{vis_time:.1f}s)") + + return pytree_dict + else: + # 打印未通过验证的Pytree以便排查 + preview = json.dumps(pytree_dict, ensure_ascii=False, indent=2) + logging.warning(f"❌ 未通过验证的Pytree(第 {attempt + 1}/3 次尝试):\n{preview}") + logging.warning("生成的Pytree验证失败,正在重试...") + except Exception as e: + logging.error(f"生成Pytree时发生错误: {e}") + + raise RuntimeError("在3次尝试后,仍未能生成一个有效的Pytree。") + +# Create a single instance for the application +py_tree_generator = PyTreeGenerator() + diff --git a/scripts/ai_agent/groundcontrol/src/test_api.py b/scripts/ai_agent/groundcontrol/src/test_api.py new file mode 100644 index 0000000..4d35fb8 --- /dev/null +++ b/scripts/ai_agent/groundcontrol/src/test_api.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import requests +import json + +# --- Configuration --- +# The base URL of your running FastAPI service +BASE_URL = "http://127.0.0.1:8001" + +# The endpoint we want to test +ENDPOINT = "/generate_plan" + +# The user prompt we will send for the test +TEST_PROMPT = "起飞,飞到匡亚明学院" + +def test_generate_plan(): + """ + Sends a request to the /generate_plan endpoint and validates the response. + """ + url = BASE_URL + ENDPOINT + payload = {"user_prompt": TEST_PROMPT} + headers = {"Content-Type": "application/json"} + + print("--- API Test: Generate Plan ---") + print(f"✅ URL: {url}") + print(f"✅ Sending Prompt: \"{TEST_PROMPT}\"") + + try: + # Send the POST request + response = requests.post(url, data=json.dumps(payload), headers=headers) + + # Check for HTTP errors (e.g., 404, 500) + response.raise_for_status() + + # Parse the JSON response + data = response.json() + + print("✅ Received Response:") + print(json.dumps(data, indent=2, ensure_ascii=False)) + + # --- Validation --- + print("\n--- Validation Checks ---") + + # 1. Check if the response is a dictionary + if isinstance(data, dict): + print("PASS: Response is a valid JSON object.") + else: + print("FAIL: Response is not a valid JSON object.") + return + + # 2. Check for the existence of the 'root' key + if "root" in data and isinstance(data['root'], dict): + print("PASS: Response contains a valid 'root' key.") + else: + print("FAIL: Response does not contain a valid 'root' key.") + + # 3. Check for the existence and format of the 'visualization_url' key + if "visualization_url" in data and data["visualization_url"].endswith(".png"): + print(f"PASS: Response contains a valid 'visualization_url': {data['visualization_url']}") + else: + print("FAIL: Response does not contain a valid 'visualization_url'.") + + except requests.exceptions.RequestException as e: + print(f"\n❌ TEST FAILED: Could not connect to the server.") + print(" Please make sure the backend service is running.") + print(f" Error details: {e}") + except json.JSONDecodeError: + print(f"\n❌ TEST FAILED: The server response was not valid JSON.") + print(f" Response text: {response.text}") + except Exception as e: + print(f"\n❌ TEST FAILED: An unexpected error occurred: {e}") + +if __name__ == "__main__": + test_generate_plan() diff --git a/scripts/ai_agent/groundcontrol/src/test_rag_score.py b/scripts/ai_agent/groundcontrol/src/test_rag_score.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/ai_agent/groundcontrol/src/websocket_manager.py b/scripts/ai_agent/groundcontrol/src/websocket_manager.py new file mode 100644 index 0000000..dbc42f2 --- /dev/null +++ b/scripts/ai_agent/groundcontrol/src/websocket_manager.py @@ -0,0 +1,48 @@ +import asyncio +from typing import List +from fastapi import WebSocket +import logging + +class ConnectionManager: + def __init__(self): + self.active_connections: List[WebSocket] = [] + self.loop: asyncio.AbstractEventLoop | None = None + + def set_loop(self, loop: asyncio.AbstractEventLoop): + """Sets the asyncio event loop.""" + self.loop = loop + + async def connect(self, websocket: WebSocket): + await websocket.accept() + self.active_connections.append(websocket) + + def disconnect(self, websocket: WebSocket): + self.active_connections.remove(websocket) + + def broadcast(self, message: str): + """ + Thread-safely broadcasts a message to all active WebSocket connections. + This method is designed to be called from a different thread (e.g., a ROS2 callback). + """ + if not self.loop: + logging.error("Event loop not set in ConnectionManager. Cannot broadcast.") + return + + # Schedule the coroutine to be executed in the event loop + self.loop.call_soon_threadsafe(self._broadcast_in_loop, message) + + def _broadcast_in_loop(self, message: str): + """ + Helper to run the broadcast coroutine in the correct event loop. + """ + asyncio.ensure_future(self._broadcast_async(message), loop=self.loop) + + async def _broadcast_async(self, message: str): + """ + The actual async method that sends messages. + """ + tasks = [connection.send_text(message) for connection in self.active_connections] + await asyncio.gather(*tasks, return_exceptions=True) + +# Create a single instance of the manager to be used across the application +websocket_manager = ConnectionManager() diff --git a/scripts/ai_agent/groundcontrol/start_backend.sh b/scripts/ai_agent/groundcontrol/start_backend.sh new file mode 100755 index 0000000..4ee6161 --- /dev/null +++ b/scripts/ai_agent/groundcontrol/start_backend.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# 后端服务启动脚本 +# 用于启动 FastAPI 后端服务 + +# 获取脚本所在目录 +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +# 设置 Python 路径 +export PYTHONPATH="${SCRIPT_DIR}:${PYTHONPATH}" + +# 启动后端服务 +echo "正在启动后端服务..." +echo "工作目录: $SCRIPT_DIR" +echo "监听地址: 0.0.0.0:8001" +echo "" + +python -m uvicorn src.main:app --host 0.0.0.0 --port 8001 +