437 lines
17 KiB
Python
437 lines
17 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
import requests
|
||
import json
|
||
import csv
|
||
import time
|
||
from datetime import datetime
|
||
import os
|
||
import re
|
||
|
||
# --- Configuration ---
|
||
BASE_URL = "http://127.0.0.1:8000"
|
||
ENDPOINT = "/generate_plan"
|
||
INSTRUCTIONS_FILE = "instructions.txt"
|
||
RESULTS_CSV = "test_results.csv"
|
||
SUMMARY_CSV = "test_summary.csv"
|
||
LOG_FILE = "api_test_log.txt"
|
||
|
||
# 测试参数
|
||
TESTS_PER_INSTRUCTION = 1
|
||
MAX_RETRIES = 3
|
||
RETRY_DELAY = 2
|
||
|
||
# 添加调试模式
|
||
DEBUG = True
|
||
|
||
def debug_print(message):
|
||
"""调试输出"""
|
||
if DEBUG:
|
||
print(f"🐛 DEBUG: {message}")
|
||
|
||
def check_safety_monitoring(node):
|
||
"""简化安全监控检查"""
|
||
has_battery = False
|
||
has_emergency = False
|
||
|
||
def check_node(current_node):
|
||
nonlocal has_battery, has_emergency
|
||
|
||
# 检查电池相关条件
|
||
if (current_node.get('type') == 'condition' and
|
||
'battery' in str(current_node.get('name', '')).lower()):
|
||
has_battery = True
|
||
|
||
# 检查紧急动作
|
||
if (current_node.get('type') == 'action' and
|
||
any(keyword in str(current_node.get('name', '')).lower()
|
||
for keyword in ['emergency', 'safe', 'land'])):
|
||
has_emergency = True
|
||
|
||
for child in current_node.get('children', []):
|
||
check_node(child)
|
||
|
||
check_node(node)
|
||
return has_battery or has_emergency # 放宽要求
|
||
|
||
def check_leaf_nodes(node, depth=0, max_depth=50):
|
||
"""检查节点结构"""
|
||
if depth > max_depth:
|
||
return True # 不因深度限制而失败
|
||
|
||
# 动作和条件节点不应该有子节点
|
||
if node.get('type') in ['action', 'condition']:
|
||
return 'children' not in node or not node['children']
|
||
|
||
# 控制流节点应该有子节点
|
||
if node.get('type') in ['Sequence', 'Selector', 'Parallel']:
|
||
if 'children' not in node or not node['children']:
|
||
return False
|
||
|
||
# 递归检查
|
||
for child in node.get('children', []):
|
||
if not check_leaf_nodes(child, depth + 1, max_depth):
|
||
return False
|
||
|
||
return True
|
||
|
||
def send_api_request(prompt, instruction_idx, run_number):
|
||
"""发送API请求并返回结果"""
|
||
url = BASE_URL + ENDPOINT
|
||
payload = {"user_prompt": prompt}
|
||
headers = {"Content-Type": "application/json"}
|
||
|
||
for attempt in range(MAX_RETRIES):
|
||
try:
|
||
debug_print(f"指令 {instruction_idx}-{run_number} 尝试 {attempt + 1}")
|
||
start_time = time.time()
|
||
response = requests.post(url, data=json.dumps(payload), headers=headers, timeout=60) # 增加超时
|
||
response_time = time.time() - start_time
|
||
|
||
# 首先检查HTTP状态
|
||
response.raise_for_status()
|
||
|
||
# 尝试解析JSON
|
||
try:
|
||
data = response.json()
|
||
except json.JSONDecodeError as e:
|
||
debug_print(f"JSON解析失败: {e}, 响应文本: {response.text[:200]}")
|
||
raise
|
||
|
||
root_node = data.get('root', {})
|
||
|
||
# 基本验证 - 放宽要求
|
||
validation_checks = {
|
||
"is_dict": isinstance(data, dict),
|
||
"has_root": "root" in data,
|
||
"root_has_children": bool(root_node.get('children')),
|
||
"has_plan_id": "plan_id" in data,
|
||
"has_visualization_url": "visualization_url" in data,
|
||
}
|
||
|
||
# 可选的高级验证
|
||
advanced_checks = {
|
||
"leaf_nodes_valid": check_leaf_nodes(root_node),
|
||
"has_safety": check_safety_monitoring(root_node)
|
||
}
|
||
|
||
# 合并验证结果
|
||
validation_checks.update(advanced_checks)
|
||
|
||
# 统计无效节点但不作为失败条件
|
||
invalid_actions = []
|
||
invalid_conditions = []
|
||
|
||
def collect_nodes(current_node):
|
||
if current_node.get('type') == 'action':
|
||
action_name = current_node.get('name', '')
|
||
if action_name not in ['deliver_payload', 'emergency_return', 'fly_to_waypoint',
|
||
'land', 'loiter', 'object_detect', 'preflight_checks',
|
||
'search_pattern', 'strike_target', 'battle_damage_assessment', 'takeoff']:
|
||
invalid_actions.append(action_name)
|
||
|
||
elif current_node.get('type') == 'condition':
|
||
condition_name = current_node.get('name', '')
|
||
if condition_name not in ['battery_above', 'at_waypoint', 'object_detected',
|
||
'target_destroyed', 'time_elapsed', 'gps_status']:
|
||
invalid_conditions.append(condition_name)
|
||
|
||
for child in current_node.get('children', []):
|
||
collect_nodes(child)
|
||
|
||
collect_nodes(root_node)
|
||
|
||
# 主要检查基本验证,高级验证作为警告
|
||
success = all(validation_checks[k] for k in ["is_dict", "has_root", "root_has_children",
|
||
"has_plan_id", "has_visualization_url"])
|
||
|
||
debug_print(f"验证结果: 成功={success}, 基本验证通过={all(validation_checks.values())}")
|
||
|
||
return {
|
||
"success": success,
|
||
"data": data,
|
||
"validation_checks": validation_checks,
|
||
"response_time": response_time,
|
||
"invalid_actions": invalid_actions,
|
||
"invalid_conditions": invalid_conditions,
|
||
"error": None,
|
||
"attempts": attempt + 1,
|
||
"http_status": response.status_code
|
||
}
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
error_msg = f"请求失败: {e}"
|
||
debug_print(f"请求异常: {error_msg}")
|
||
if attempt < MAX_RETRIES - 1:
|
||
time.sleep(RETRY_DELAY)
|
||
continue
|
||
return {
|
||
"success": False,
|
||
"data": None,
|
||
"validation_checks": {},
|
||
"response_time": 0,
|
||
"invalid_actions": [],
|
||
"invalid_conditions": [],
|
||
"error": error_msg,
|
||
"attempts": attempt + 1,
|
||
"http_status": getattr(e.response, 'status_code', None) if hasattr(e, 'response') else None
|
||
}
|
||
|
||
except Exception as e:
|
||
error_msg = f"未知错误: {e}"
|
||
debug_print(f"未知错误: {error_msg}")
|
||
return {
|
||
"success": False,
|
||
"data": None,
|
||
"validation_checks": {},
|
||
"response_time": 0,
|
||
"invalid_actions": [],
|
||
"invalid_conditions": [],
|
||
"error": error_msg,
|
||
"attempts": attempt + 1,
|
||
"http_status": None
|
||
}
|
||
|
||
def read_instructions(filename):
|
||
"""读取指令列表"""
|
||
instructions = []
|
||
try:
|
||
with open(filename, 'r', encoding='utf-8') as file:
|
||
for line in file:
|
||
line = line.strip()
|
||
if line and not line.startswith('#'):
|
||
instructions.append(line)
|
||
return instructions
|
||
except Exception as e:
|
||
print(f"❌ 读取指令文件时出错: {e}")
|
||
return []
|
||
|
||
def write_log_entry(log_file, instruction_idx, run_number, prompt, result):
|
||
"""写入详细日志,包含完整的API响应"""
|
||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
with open(log_file, 'a', encoding='utf-8') as f:
|
||
f.write(f"\n{'='*80}\n")
|
||
f.write(f"指令 #{instruction_idx} - 运行 #{run_number} - {timestamp}\n")
|
||
f.write(f"HTTP状态: {result.get('http_status', 'N/A')}\n")
|
||
f.write(f"原始指令: {prompt}\n")
|
||
f.write(f"尝试次数: {result['attempts']}\n")
|
||
f.write(f"响应时间: {result['response_time']:.2f}秒\n")
|
||
f.write(f"结果: {'✅ 成功' if result['success'] else '❌ 失败'}\n")
|
||
|
||
if result['success'] and result.get('data'):
|
||
data = result['data']
|
||
|
||
# 提取并记录组织后的prompt(如果存在)
|
||
organized_prompt = None
|
||
if isinstance(data, dict):
|
||
organized_prompt = data.get("organized_prompt") or \
|
||
data.get("processed_prompt") or \
|
||
data.get("final_prompt") or \
|
||
data.get("enhanced_prompt") or \
|
||
data.get("user_prompt_enhanced")
|
||
|
||
if organized_prompt:
|
||
f.write(f"\n📝 组织后的Prompt:\n")
|
||
f.write(f"{organized_prompt}\n")
|
||
else:
|
||
f.write(f"\n📝 组织后的Prompt: (未在响应中返回)\n")
|
||
|
||
# 记录验证结果
|
||
f.write("\n验证结果:\n")
|
||
for check_name, check_result in result['validation_checks'].items():
|
||
f.write(f" {check_name}: {'✅' if check_result else '❌'}\n")
|
||
|
||
if result['invalid_actions']:
|
||
f.write(f"⚠️ 无效动作节点: {result['invalid_actions']}\n")
|
||
|
||
if result['invalid_conditions']:
|
||
f.write(f"⚠️ 无效条件节点: {result['invalid_conditions']}\n")
|
||
|
||
# 记录完整的API响应
|
||
f.write(f"\n完整API响应:\n")
|
||
try:
|
||
response_json = json.dumps(data, indent=2, ensure_ascii=False)
|
||
f.write(response_json)
|
||
f.write("\n")
|
||
except Exception as e:
|
||
f.write(f"⚠️ 无法序列化响应数据: {e}\n")
|
||
f.write(f"原始数据: {str(data)}\n")
|
||
else:
|
||
f.write(f"错误信息: {result['error']}\n")
|
||
# 即使失败也尝试记录响应数据(如果有)
|
||
if result.get('data'):
|
||
f.write(f"\n部分响应数据:\n")
|
||
try:
|
||
response_json = json.dumps(result['data'], indent=2, ensure_ascii=False)
|
||
f.write(response_json)
|
||
f.write("\n")
|
||
except Exception:
|
||
f.write(f"原始数据: {str(result['data'])}\n")
|
||
|
||
def generate_summary_report(instructions, results_summary):
|
||
"""
|
||
生成统计摘要报告(修复除零错误)
|
||
"""
|
||
try:
|
||
with open(SUMMARY_CSV, 'w', newline='', encoding='utf-8') as csvfile:
|
||
fieldnames = ['instruction_index', 'instruction', 'total_runs', 'successful_runs',
|
||
'success_rate', 'avg_response_time', 'min_response_time',
|
||
'max_response_time', 'total_response_time']
|
||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||
|
||
writer.writeheader()
|
||
|
||
for i, instruction in enumerate(instructions):
|
||
summary = results_summary[i]
|
||
success_count = summary['success_count']
|
||
|
||
# 防止除零错误
|
||
avg_time = "N/A"
|
||
min_time = "N/A"
|
||
max_time = "N/A"
|
||
|
||
if success_count > 0:
|
||
avg_time = f"{summary['total_response_time'] / success_count:.2f}s"
|
||
min_time = f"{summary['min_response_time']:.2f}s"
|
||
max_time = f"{summary['max_response_time']:.2f}s"
|
||
|
||
writer.writerow({
|
||
'instruction_index': i + 1,
|
||
'instruction': instruction,
|
||
'total_runs': TESTS_PER_INSTRUCTION,
|
||
'successful_runs': success_count,
|
||
'success_rate': f"{(success_count / TESTS_PER_INSTRUCTION * 100):.2f}%",
|
||
'avg_response_time': avg_time,
|
||
'min_response_time': min_time,
|
||
'max_response_time': max_time,
|
||
'total_response_time': f"{summary['total_response_time']:.2f}s"
|
||
})
|
||
|
||
print(f"📊 统计摘要已保存至: {SUMMARY_CSV}")
|
||
except Exception as e:
|
||
print(f"❌ 保存统计摘要时出错: {e}")
|
||
|
||
def main():
|
||
"""主测试函数"""
|
||
print("🚀 开始批量API测试")
|
||
print(f"每个指令测试 {TESTS_PER_INSTRUCTION} 次")
|
||
|
||
instructions = read_instructions(INSTRUCTIONS_FILE)
|
||
if not instructions:
|
||
return
|
||
|
||
print(f"找到 {len(instructions)} 条指令")
|
||
|
||
# 初始化统计
|
||
results_summary = [{
|
||
'success_count': 0,
|
||
'total_response_time': 0,
|
||
'min_response_time': float('inf'),
|
||
'max_response_time': 0,
|
||
'http_statuses': []
|
||
} for _ in instructions]
|
||
|
||
detailed_results = []
|
||
|
||
# 执行测试
|
||
for instruction_idx, prompt in enumerate(instructions, 1):
|
||
print(f"\n{'='*60}")
|
||
print(f"📋 测试指令 {instruction_idx}/{len(instructions)}")
|
||
print(f"指令: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
|
||
print(f"{'='*60}")
|
||
|
||
for run_number in range(1, TESTS_PER_INSTRUCTION + 1):
|
||
print(f" 运行 {run_number}/{TESTS_PER_INSTRUCTION}...", end=" ", flush=True)
|
||
|
||
result = send_api_request(prompt, instruction_idx, run_number)
|
||
write_log_entry(LOG_FILE, instruction_idx, run_number, prompt, result)
|
||
|
||
# 记录结果
|
||
plan_id = ""
|
||
if result.get("success") and result.get("data") and isinstance(result["data"], dict):
|
||
plan_id = result["data"].get("plan_id", "")
|
||
|
||
detailed_result = {
|
||
"instruction_index": instruction_idx,
|
||
"instruction": prompt,
|
||
"run_number": run_number,
|
||
"success": result["success"],
|
||
"attempts": result["attempts"],
|
||
"response_time": result["response_time"],
|
||
"plan_id": plan_id,
|
||
"error": result["error"] or "",
|
||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
}
|
||
detailed_results.append(detailed_result)
|
||
|
||
# 更新统计
|
||
idx = instruction_idx - 1
|
||
if result["success"]:
|
||
results_summary[idx]['success_count'] += 1
|
||
results_summary[idx]['total_response_time'] += result['response_time']
|
||
results_summary[idx]['min_response_time'] = min(
|
||
results_summary[idx]['min_response_time'], result['response_time']
|
||
)
|
||
results_summary[idx]['max_response_time'] = max(
|
||
results_summary[idx]['max_response_time'], result['response_time']
|
||
)
|
||
print(f"✅ 成功 ({result['response_time']:.1f}s)")
|
||
else:
|
||
print(f"❌ 失败 (HTTP: {result.get('http_status', 'N/A')})")
|
||
|
||
# 记录HTTP状态
|
||
if 'http_status' in result:
|
||
results_summary[idx]['http_statuses'].append(result['http_status'])
|
||
|
||
time.sleep(1) # 避免服务器过载
|
||
|
||
# 生成详细结果CSV
|
||
try:
|
||
with open(RESULTS_CSV, 'w', newline='', encoding='utf-8') as csvfile:
|
||
fieldnames = ['instruction_index', 'instruction', 'run_number', 'success',
|
||
'attempts', 'response_time', 'plan_id', 'error', 'timestamp']
|
||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||
|
||
writer.writeheader()
|
||
for result in detailed_results:
|
||
writer.writerow(result)
|
||
|
||
print(f"\n📊 详细结果已保存至: {RESULTS_CSV}")
|
||
except Exception as e:
|
||
print(f"❌ 保存详细结果时出错: {e}")
|
||
|
||
# 生成统计摘要
|
||
generate_summary_report(instructions, results_summary)
|
||
|
||
# 计算总统计
|
||
total_tests = len(instructions) * TESTS_PER_INSTRUCTION
|
||
total_successful = sum(summary['success_count'] for summary in results_summary)
|
||
|
||
# 打印最终统计
|
||
print(f"\n{'='*60}")
|
||
print("📈 最终测试统计")
|
||
print(f"{'='*60}")
|
||
print(f"总测试次数: {total_tests}")
|
||
print(f"成功次数: {total_successful}")
|
||
print(f"失败次数: {total_tests - total_successful}")
|
||
if total_tests > 0:
|
||
print(f"总成功率: {(total_successful / total_tests * 100):.2f}%")
|
||
else:
|
||
print(f"总成功率: N/A")
|
||
|
||
# 打印每个指令的统计
|
||
print(f"\n📋 每个指令的统计:")
|
||
for i, (instruction, summary) in enumerate(zip(instructions, results_summary), 1):
|
||
success_rate = (summary['success_count'] / TESTS_PER_INSTRUCTION * 100)
|
||
avg_time = summary['total_response_time'] / summary['success_count'] if summary['success_count'] > 0 else 0
|
||
print(f" 指令 {i}: {success_rate:.1f}% 成功 ({summary['success_count']}/{TESTS_PER_INSTRUCTION}), "
|
||
f"平均时间: {avg_time:.2f}s")
|
||
|
||
print(f"\n📁 输出文件:")
|
||
print(f"详细日志: {LOG_FILE}")
|
||
print(f"详细结果: {RESULTS_CSV}")
|
||
print(f"统计摘要: {SUMMARY_CSV}")
|
||
|
||
if __name__ == "__main__":
|
||
main() |