Files
DronePlanning/tools/test_validate/test_validity.py
2025-12-03 17:13:47 +08:00

508 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import requests
import json
import csv
import time
from datetime import datetime
import os
import re
# --- Configuration ---
BASE_URL = "http://127.0.0.1:8000"
ENDPOINT = "/generate_plan"
INSTRUCTIONS_FILE = "instructions.txt"
RESULTS_CSV = "test_results.csv"
SUMMARY_CSV = "test_summary.csv"
LOG_FILE = "api_test_log.txt"
# 测试参数
TESTS_PER_INSTRUCTION = 1
MAX_RETRIES = 3
RETRY_DELAY = 2
# 添加调试模式
DEBUG = True
def debug_print(message):
"""调试输出"""
if DEBUG:
print(f"🐛 DEBUG: {message}")
def check_safety_monitoring(node):
"""简化安全监控检查"""
has_battery = False
has_emergency = False
def check_node(current_node):
nonlocal has_battery, has_emergency
# 检查电池相关条件
if (current_node.get('type') == 'condition' and
'battery' in str(current_node.get('name', '')).lower()):
has_battery = True
# 检查紧急动作
if (current_node.get('type') == 'action' and
any(keyword in str(current_node.get('name', '')).lower()
for keyword in ['emergency', 'safe', 'land'])):
has_emergency = True
for child in current_node.get('children', []):
check_node(child)
check_node(node)
return has_battery or has_emergency # 放宽要求
def check_leaf_nodes(node, depth=0, max_depth=50):
"""检查节点结构"""
if depth > max_depth:
return True # 不因深度限制而失败
# 动作和条件节点不应该有子节点
if node.get('type') in ['action', 'condition']:
return 'children' not in node or not node['children']
# 控制流节点应该有子节点
if node.get('type') in ['Sequence', 'Selector', 'Parallel']:
if 'children' not in node or not node['children']:
return False
# 递归检查
for child in node.get('children', []):
if not check_leaf_nodes(child, depth + 1, max_depth):
return False
return True
def send_api_request(prompt, instruction_idx, run_number):
"""发送API请求并返回结果"""
url = BASE_URL + ENDPOINT
payload = {"user_prompt": prompt}
headers = {"Content-Type": "application/json"}
for attempt in range(MAX_RETRIES):
try:
debug_print(f"指令 {instruction_idx}-{run_number} 尝试 {attempt + 1}")
debug_print(f"请求URL: {url}")
debug_print(f"请求Payload: {json.dumps(payload, ensure_ascii=False)}")
start_time = time.time()
response = requests.post(url, data=json.dumps(payload), headers=headers, timeout=60) # 增加超时
response_time = time.time() - start_time
debug_print(f"HTTP状态码: {response.status_code}")
debug_print(f"响应时间: {response_time:.2f}")
# 首先检查HTTP状态
response.raise_for_status()
# 尝试解析JSON
try:
data = response.json()
except json.JSONDecodeError as e:
debug_print(f"JSON解析失败: {e}, 响应文本: {response.text[:200]}")
raise
# 判断是简单模式还是复杂模式
# 简单模式和复杂模式都使用root字段区别是
# - 简单模式root是单个action节点没有children
# - 复杂模式root是控制流节点或有children的节点
root_node = data.get('root', {})
root_type = root_node.get('type', '')
root_has_children = bool(root_node.get('children'))
# 简单模式root是action类型且没有children
is_simple_mode = (root_type == 'action' and not root_has_children)
# 复杂模式有root字段且不是简单模式
is_complex_mode = ("root" in data and not is_simple_mode)
# 基本验证 - 支持简单模式和复杂模式
validation_checks = {
"is_dict": isinstance(data, dict),
"has_root": "root" in data,
"has_plan_id": "plan_id" in data,
"has_visualization_url": "visualization_url" in data,
}
# 模式特定的验证
if is_simple_mode:
# 简单模式root必须是action类型且没有children
validation_checks.update({
"root_is_action": root_type == 'action',
"root_no_children": not root_has_children,
"root_has_name": bool(root_node.get('name')),
})
# 简单模式和复杂模式都不应该有mode字段
validation_checks["no_mode_field"] = "mode" not in data
elif is_complex_mode:
# 复杂模式root应该有children控制流节点
validation_checks.update({
"root_has_children": root_has_children,
})
# 复杂模式不应该有mode字段
validation_checks["no_mode_field"] = "mode" not in data
else:
# 既不是简单模式也不是复杂模式,这是错误
validation_checks["valid_mode"] = False
debug_print(f"⚠️ 响应既不是简单模式也不是复杂模式: root_type={root_type}, has_children={root_has_children}")
# 可选的高级验证
advanced_checks = {}
invalid_actions = []
invalid_conditions = []
def collect_nodes(current_node):
if current_node.get('type') == 'action':
action_name = current_node.get('name', '')
if action_name not in ['deliver_payload', 'emergency_return', 'fly_to_waypoint',
'land', 'loiter', 'object_detect', 'preflight_checks',
'search_pattern', 'strike_target', 'battle_damage_assessment', 'takeoff']:
invalid_actions.append(action_name)
elif current_node.get('type') == 'condition':
condition_name = current_node.get('name', '')
if condition_name not in ['battery_above', 'at_waypoint', 'object_detected',
'target_destroyed', 'time_elapsed', 'gps_status']:
invalid_conditions.append(condition_name)
for child in current_node.get('children', []):
collect_nodes(child)
if is_complex_mode:
# 复杂模式的高级验证
advanced_checks = {
"leaf_nodes_valid": check_leaf_nodes(root_node),
"has_safety": check_safety_monitoring(root_node)
}
collect_nodes(root_node)
elif is_simple_mode:
# 简单模式检查action名称是否有效
action_name = root_node.get('name', '')
if action_name and action_name not in ['deliver_payload', 'emergency_return', 'fly_to_waypoint',
'land', 'loiter', 'object_detect', 'preflight_checks',
'search_pattern', 'strike_target', 'battle_damage_assessment', 'takeoff']:
invalid_actions.append(action_name)
# 合并验证结果
validation_checks.update(advanced_checks)
# 根据模式确定成功条件
if is_simple_mode:
# 简单模式:必须有的字段
required_checks = ["is_dict", "has_root", "has_plan_id", "has_visualization_url",
"root_is_action", "root_no_children", "root_has_name", "no_mode_field"]
success = all(validation_checks.get(k, False) for k in required_checks)
elif is_complex_mode:
# 复杂模式:必须有的字段
required_checks = ["is_dict", "has_root", "has_plan_id", "has_visualization_url",
"root_has_children", "no_mode_field"]
success = all(validation_checks.get(k, False) for k in required_checks)
else:
# 无效模式
success = False
mode_type = "简单模式" if is_simple_mode else ("复杂模式" if is_complex_mode else "未知模式")
debug_print(f"验证结果: 模式={mode_type}, 成功={success}, 基本验证通过={all(validation_checks.values())}")
return {
"success": success,
"data": data,
"validation_checks": validation_checks,
"response_time": response_time,
"invalid_actions": invalid_actions,
"invalid_conditions": invalid_conditions,
"error": None,
"attempts": attempt + 1,
"http_status": response.status_code,
"mode_type": mode_type
}
except requests.exceptions.RequestException as e:
error_msg = f"请求失败: {e}"
http_status = getattr(e.response, 'status_code', None) if hasattr(e, 'response') else None
debug_print(f"请求异常: {error_msg}")
debug_print(f"HTTP状态码: {http_status}")
if hasattr(e, 'response') and e.response is not None:
try:
debug_print(f"响应内容: {e.response.text[:500]}")
except:
pass
if attempt < MAX_RETRIES - 1:
debug_print(f"等待 {RETRY_DELAY} 秒后重试...")
time.sleep(RETRY_DELAY)
continue
return {
"success": False,
"data": None,
"validation_checks": {},
"response_time": 0,
"invalid_actions": [],
"invalid_conditions": [],
"error": error_msg,
"attempts": attempt + 1,
"http_status": http_status,
"mode_type": "未知"
}
except Exception as e:
error_msg = f"未知错误: {e}"
debug_print(f"未知错误: {error_msg}")
import traceback
debug_print(f"错误堆栈: {traceback.format_exc()}")
return {
"success": False,
"data": None,
"validation_checks": {},
"response_time": 0,
"invalid_actions": [],
"invalid_conditions": [],
"error": error_msg,
"attempts": attempt + 1,
"http_status": None,
"mode_type": "未知"
}
def read_instructions(filename):
"""读取指令列表"""
instructions = []
try:
with open(filename, 'r', encoding='utf-8') as file:
for line in file:
line = line.strip()
if line and not line.startswith('#'):
instructions.append(line)
return instructions
except Exception as e:
print(f"❌ 读取指令文件时出错: {e}")
return []
def write_log_entry(log_file, instruction_idx, run_number, prompt, result):
"""写入详细日志包含完整的API响应"""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, 'a', encoding='utf-8') as f:
f.write(f"\n{'='*80}\n")
f.write(f"指令 #{instruction_idx} - 运行 #{run_number} - {timestamp}\n")
f.write(f"HTTP状态: {result.get('http_status', 'N/A')}\n")
f.write(f"原始指令: {prompt}\n")
f.write(f"模式类型: {result.get('mode_type', '未知')}\n")
f.write(f"尝试次数: {result['attempts']}\n")
f.write(f"响应时间: {result['response_time']:.2f}\n")
f.write(f"结果: {'✅ 成功' if result['success'] else '❌ 失败'}\n")
if result['success'] and result.get('data'):
data = result['data']
# 提取并记录组织后的prompt如果存在
organized_prompt = None
if isinstance(data, dict):
organized_prompt = data.get("organized_prompt") or \
data.get("processed_prompt") or \
data.get("final_prompt") or \
data.get("enhanced_prompt") or \
data.get("user_prompt_enhanced")
if organized_prompt:
f.write(f"\n📝 组织后的Prompt:\n")
f.write(f"{organized_prompt}\n")
else:
f.write(f"\n📝 组织后的Prompt: (未在响应中返回)\n")
# 记录验证结果
f.write("\n验证结果:\n")
for check_name, check_result in result['validation_checks'].items():
f.write(f" {check_name}: {'' if check_result else ''}\n")
if result['invalid_actions']:
f.write(f"⚠️ 无效动作节点: {result['invalid_actions']}\n")
if result['invalid_conditions']:
f.write(f"⚠️ 无效条件节点: {result['invalid_conditions']}\n")
# 记录完整的API响应
f.write(f"\n完整API响应:\n")
try:
response_json = json.dumps(data, indent=2, ensure_ascii=False)
f.write(response_json)
f.write("\n")
except Exception as e:
f.write(f"⚠️ 无法序列化响应数据: {e}\n")
f.write(f"原始数据: {str(data)}\n")
else:
f.write(f"错误信息: {result['error']}\n")
# 即使失败也尝试记录响应数据(如果有)
if result.get('data'):
f.write(f"\n部分响应数据:\n")
try:
response_json = json.dumps(result['data'], indent=2, ensure_ascii=False)
f.write(response_json)
f.write("\n")
except Exception:
f.write(f"原始数据: {str(result['data'])}\n")
def generate_summary_report(instructions, results_summary):
"""
生成统计摘要报告(修复除零错误)
"""
try:
with open(SUMMARY_CSV, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['instruction_index', 'instruction', 'total_runs', 'successful_runs',
'success_rate', 'avg_response_time', 'min_response_time',
'max_response_time', 'total_response_time']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for i, instruction in enumerate(instructions):
summary = results_summary[i]
success_count = summary['success_count']
# 防止除零错误
avg_time = "N/A"
min_time = "N/A"
max_time = "N/A"
if success_count > 0:
avg_time = f"{summary['total_response_time'] / success_count:.2f}s"
min_time = f"{summary['min_response_time']:.2f}s"
max_time = f"{summary['max_response_time']:.2f}s"
writer.writerow({
'instruction_index': i + 1,
'instruction': instruction,
'total_runs': TESTS_PER_INSTRUCTION,
'successful_runs': success_count,
'success_rate': f"{(success_count / TESTS_PER_INSTRUCTION * 100):.2f}%",
'avg_response_time': avg_time,
'min_response_time': min_time,
'max_response_time': max_time,
'total_response_time': f"{summary['total_response_time']:.2f}s"
})
print(f"📊 统计摘要已保存至: {SUMMARY_CSV}")
except Exception as e:
print(f"❌ 保存统计摘要时出错: {e}")
def main():
"""主测试函数"""
print("🚀 开始批量API测试")
print(f"每个指令测试 {TESTS_PER_INSTRUCTION}")
instructions = read_instructions(INSTRUCTIONS_FILE)
if not instructions:
return
print(f"找到 {len(instructions)} 条指令")
# 初始化统计
results_summary = [{
'success_count': 0,
'total_response_time': 0,
'min_response_time': float('inf'),
'max_response_time': 0,
'http_statuses': []
} for _ in instructions]
detailed_results = []
# 执行测试
for instruction_idx, prompt in enumerate(instructions, 1):
print(f"\n{'='*60}")
print(f"📋 测试指令 {instruction_idx}/{len(instructions)}")
print(f"指令: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
print(f"{'='*60}")
for run_number in range(1, TESTS_PER_INSTRUCTION + 1):
print(f" 运行 {run_number}/{TESTS_PER_INSTRUCTION}...", end=" ", flush=True)
result = send_api_request(prompt, instruction_idx, run_number)
write_log_entry(LOG_FILE, instruction_idx, run_number, prompt, result)
# 记录结果
plan_id = ""
if result.get("success") and result.get("data") and isinstance(result["data"], dict):
plan_id = result["data"].get("plan_id", "")
detailed_result = {
"instruction_index": instruction_idx,
"instruction": prompt,
"run_number": run_number,
"success": result["success"],
"attempts": result["attempts"],
"response_time": result["response_time"],
"plan_id": plan_id,
"error": result["error"] or "",
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
detailed_results.append(detailed_result)
# 更新统计
idx = instruction_idx - 1
if result["success"]:
results_summary[idx]['success_count'] += 1
results_summary[idx]['total_response_time'] += result['response_time']
results_summary[idx]['min_response_time'] = min(
results_summary[idx]['min_response_time'], result['response_time']
)
results_summary[idx]['max_response_time'] = max(
results_summary[idx]['max_response_time'], result['response_time']
)
print(f"✅ 成功 ({result['response_time']:.1f}s)")
else:
print(f"❌ 失败 (HTTP: {result.get('http_status', 'N/A')})")
# 记录HTTP状态
if 'http_status' in result:
results_summary[idx]['http_statuses'].append(result['http_status'])
time.sleep(1) # 避免服务器过载
# 生成详细结果CSV
try:
with open(RESULTS_CSV, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['instruction_index', 'instruction', 'run_number', 'success',
'attempts', 'response_time', 'plan_id', 'error', 'timestamp']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for result in detailed_results:
writer.writerow(result)
print(f"\n📊 详细结果已保存至: {RESULTS_CSV}")
except Exception as e:
print(f"❌ 保存详细结果时出错: {e}")
# 生成统计摘要
generate_summary_report(instructions, results_summary)
# 计算总统计
total_tests = len(instructions) * TESTS_PER_INSTRUCTION
total_successful = sum(summary['success_count'] for summary in results_summary)
# 打印最终统计
print(f"\n{'='*60}")
print("📈 最终测试统计")
print(f"{'='*60}")
print(f"总测试次数: {total_tests}")
print(f"成功次数: {total_successful}")
print(f"失败次数: {total_tests - total_successful}")
if total_tests > 0:
print(f"总成功率: {(total_successful / total_tests * 100):.2f}%")
else:
print(f"总成功率: N/A")
# 打印每个指令的统计
print(f"\n📋 每个指令的统计:")
for i, (instruction, summary) in enumerate(zip(instructions, results_summary), 1):
success_rate = (summary['success_count'] / TESTS_PER_INSTRUCTION * 100)
avg_time = summary['total_response_time'] / summary['success_count'] if summary['success_count'] > 0 else 0
print(f" 指令 {i}: {success_rate:.1f}% 成功 ({summary['success_count']}/{TESTS_PER_INSTRUCTION}), "
f"平均时间: {avg_time:.2f}s")
print(f"\n📁 输出文件:")
print(f"详细日志: {LOG_FILE}")
print(f"详细结果: {RESULTS_CSV}")
print(f"统计摘要: {SUMMARY_CSV}")
if __name__ == "__main__":
main()