Files
DronePlanning/tools/test_validate/modules/batch_runner.py

130 lines
4.6 KiB
Python

import os
import json
import time
import csv
import shutil
from datetime import datetime
from .api_client import APIClient
from .visualizer import generate_visualization, sanitize_filename
def run_batch_test():
# 1. 选择指令文件
base_dir = os.path.dirname(os.path.dirname(__file__))
instr_dir = os.path.join(base_dir, "instructions")
files = [f for f in os.listdir(instr_dir) if f.endswith('.txt')]
if not files:
print("❌ 未在 instructions 目录下找到 .txt 文件")
return
print("\n请选择测试指令文件:")
for i, f in enumerate(files):
print(f"{i+1}. {f}")
try:
idx = int(input("请输入序号: ").strip()) - 1
if idx < 0 or idx >= len(files):
print("❌ 无效序号")
return
selected_file = os.path.join(instr_dir, files[idx])
except ValueError:
print("❌ 输入无效")
return
# 2. 配置参数
try:
iterations = int(input("请输入每条指令的测试次数 (默认1): ").strip() or "1")
except ValueError:
iterations = 1
# 3. 准备输出目录
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
output_dir = os.path.join(base_dir, "validation", timestamp)
os.makedirs(output_dir, exist_ok=True)
# 备份指令文件
shutil.copy(selected_file, os.path.join(output_dir, "instructions_backup.txt"))
# 读取指令
with open(selected_file, 'r', encoding='utf-8') as f:
instructions = [line.strip() for line in f if line.strip() and not line.startswith('#')]
print(f"\n🚀 开始批量测试 (共 {len(instructions)} 条指令, 每条 {iterations} 次)")
print(f"📂 输出目录: {output_dir}\n")
client = APIClient()
detailed_results = []
summary_stats = {}
for i, prompt in enumerate(instructions, 1):
print(f"[{i}/{len(instructions)}] 测试指令: {prompt[:30]}...")
safe_name = sanitize_filename(prompt)
instr_out_dir = os.path.join(output_dir, safe_name)
os.makedirs(instr_out_dir, exist_ok=True)
success_count = 0
total_latency = 0
for k in range(1, iterations + 1):
print(f" - 第 {k} 次...", end="", flush=True)
result = client.send_request(prompt)
# 记录详情
detailed_results.append({
"instruction": prompt,
"run_id": k,
"success": result['success'],
"latency": result['latency'],
"error": result['error'] or ""
})
# 保存产物
if result['success']:
print(f" ✅ ({result['latency']:.2f}s)")
success_count += 1
total_latency += result['latency']
# 保存JSON
with open(os.path.join(instr_out_dir, f"{k}.json"), 'w', encoding='utf-8') as f:
json.dump(result['data'], f, indent=2, ensure_ascii=False)
# 保存图片
if result['data'] and 'root' in result['data']:
generate_visualization(result['data']['root'], os.path.join(instr_out_dir, f"{k}.png"))
else:
print(f"{result['error']}")
time.sleep(0.5) # 避免过快请求
# 统计单条指令
avg_lat = total_latency / success_count if success_count > 0 else 0
summary_stats[prompt] = {
"total_runs": iterations,
"success_runs": success_count,
"success_rate": f"{(success_count/iterations)*100:.1f}%",
"avg_latency": f"{avg_lat:.2f}s"
}
# 4. 生成报告
# 详细报告
with open(os.path.join(output_dir, "test_details.csv"), 'w', newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=["instruction", "run_id", "success", "latency", "error"])
writer.writeheader()
writer.writerows(detailed_results)
# 摘要报告
with open(os.path.join(output_dir, "test_summary.csv"), 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow(["Instruction", "Total Runs", "Success Runs", "Success Rate", "Avg Latency"])
for prompt, stats in summary_stats.items():
writer.writerow([
prompt,
stats["total_runs"],
stats["success_runs"],
stats["success_rate"],
stats["avg_latency"]
])
print(f"\n✅ 测试完成! 统计报告已保存至 {output_dir}")