Files
DronePlanning/tools/test_validate/test_validity.py

437 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import requests
import json
import csv
import time
from datetime import datetime
import os
import re
# --- Configuration ---
BASE_URL = "http://127.0.0.1:8000"
ENDPOINT = "/generate_plan"
INSTRUCTIONS_FILE = "instructions.txt"
RESULTS_CSV = "test_results.csv"
SUMMARY_CSV = "test_summary.csv"
LOG_FILE = "api_test_log.txt"
# 测试参数
TESTS_PER_INSTRUCTION = 1
MAX_RETRIES = 3
RETRY_DELAY = 2
# 添加调试模式
DEBUG = True
def debug_print(message):
"""调试输出"""
if DEBUG:
print(f"🐛 DEBUG: {message}")
def check_safety_monitoring(node):
"""简化安全监控检查"""
has_battery = False
has_emergency = False
def check_node(current_node):
nonlocal has_battery, has_emergency
# 检查电池相关条件
if (current_node.get('type') == 'condition' and
'battery' in str(current_node.get('name', '')).lower()):
has_battery = True
# 检查紧急动作
if (current_node.get('type') == 'action' and
any(keyword in str(current_node.get('name', '')).lower()
for keyword in ['emergency', 'safe', 'land'])):
has_emergency = True
for child in current_node.get('children', []):
check_node(child)
check_node(node)
return has_battery or has_emergency # 放宽要求
def check_leaf_nodes(node, depth=0, max_depth=50):
"""检查节点结构"""
if depth > max_depth:
return True # 不因深度限制而失败
# 动作和条件节点不应该有子节点
if node.get('type') in ['action', 'condition']:
return 'children' not in node or not node['children']
# 控制流节点应该有子节点
if node.get('type') in ['Sequence', 'Selector', 'Parallel']:
if 'children' not in node or not node['children']:
return False
# 递归检查
for child in node.get('children', []):
if not check_leaf_nodes(child, depth + 1, max_depth):
return False
return True
def send_api_request(prompt, instruction_idx, run_number):
"""发送API请求并返回结果"""
url = BASE_URL + ENDPOINT
payload = {"user_prompt": prompt}
headers = {"Content-Type": "application/json"}
for attempt in range(MAX_RETRIES):
try:
debug_print(f"指令 {instruction_idx}-{run_number} 尝试 {attempt + 1}")
start_time = time.time()
response = requests.post(url, data=json.dumps(payload), headers=headers, timeout=60) # 增加超时
response_time = time.time() - start_time
# 首先检查HTTP状态
response.raise_for_status()
# 尝试解析JSON
try:
data = response.json()
except json.JSONDecodeError as e:
debug_print(f"JSON解析失败: {e}, 响应文本: {response.text[:200]}")
raise
root_node = data.get('root', {})
# 基本验证 - 放宽要求
validation_checks = {
"is_dict": isinstance(data, dict),
"has_root": "root" in data,
"root_has_children": bool(root_node.get('children')),
"has_plan_id": "plan_id" in data,
"has_visualization_url": "visualization_url" in data,
}
# 可选的高级验证
advanced_checks = {
"leaf_nodes_valid": check_leaf_nodes(root_node),
"has_safety": check_safety_monitoring(root_node)
}
# 合并验证结果
validation_checks.update(advanced_checks)
# 统计无效节点但不作为失败条件
invalid_actions = []
invalid_conditions = []
def collect_nodes(current_node):
if current_node.get('type') == 'action':
action_name = current_node.get('name', '')
if action_name not in ['deliver_payload', 'emergency_return', 'fly_to_waypoint',
'land', 'loiter', 'object_detect', 'preflight_checks',
'search_pattern', 'strike_target', 'battle_damage_assessment', 'takeoff']:
invalid_actions.append(action_name)
elif current_node.get('type') == 'condition':
condition_name = current_node.get('name', '')
if condition_name not in ['battery_above', 'at_waypoint', 'object_detected',
'target_destroyed', 'time_elapsed', 'gps_status']:
invalid_conditions.append(condition_name)
for child in current_node.get('children', []):
collect_nodes(child)
collect_nodes(root_node)
# 主要检查基本验证,高级验证作为警告
success = all(validation_checks[k] for k in ["is_dict", "has_root", "root_has_children",
"has_plan_id", "has_visualization_url"])
debug_print(f"验证结果: 成功={success}, 基本验证通过={all(validation_checks.values())}")
return {
"success": success,
"data": data,
"validation_checks": validation_checks,
"response_time": response_time,
"invalid_actions": invalid_actions,
"invalid_conditions": invalid_conditions,
"error": None,
"attempts": attempt + 1,
"http_status": response.status_code
}
except requests.exceptions.RequestException as e:
error_msg = f"请求失败: {e}"
debug_print(f"请求异常: {error_msg}")
if attempt < MAX_RETRIES - 1:
time.sleep(RETRY_DELAY)
continue
return {
"success": False,
"data": None,
"validation_checks": {},
"response_time": 0,
"invalid_actions": [],
"invalid_conditions": [],
"error": error_msg,
"attempts": attempt + 1,
"http_status": getattr(e.response, 'status_code', None) if hasattr(e, 'response') else None
}
except Exception as e:
error_msg = f"未知错误: {e}"
debug_print(f"未知错误: {error_msg}")
return {
"success": False,
"data": None,
"validation_checks": {},
"response_time": 0,
"invalid_actions": [],
"invalid_conditions": [],
"error": error_msg,
"attempts": attempt + 1,
"http_status": None
}
def read_instructions(filename):
"""读取指令列表"""
instructions = []
try:
with open(filename, 'r', encoding='utf-8') as file:
for line in file:
line = line.strip()
if line and not line.startswith('#'):
instructions.append(line)
return instructions
except Exception as e:
print(f"❌ 读取指令文件时出错: {e}")
return []
def write_log_entry(log_file, instruction_idx, run_number, prompt, result):
"""写入详细日志包含完整的API响应"""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, 'a', encoding='utf-8') as f:
f.write(f"\n{'='*80}\n")
f.write(f"指令 #{instruction_idx} - 运行 #{run_number} - {timestamp}\n")
f.write(f"HTTP状态: {result.get('http_status', 'N/A')}\n")
f.write(f"原始指令: {prompt}\n")
f.write(f"尝试次数: {result['attempts']}\n")
f.write(f"响应时间: {result['response_time']:.2f}\n")
f.write(f"结果: {'✅ 成功' if result['success'] else '❌ 失败'}\n")
if result['success'] and result.get('data'):
data = result['data']
# 提取并记录组织后的prompt如果存在
organized_prompt = None
if isinstance(data, dict):
organized_prompt = data.get("organized_prompt") or \
data.get("processed_prompt") or \
data.get("final_prompt") or \
data.get("enhanced_prompt") or \
data.get("user_prompt_enhanced")
if organized_prompt:
f.write(f"\n📝 组织后的Prompt:\n")
f.write(f"{organized_prompt}\n")
else:
f.write(f"\n📝 组织后的Prompt: (未在响应中返回)\n")
# 记录验证结果
f.write("\n验证结果:\n")
for check_name, check_result in result['validation_checks'].items():
f.write(f" {check_name}: {'' if check_result else ''}\n")
if result['invalid_actions']:
f.write(f"⚠️ 无效动作节点: {result['invalid_actions']}\n")
if result['invalid_conditions']:
f.write(f"⚠️ 无效条件节点: {result['invalid_conditions']}\n")
# 记录完整的API响应
f.write(f"\n完整API响应:\n")
try:
response_json = json.dumps(data, indent=2, ensure_ascii=False)
f.write(response_json)
f.write("\n")
except Exception as e:
f.write(f"⚠️ 无法序列化响应数据: {e}\n")
f.write(f"原始数据: {str(data)}\n")
else:
f.write(f"错误信息: {result['error']}\n")
# 即使失败也尝试记录响应数据(如果有)
if result.get('data'):
f.write(f"\n部分响应数据:\n")
try:
response_json = json.dumps(result['data'], indent=2, ensure_ascii=False)
f.write(response_json)
f.write("\n")
except Exception:
f.write(f"原始数据: {str(result['data'])}\n")
def generate_summary_report(instructions, results_summary):
"""
生成统计摘要报告(修复除零错误)
"""
try:
with open(SUMMARY_CSV, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['instruction_index', 'instruction', 'total_runs', 'successful_runs',
'success_rate', 'avg_response_time', 'min_response_time',
'max_response_time', 'total_response_time']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for i, instruction in enumerate(instructions):
summary = results_summary[i]
success_count = summary['success_count']
# 防止除零错误
avg_time = "N/A"
min_time = "N/A"
max_time = "N/A"
if success_count > 0:
avg_time = f"{summary['total_response_time'] / success_count:.2f}s"
min_time = f"{summary['min_response_time']:.2f}s"
max_time = f"{summary['max_response_time']:.2f}s"
writer.writerow({
'instruction_index': i + 1,
'instruction': instruction,
'total_runs': TESTS_PER_INSTRUCTION,
'successful_runs': success_count,
'success_rate': f"{(success_count / TESTS_PER_INSTRUCTION * 100):.2f}%",
'avg_response_time': avg_time,
'min_response_time': min_time,
'max_response_time': max_time,
'total_response_time': f"{summary['total_response_time']:.2f}s"
})
print(f"📊 统计摘要已保存至: {SUMMARY_CSV}")
except Exception as e:
print(f"❌ 保存统计摘要时出错: {e}")
def main():
"""主测试函数"""
print("🚀 开始批量API测试")
print(f"每个指令测试 {TESTS_PER_INSTRUCTION}")
instructions = read_instructions(INSTRUCTIONS_FILE)
if not instructions:
return
print(f"找到 {len(instructions)} 条指令")
# 初始化统计
results_summary = [{
'success_count': 0,
'total_response_time': 0,
'min_response_time': float('inf'),
'max_response_time': 0,
'http_statuses': []
} for _ in instructions]
detailed_results = []
# 执行测试
for instruction_idx, prompt in enumerate(instructions, 1):
print(f"\n{'='*60}")
print(f"📋 测试指令 {instruction_idx}/{len(instructions)}")
print(f"指令: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
print(f"{'='*60}")
for run_number in range(1, TESTS_PER_INSTRUCTION + 1):
print(f" 运行 {run_number}/{TESTS_PER_INSTRUCTION}...", end=" ", flush=True)
result = send_api_request(prompt, instruction_idx, run_number)
write_log_entry(LOG_FILE, instruction_idx, run_number, prompt, result)
# 记录结果
plan_id = ""
if result.get("success") and result.get("data") and isinstance(result["data"], dict):
plan_id = result["data"].get("plan_id", "")
detailed_result = {
"instruction_index": instruction_idx,
"instruction": prompt,
"run_number": run_number,
"success": result["success"],
"attempts": result["attempts"],
"response_time": result["response_time"],
"plan_id": plan_id,
"error": result["error"] or "",
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
detailed_results.append(detailed_result)
# 更新统计
idx = instruction_idx - 1
if result["success"]:
results_summary[idx]['success_count'] += 1
results_summary[idx]['total_response_time'] += result['response_time']
results_summary[idx]['min_response_time'] = min(
results_summary[idx]['min_response_time'], result['response_time']
)
results_summary[idx]['max_response_time'] = max(
results_summary[idx]['max_response_time'], result['response_time']
)
print(f"✅ 成功 ({result['response_time']:.1f}s)")
else:
print(f"❌ 失败 (HTTP: {result.get('http_status', 'N/A')})")
# 记录HTTP状态
if 'http_status' in result:
results_summary[idx]['http_statuses'].append(result['http_status'])
time.sleep(1) # 避免服务器过载
# 生成详细结果CSV
try:
with open(RESULTS_CSV, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['instruction_index', 'instruction', 'run_number', 'success',
'attempts', 'response_time', 'plan_id', 'error', 'timestamp']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for result in detailed_results:
writer.writerow(result)
print(f"\n📊 详细结果已保存至: {RESULTS_CSV}")
except Exception as e:
print(f"❌ 保存详细结果时出错: {e}")
# 生成统计摘要
generate_summary_report(instructions, results_summary)
# 计算总统计
total_tests = len(instructions) * TESTS_PER_INSTRUCTION
total_successful = sum(summary['success_count'] for summary in results_summary)
# 打印最终统计
print(f"\n{'='*60}")
print("📈 最终测试统计")
print(f"{'='*60}")
print(f"总测试次数: {total_tests}")
print(f"成功次数: {total_successful}")
print(f"失败次数: {total_tests - total_successful}")
if total_tests > 0:
print(f"总成功率: {(total_successful / total_tests * 100):.2f}%")
else:
print(f"总成功率: N/A")
# 打印每个指令的统计
print(f"\n📋 每个指令的统计:")
for i, (instruction, summary) in enumerate(zip(instructions, results_summary), 1):
success_rate = (summary['success_count'] / TESTS_PER_INSTRUCTION * 100)
avg_time = summary['total_response_time'] / summary['success_count'] if summary['success_count'] > 0 else 0
print(f" 指令 {i}: {success_rate:.1f}% 成功 ({summary['success_count']}/{TESTS_PER_INSTRUCTION}), "
f"平均时间: {avg_time:.2f}s")
print(f"\n📁 输出文件:")
print(f"详细日志: {LOG_FILE}")
print(f"详细结果: {RESULTS_CSV}")
print(f"统计摘要: {SUMMARY_CSV}")
if __name__ == "__main__":
main()