优化交互式测试验证脚本,针对场景4修改提示词以及代码
This commit is contained in:
57
tools/test_validate/README.md
Normal file
57
tools/test_validate/README.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# Test & Validation Tools (Unified)
|
||||
|
||||
该目录包含用于测试无人机规划系统、API 接口及 LLM 服务的集成验证工具集。
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
使用统一入口脚本启动交互式菜单:
|
||||
|
||||
```bash
|
||||
python run_tests.py
|
||||
```
|
||||
|
||||
## 🛠️ 测试模式
|
||||
|
||||
### 1. 交互式单次测试 (Mode 1)
|
||||
- **场景**: 快速验证单条指令,调试 Prompt。
|
||||
- **操作**: 在终端输入自然语言指令,即时获取结果。
|
||||
- **输出**: `validation/temporary/{指令名}/`
|
||||
- `response.json`: 完整 API 响应
|
||||
- `plan.png`: 可视化任务树
|
||||
- `process.log`: 请求与响应日志
|
||||
|
||||
### 2. 批量/场景测试 (Mode 2)
|
||||
- **场景**:
|
||||
- **场景测试**: 验证一组预定义指令的正确性(默认运行 1 次)。
|
||||
- **稳定性测试**: 对同一组指令进行高频重复测试(如运行 10 次),检测成功率和延迟抖动。
|
||||
- **操作**:
|
||||
1. 选择指令文件(位于 `instructions/` 目录)。
|
||||
2. 输入每条指令的运行次数(默认 1)。
|
||||
- **输出**: `validation/{时间戳}/`
|
||||
- `test_summary.csv`: 统计摘要(成功率、平均耗时)
|
||||
- `test_details.csv`: 每次运行的详细记录
|
||||
- `instructions_backup.txt`: 本次测试使用的指令备份
|
||||
- `{指令名}/`: 包含所有运行的 `.json` 和 `.png` 产物
|
||||
|
||||
## 📂 目录结构
|
||||
|
||||
```text
|
||||
tools/test_validate/
|
||||
├── instructions/ # 指令集文件 (.txt)
|
||||
├── modules/ # 功能模块
|
||||
│ ├── api_client.py # API 客户端核心
|
||||
│ ├── interactive_test.py # 交互式测试逻辑
|
||||
│ ├── batch_runner.py # 批量测试逻辑
|
||||
│ ├── visualizer.py # 可视化工具库
|
||||
│ ├── llm_tester.py # LLM 连接测试
|
||||
│ └── drone_uploader.py # 任务上传工具
|
||||
├── validation/ # 测试产物输出
|
||||
│ ├── temporary/ # 交互式测试结果
|
||||
│ └── {时间戳}/ # 批量测试结果
|
||||
└── run_tests.py # 主程序入口
|
||||
```
|
||||
|
||||
## 📄 配置文件
|
||||
|
||||
- **instructions/validate_instructions.txt**: 默认的预定义场景指令集。
|
||||
- 您可以在 `instructions/` 下添加任意 `.txt` 文件,测试时会在菜单中自动列出供选择。
|
||||
@@ -0,0 +1,5 @@
|
||||
去研究所正大门,搜索扎辫子女子并拍照。
|
||||
查找戴帽子的女子,找到后近距离拍照。
|
||||
到研究所广场,对长发可疑男子进行拍照。
|
||||
到研究所广场,寻找黄色衣服男子,我确认后再对其拍照。
|
||||
立即返航。
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
60
tools/test_validate/modules/api_client.py
Normal file
60
tools/test_validate/modules/api_client.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import requests
|
||||
import time
|
||||
import json
|
||||
|
||||
class APIClient:
|
||||
def __init__(self, base_url="http://127.0.0.1:8000"):
|
||||
self.base_url = base_url
|
||||
self.endpoint = "/generate_plan"
|
||||
|
||||
def send_request(self, prompt, timeout=60):
|
||||
"""
|
||||
Sends a request to the API and returns a structured result.
|
||||
Returns:
|
||||
dict: {
|
||||
"success": bool,
|
||||
"data": dict or None,
|
||||
"latency": float (seconds),
|
||||
"error": str or None,
|
||||
"http_status": int or None
|
||||
}
|
||||
"""
|
||||
url = f"{self.base_url}{self.endpoint}"
|
||||
payload = {"user_prompt": prompt}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
response = requests.post(url, json=payload, headers=headers, timeout=timeout)
|
||||
latency = time.time() - start_time
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
return {
|
||||
"success": True,
|
||||
"data": data,
|
||||
"latency": latency,
|
||||
"error": None,
|
||||
"http_status": response.status_code
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {
|
||||
"success": False,
|
||||
"data": None,
|
||||
"latency": latency,
|
||||
"error": f"Invalid JSON response: {response.text[:200]}",
|
||||
"http_status": response.status_code
|
||||
}
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
latency = time.time() - start_time
|
||||
return {
|
||||
"success": False,
|
||||
"data": None,
|
||||
"latency": latency,
|
||||
"error": str(e),
|
||||
"http_status": getattr(e.response, 'status_code', None)
|
||||
}
|
||||
|
||||
129
tools/test_validate/modules/batch_runner.py
Normal file
129
tools/test_validate/modules/batch_runner.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import csv
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from .api_client import APIClient
|
||||
from .visualizer import generate_visualization, sanitize_filename
|
||||
|
||||
def run_batch_test():
|
||||
# 1. 选择指令文件
|
||||
base_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
instr_dir = os.path.join(base_dir, "instructions")
|
||||
files = [f for f in os.listdir(instr_dir) if f.endswith('.txt')]
|
||||
|
||||
if not files:
|
||||
print("❌ 未在 instructions 目录下找到 .txt 文件")
|
||||
return
|
||||
|
||||
print("\n请选择测试指令文件:")
|
||||
for i, f in enumerate(files):
|
||||
print(f"{i+1}. {f}")
|
||||
|
||||
try:
|
||||
idx = int(input("请输入序号: ").strip()) - 1
|
||||
if idx < 0 or idx >= len(files):
|
||||
print("❌ 无效序号")
|
||||
return
|
||||
selected_file = os.path.join(instr_dir, files[idx])
|
||||
except ValueError:
|
||||
print("❌ 输入无效")
|
||||
return
|
||||
|
||||
# 2. 配置参数
|
||||
try:
|
||||
iterations = int(input("请输入每条指令的测试次数 (默认1): ").strip() or "1")
|
||||
except ValueError:
|
||||
iterations = 1
|
||||
|
||||
# 3. 准备输出目录
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
output_dir = os.path.join(base_dir, "validation", timestamp)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 备份指令文件
|
||||
shutil.copy(selected_file, os.path.join(output_dir, "instructions_backup.txt"))
|
||||
|
||||
# 读取指令
|
||||
with open(selected_file, 'r', encoding='utf-8') as f:
|
||||
instructions = [line.strip() for line in f if line.strip() and not line.startswith('#')]
|
||||
|
||||
print(f"\n🚀 开始批量测试 (共 {len(instructions)} 条指令, 每条 {iterations} 次)")
|
||||
print(f"📂 输出目录: {output_dir}\n")
|
||||
|
||||
client = APIClient()
|
||||
detailed_results = []
|
||||
summary_stats = {}
|
||||
|
||||
for i, prompt in enumerate(instructions, 1):
|
||||
print(f"[{i}/{len(instructions)}] 测试指令: {prompt[:30]}...")
|
||||
safe_name = sanitize_filename(prompt)
|
||||
instr_out_dir = os.path.join(output_dir, safe_name)
|
||||
os.makedirs(instr_out_dir, exist_ok=True)
|
||||
|
||||
success_count = 0
|
||||
total_latency = 0
|
||||
|
||||
for k in range(1, iterations + 1):
|
||||
print(f" - 第 {k} 次...", end="", flush=True)
|
||||
result = client.send_request(prompt)
|
||||
|
||||
# 记录详情
|
||||
detailed_results.append({
|
||||
"instruction": prompt,
|
||||
"run_id": k,
|
||||
"success": result['success'],
|
||||
"latency": result['latency'],
|
||||
"error": result['error'] or ""
|
||||
})
|
||||
|
||||
# 保存产物
|
||||
if result['success']:
|
||||
print(f" ✅ ({result['latency']:.2f}s)")
|
||||
success_count += 1
|
||||
total_latency += result['latency']
|
||||
|
||||
# 保存JSON
|
||||
with open(os.path.join(instr_out_dir, f"{k}.json"), 'w', encoding='utf-8') as f:
|
||||
json.dump(result['data'], f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 保存图片
|
||||
if result['data'] and 'root' in result['data']:
|
||||
generate_visualization(result['data']['root'], os.path.join(instr_out_dir, f"{k}.png"))
|
||||
else:
|
||||
print(f" ❌ {result['error']}")
|
||||
|
||||
time.sleep(0.5) # 避免过快请求
|
||||
|
||||
# 统计单条指令
|
||||
avg_lat = total_latency / success_count if success_count > 0 else 0
|
||||
summary_stats[prompt] = {
|
||||
"total_runs": iterations,
|
||||
"success_runs": success_count,
|
||||
"success_rate": f"{(success_count/iterations)*100:.1f}%",
|
||||
"avg_latency": f"{avg_lat:.2f}s"
|
||||
}
|
||||
|
||||
# 4. 生成报告
|
||||
# 详细报告
|
||||
with open(os.path.join(output_dir, "test_details.csv"), 'w', newline='', encoding='utf-8') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=["instruction", "run_id", "success", "latency", "error"])
|
||||
writer.writeheader()
|
||||
writer.writerows(detailed_results)
|
||||
|
||||
# 摘要报告
|
||||
with open(os.path.join(output_dir, "test_summary.csv"), 'w', newline='', encoding='utf-8') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["Instruction", "Total Runs", "Success Runs", "Success Rate", "Avg Latency"])
|
||||
for prompt, stats in summary_stats.items():
|
||||
writer.writerow([
|
||||
prompt,
|
||||
stats["total_runs"],
|
||||
stats["success_runs"],
|
||||
stats["success_rate"],
|
||||
stats["avg_latency"]
|
||||
])
|
||||
|
||||
print(f"\n✅ 测试完成! 统计报告已保存至 {output_dir}")
|
||||
|
||||
37
tools/test_validate/modules/drone_uploader.py
Normal file
37
tools/test_validate/modules/drone_uploader.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import requests
|
||||
import os
|
||||
import sys
|
||||
|
||||
def upload_mission(drone_ip, file_path):
|
||||
"""上传一个JSON任务文件到无人机"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"Error: File not found at {file_path}")
|
||||
return
|
||||
|
||||
url = f"http://{drone_ip}:5000/missions"
|
||||
print(f"正在上传 {file_path} 到 {url} ...")
|
||||
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
files = {'file': (os.path.basename(file_path), f, 'application/json')}
|
||||
response = requests.post(url, files=files, timeout=10)
|
||||
|
||||
# 检查HTTP响应状态码
|
||||
response.raise_for_status()
|
||||
|
||||
print("上传成功!")
|
||||
print("无人机端响应:", response.json())
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"上传过程中发生错误: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) < 3:
|
||||
print("用法: python ground_station_client.py [无人机IP地址] [JSON文件路径]")
|
||||
print("示例: python ground_station_client.py 192.168.1.10 ./missions/rescue_mission.json")
|
||||
sys.exit(1)
|
||||
|
||||
drone_ip_address = sys.argv[1]
|
||||
mission_file_path = sys.argv[2]
|
||||
|
||||
upload_mission(drone_ip_address, mission_file_path)
|
||||
87
tools/test_validate/modules/interactive_test.py
Normal file
87
tools/test_validate/modules/interactive_test.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from .api_client import APIClient
|
||||
from .visualizer import generate_visualization, sanitize_filename
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
|
||||
def run_interactive_test():
|
||||
client = APIClient()
|
||||
print("\n🚀 进入交互式测试模式 (输入 'exit' 或 'q' 退出)")
|
||||
|
||||
while True:
|
||||
try:
|
||||
prompt = input("\n请输入测试指令: ").strip()
|
||||
if prompt.lower() in ['exit', 'q']:
|
||||
break
|
||||
if not prompt:
|
||||
continue
|
||||
|
||||
print("⏳ 正在请求后端 API...")
|
||||
result = client.send_request(prompt)
|
||||
|
||||
if result['success']:
|
||||
print(f"✅ 请求成功 (耗时: {result['latency']:.2f}s)")
|
||||
|
||||
# 创建输出目录
|
||||
sanitized_name = sanitize_filename(prompt)
|
||||
output_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"validation",
|
||||
"temporary",
|
||||
sanitized_name
|
||||
)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 保存 JSON
|
||||
json_path = os.path.join(output_dir, "response.json")
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(result['data'], f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 保存日志
|
||||
log_path = os.path.join(output_dir, "process.log")
|
||||
with open(log_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"Prompt: {prompt}\n")
|
||||
f.write(f"Status: {result['http_status']}\n")
|
||||
f.write(f"Latency: {result['latency']}\n")
|
||||
f.write(f"Response: {json.dumps(result['data'], ensure_ascii=False)}\n")
|
||||
|
||||
# 生成图片
|
||||
if result['data'] and 'root' in result['data']:
|
||||
png_path = os.path.join(output_dir, "plan.png")
|
||||
if generate_visualization(result['data']['root'], png_path):
|
||||
print(f"🖼️ 可视化图已生成: {png_path}")
|
||||
else:
|
||||
print("⚠️ 可视化生成失败")
|
||||
|
||||
print(f"📂 结果已保存至: {output_dir}")
|
||||
else:
|
||||
print(f"❌ 请求失败: {result['error']}")
|
||||
|
||||
# 即使失败也保存日志,以便排查
|
||||
sanitized_name = sanitize_filename(prompt)
|
||||
output_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"validation",
|
||||
"temporary",
|
||||
sanitized_name
|
||||
)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
log_path = os.path.join(output_dir, "process.log")
|
||||
with open(log_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"Prompt: {prompt}\n")
|
||||
f.write(f"Status: {result['http_status']}\n")
|
||||
f.write(f"Latency: {result['latency']}\n")
|
||||
f.write(f"Error: {result['error']}\n")
|
||||
# 如果有部分数据,也记录下来
|
||||
if result['data']:
|
||||
f.write(f"Partial Response: {json.dumps(result['data'], ensure_ascii=False)}\n")
|
||||
print(f"⚠️ 错误日志已保存至: {log_path}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n已取消")
|
||||
break
|
||||
|
||||
174
tools/test_validate/modules/llm_tester.py
Normal file
174
tools/test_validate/modules/llm_tester.py
Normal file
@@ -0,0 +1,174 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def build_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="调用本地 llama-server (OpenAI兼容) 进行推理,支持自定义系统/用户提示词"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
default=os.getenv("SIMPLE_BASE_URL", "http://127.0.0.1:8081/v1"),
|
||||
help="llama-server 的基础URL(默认: http://127.0.0.1:8081/v1,或环境变量 SIMPLE_BASE_URL)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default=os.getenv("SIMPLE_MODEL", "local-model"),
|
||||
help="模型名称(默认: local-model,或环境变量 SIMPLE_MODEL)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system",
|
||||
default="You are a helpful assistant.",
|
||||
help="系统提示词(system role)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system-file",
|
||||
default=None,
|
||||
help="系统提示词文件路径(txt);若提供,则覆盖 --system 的字符串",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--user",
|
||||
default=None,
|
||||
help="用户提示词(user role);若不传则从交互式输入读取",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="采样温度(默认: 0.2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="最大生成Token数(默认: 4096)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=120.0,
|
||||
help="HTTP超时时间秒(默认: 120)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="打印完整返回JSON",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def call_llama_server(
|
||||
base_url: str,
|
||||
model: str,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
timeout: float,
|
||||
) -> Dict[str, Any]:
|
||||
endpoint = base_url.rstrip("/") + "/chat/completions"
|
||||
headers: Dict[str, str] = {"Content-Type": "application/json"}
|
||||
|
||||
# 兼容需要API Key的代理/服务(llama-server通常不强制)
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
resp = requests.post(endpoint, headers=headers, data=json.dumps(payload), timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = build_args()
|
||||
|
||||
user_prompt = args.user
|
||||
if not user_prompt:
|
||||
try:
|
||||
user_prompt = input("请输入用户提示词: ")
|
||||
except KeyboardInterrupt:
|
||||
print("\n已取消。")
|
||||
sys.exit(1)
|
||||
|
||||
# 解析系统提示词:优先使用 --system-file
|
||||
system_prompt = args.system
|
||||
if args.system_file:
|
||||
try:
|
||||
with open(args.system_file, "r", encoding="utf-8") as f:
|
||||
system_prompt = f.read()
|
||||
except Exception as e:
|
||||
print("\n❌ 读取系统提示词文件失败:")
|
||||
print(str(e))
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
print("--- llama-server 推理 ---")
|
||||
print(f"Base URL: {args.base_url}")
|
||||
print(f"Model: {args.model}")
|
||||
if args.system_file:
|
||||
print(f"System(from file): {args.system_file}")
|
||||
else:
|
||||
print(f"System: {system_prompt}")
|
||||
print(f"User: {user_prompt}")
|
||||
|
||||
data = call_llama_server(
|
||||
base_url=args.base_url,
|
||||
model=args.model,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
temperature=args.temperature,
|
||||
max_tokens=args.max_tokens,
|
||||
timeout=args.timeout,
|
||||
)
|
||||
|
||||
if args.verbose:
|
||||
print("\n完整返回JSON:")
|
||||
print(json.dumps(data, ensure_ascii=False, indent=2))
|
||||
|
||||
# 尝试按OpenAI兼容格式提取assistant内容
|
||||
content = None
|
||||
try:
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if content is not None:
|
||||
print("\n模型输出:")
|
||||
print(content)
|
||||
else:
|
||||
# 兜底打印
|
||||
print("\n无法按OpenAI兼容字段解析内容,原始返回如下:")
|
||||
print(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print("\n❌ 请求失败:请确认 llama-server 已在 8081 端口启动并可访问。")
|
||||
print(f"详情: {e}")
|
||||
sys.exit(2)
|
||||
except Exception as e:
|
||||
print("\n❌ 发生未预期的错误:")
|
||||
print(str(e))
|
||||
sys.exit(3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
287
tools/test_validate/modules/visualizer.py
Normal file
287
tools/test_validate/modules/visualizer.py
Normal file
@@ -0,0 +1,287 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
从API测试日志中提取JSON响应并批量可视化
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
import platform
|
||||
import random
|
||||
import html
|
||||
from typing import Dict, List, Tuple
|
||||
from collections import defaultdict
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
def sanitize_filename(text: str) -> str:
|
||||
"""将文本转换为安全的文件名"""
|
||||
# 移除或替换不安全的字符
|
||||
text = re.sub(r'[<>:"/\\|?*]', '_', text)
|
||||
# 限制长度
|
||||
if len(text) > 100:
|
||||
text = text[:100]
|
||||
return text
|
||||
|
||||
def _pick_zh_font():
|
||||
"""选择合适的中文字体"""
|
||||
sys = platform.system()
|
||||
if sys == "Windows":
|
||||
return "Microsoft YaHei"
|
||||
elif sys == "Darwin":
|
||||
return "PingFang SC"
|
||||
else:
|
||||
return "Noto Sans CJK SC"
|
||||
|
||||
def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
|
||||
"""递归辅助函数,用于添加节点和边。"""
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
logging.critical("错误:未安装graphviz库。请运行: pip install graphviz")
|
||||
return ""
|
||||
|
||||
current_id = f"{id(node)}_{random.randint(1000, 9999)}"
|
||||
|
||||
# 准备节点标签(HTML-like,正确换行与转义)
|
||||
name = html.escape(str(node.get('name', '')))
|
||||
ntype = html.escape(str(node.get('type', '')))
|
||||
label_parts = [f"<B>{name}</B> <FONT POINT-SIZE='10'><I>({ntype})</I></FONT>"]
|
||||
|
||||
# 格式化参数显示
|
||||
params = node.get('params') or {}
|
||||
if params:
|
||||
params_lines = []
|
||||
for key, value in params.items():
|
||||
k = html.escape(str(key))
|
||||
if isinstance(value, float):
|
||||
value_str = f"{value:.2f}".rstrip('0').rstrip('.')
|
||||
else:
|
||||
value_str = str(value)
|
||||
v = html.escape(value_str)
|
||||
params_lines.append(f"{k}: {v}")
|
||||
params_text = "<BR ALIGN='LEFT'/>".join(params_lines)
|
||||
label_parts.append(f"<FONT POINT-SIZE='9' COLOR='#555555'>{params_text}</FONT>")
|
||||
|
||||
node_label = f"<{'<BR/>'.join(label_parts)}>"
|
||||
|
||||
# 根据类型设置节点样式和颜色
|
||||
node_type = (node.get('type') or '').lower()
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#e6e6e6' # 默认灰色填充
|
||||
border_color = '#666666' # 默认描边色
|
||||
|
||||
if node_type == 'action':
|
||||
shape = 'box'
|
||||
style = 'rounded,filled'
|
||||
fillcolor = "#cde4ff" # 浅蓝
|
||||
elif node_type == 'condition':
|
||||
shape = 'diamond'
|
||||
style = 'filled'
|
||||
fillcolor = "#fff2cc" # 浅黄
|
||||
elif node_type == 'sequence':
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#d5e8d4' # 绿色
|
||||
elif node_type == 'selector':
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#ffe6cc' # 橙色
|
||||
elif node_type == 'parallel':
|
||||
shape = 'ellipse'
|
||||
style = 'filled'
|
||||
fillcolor = '#e1d5e7' # 紫色
|
||||
|
||||
# 特别标记安全相关节点
|
||||
if node.get('name') in ['battery_above', 'gps_status', 'SafetyMonitor']:
|
||||
border_color = '#ff0000' # 红色边框突出显示安全节点
|
||||
style = 'filled,bold' # 加粗
|
||||
|
||||
dot.node(current_id, label=node_label, shape=shape, style=style, fillcolor=fillcolor, color=border_color)
|
||||
|
||||
# 连接父节点
|
||||
if parent_id:
|
||||
dot.edge(parent_id, current_id)
|
||||
|
||||
# 递归处理子节点
|
||||
children = node.get("children", [])
|
||||
if not children:
|
||||
return current_id
|
||||
|
||||
# 记录所有子节点的ID
|
||||
child_ids = []
|
||||
|
||||
# 正确的递归连接:每个子节点都连接到当前节点
|
||||
for child in children:
|
||||
child_id = _add_nodes_and_edges(child, dot, current_id)
|
||||
child_ids.append(child_id)
|
||||
|
||||
# 子节点同级排列(横向排布,更直观地表现同层)
|
||||
if len(child_ids) > 1:
|
||||
with dot.subgraph(name=f"rank_{current_id}") as s:
|
||||
s.attr(rank='same')
|
||||
for cid in child_ids:
|
||||
s.node(cid)
|
||||
|
||||
return current_id
|
||||
|
||||
def generate_visualization(node: Dict, file_path: str):
|
||||
"""
|
||||
使用Graphviz将Pytree字典可视化,并保存到指定路径。
|
||||
"""
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
logging.critical("错误:未安装graphviz库。请运行: pip install graphviz")
|
||||
return False
|
||||
|
||||
fontname = _pick_zh_font()
|
||||
|
||||
dot = Digraph('Pytree', comment='Drone Mission Plan')
|
||||
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20', fontname=fontname)
|
||||
dot.attr('node', shape='box', style='rounded,filled', fontname=fontname)
|
||||
dot.attr('edge', fontname=fontname)
|
||||
|
||||
_add_nodes_and_edges(node, dot)
|
||||
|
||||
try:
|
||||
# 确保输出目录存在
|
||||
base_path, ext = os.path.splitext(file_path)
|
||||
render_path = base_path if ext.lower() == '.png' else file_path
|
||||
|
||||
out_dir = os.path.dirname(render_path)
|
||||
if out_dir and not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# 保存为 .png 文件,并自动删除源码 .gv 文件
|
||||
dot.render(render_path, format='png', cleanup=True, view=False)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"❌ 生成可视化图形失败: {e}")
|
||||
return False
|
||||
|
||||
# 保留旧的函数以兼容(如果有其他脚本引用)
|
||||
def _visualize_pytree(node: Dict, file_path: str):
|
||||
return generate_visualization(node, file_path)
|
||||
|
||||
def parse_log_file(log_file_path: str) -> Dict[str, List[Dict]]:
|
||||
"""
|
||||
解析日志文件,提取原始指令和完整API响应JSON
|
||||
返回: {原始指令: [JSON响应列表]}
|
||||
"""
|
||||
with open(log_file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# 按分隔符分割条目
|
||||
entries = re.split(r'={80,}', content)
|
||||
|
||||
results = defaultdict(list)
|
||||
|
||||
for entry in entries:
|
||||
if not entry.strip():
|
||||
continue
|
||||
|
||||
# 提取原始指令
|
||||
instruction_match = re.search(r'原始指令:\s*(.+)', entry)
|
||||
if not instruction_match:
|
||||
continue
|
||||
|
||||
original_instruction = instruction_match.group(1).strip()
|
||||
|
||||
# 提取完整API响应JSON
|
||||
json_match = re.search(r'完整API响应:\s*\n(\{.*\})', entry, re.DOTALL)
|
||||
if not json_match:
|
||||
logging.warning(f"未找到指令 '{original_instruction}' 的JSON响应")
|
||||
continue
|
||||
|
||||
json_str = json_match.group(1).strip()
|
||||
|
||||
try:
|
||||
json_obj = json.loads(json_str)
|
||||
results[original_instruction].append(json_obj)
|
||||
logging.info(f"成功提取指令 '{original_instruction}' 的JSON响应")
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"解析指令 '{original_instruction}' 的JSON失败: {e}")
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def process_and_visualize(log_file_path: str, output_dir: str):
|
||||
"""
|
||||
处理日志文件并批量可视化
|
||||
"""
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 解析日志文件
|
||||
logging.info(f"开始解析日志文件: {log_file_path}")
|
||||
instruction_responses = parse_log_file(log_file_path)
|
||||
|
||||
logging.info(f"共找到 {len(instruction_responses)} 个不同的原始指令")
|
||||
|
||||
# 处理每个指令的所有响应
|
||||
for instruction, responses in instruction_responses.items():
|
||||
logging.info(f"\n处理指令: {instruction} (共 {len(responses)} 个响应)")
|
||||
|
||||
# 创建指令目录(使用安全的文件名)
|
||||
safe_instruction_name = sanitize_filename(instruction)
|
||||
instruction_dir = os.path.join(output_dir, safe_instruction_name)
|
||||
os.makedirs(instruction_dir, exist_ok=True)
|
||||
|
||||
# 处理每个响应
|
||||
for idx, response in enumerate(responses, 1):
|
||||
try:
|
||||
# 提取root节点
|
||||
root_node = response.get('root')
|
||||
if not root_node:
|
||||
logging.warning(f"响应 #{idx} 没有root节点,跳过")
|
||||
continue
|
||||
|
||||
# 生成文件名
|
||||
json_filename = f"response_{idx}.json"
|
||||
png_filename = f"response_{idx}.png"
|
||||
|
||||
json_path = os.path.join(instruction_dir, json_filename)
|
||||
png_path = os.path.join(instruction_dir, png_filename)
|
||||
|
||||
# 保存JSON文件
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(response, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logging.info(f" 保存JSON: {json_filename}")
|
||||
|
||||
# 生成可视化
|
||||
generate_visualization(root_node, png_path)
|
||||
logging.info(f" 生成可视化: {png_filename}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"处理响应 #{idx} 时出错: {e}")
|
||||
continue
|
||||
|
||||
logging.info(f"\n✅ 所有处理完成!结果保存在: {output_dir}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="批量可视化API测试日志")
|
||||
parser.add_argument("--log", default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "api_test_log.txt"), help="日志文件路径")
|
||||
parser.add_argument("--out", default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "validation"), help="输出目录")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
log_file = args.log
|
||||
output_directory = args.out
|
||||
|
||||
print(f"日志文件: {log_file}")
|
||||
print(f"输出目录: {output_directory}")
|
||||
|
||||
if os.path.exists(log_file):
|
||||
process_and_visualize(log_file, output_directory)
|
||||
else:
|
||||
print(f"错误: 找不到日志文件 {log_file}")
|
||||
118
tools/test_validate/run_tests.py
Executable file
118
tools/test_validate/run_tests.py
Executable file
@@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import subprocess
|
||||
|
||||
# Add current directory to path so modules can be imported
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from modules.interactive_test import run_interactive_test
|
||||
from modules.batch_runner import run_batch_test
|
||||
|
||||
def clear_screen():
|
||||
os.system('cls' if os.name == 'nt' else 'clear')
|
||||
|
||||
def print_header():
|
||||
clear_screen()
|
||||
print("=" * 60)
|
||||
print(" Drone Planning 系统测试工具箱 (Unified)")
|
||||
print("=" * 60)
|
||||
print(f"当前工作目录: {os.getcwd()}")
|
||||
print("-" * 60)
|
||||
|
||||
def run_legacy_module(module_name, args=None):
|
||||
"""运行旧的独立 Python 脚本 (用于 LLM 测试和上传工具)"""
|
||||
script_path = os.path.join(os.path.dirname(__file__), "modules", module_name)
|
||||
|
||||
if not os.path.exists(script_path):
|
||||
print(f"❌ 错误: 找不到脚本 {script_path}")
|
||||
input("\n按回车键继续...")
|
||||
return
|
||||
|
||||
cmd = [sys.executable, script_path]
|
||||
if args:
|
||||
cmd.extend(args)
|
||||
|
||||
print(f"\n🚀 正在启动: {module_name} ...\n")
|
||||
try:
|
||||
# 保持当前环境变量
|
||||
env = os.environ.copy()
|
||||
# 确保 PYTHONPATH 包含当前目录
|
||||
env["PYTHONPATH"] = os.path.dirname(__file__) + os.pathsep + env.get("PYTHONPATH", "")
|
||||
subprocess.run(cmd, env=env, check=False)
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ 操作已取消")
|
||||
except Exception as e:
|
||||
print(f"\n❌ 运行出错: {e}")
|
||||
|
||||
input("\n按回车键返回菜单...")
|
||||
|
||||
def menu_drone_upload():
|
||||
print("\n[3] 上传任务到无人机")
|
||||
print("说明: 将生成的任务文件上传到无人机 (Ground Station Client)。")
|
||||
|
||||
ip = input("\n请输入无人机 IP 地址 (默认: 127.0.0.1): ").strip() or "127.0.0.1"
|
||||
file_path = input("请输入任务文件路径 (.json): ").strip()
|
||||
|
||||
if not file_path:
|
||||
print("❌ 必须提供文件路径!")
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
print(f"❌ 文件不存在: {file_path}")
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
run_legacy_module("drone_uploader.py", [ip, file_path])
|
||||
|
||||
def main():
|
||||
# 切换到脚本所在目录
|
||||
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
while True:
|
||||
print_header()
|
||||
print("请选择测试模式:")
|
||||
print("1. 交互式单次测试 (Interactive Single Test)")
|
||||
print(" - 手动输入指令,即时查看结果和可视化")
|
||||
print("2. 批量/场景测试 (Batch/Scenario Test)")
|
||||
print(" - 读取指令文件,支持多轮压力测试,生成统计报告")
|
||||
print("3. 上传任务到无人机 (Drone Uploader)")
|
||||
print("4. LLM 服务连通性测试 (LLM Tester)")
|
||||
print("0. 退出")
|
||||
print("-" * 60)
|
||||
|
||||
choice = input("请输入选项 [0-4]: ").strip()
|
||||
|
||||
if choice == '1':
|
||||
try:
|
||||
run_interactive_test()
|
||||
except Exception as e:
|
||||
print(f"❌ 运行出错: {e}")
|
||||
input("\n按回车键返回菜单...")
|
||||
|
||||
elif choice == '2':
|
||||
try:
|
||||
run_batch_test()
|
||||
except Exception as e:
|
||||
print(f"❌ 运行出错: {e}")
|
||||
input("\n按回车键返回菜单...")
|
||||
|
||||
elif choice == '3':
|
||||
menu_drone_upload()
|
||||
|
||||
elif choice == '4':
|
||||
run_legacy_module("llm_tester.py")
|
||||
|
||||
elif choice == '0':
|
||||
print("\n👋 再见!")
|
||||
break
|
||||
else:
|
||||
print("\n❌ 无效选项,请重试")
|
||||
time.sleep(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user