#!/usr/bin/env python # -*- coding: utf-8 -*- import requests import json import csv import time from datetime import datetime import os import re # --- Configuration --- BASE_URL = "http://127.0.0.1:8000" ENDPOINT = "/generate_plan" INSTRUCTIONS_FILE = "instructions.txt" RESULTS_CSV = "test_results.csv" SUMMARY_CSV = "test_summary.csv" LOG_FILE = "api_test_log.txt" # 测试参数 TESTS_PER_INSTRUCTION = 10 MAX_RETRIES = 3 RETRY_DELAY = 2 # 添加调试模式 DEBUG = True def debug_print(message): """调试输出""" if DEBUG: print(f"🐛 DEBUG: {message}") def check_safety_monitoring(node): """简化安全监控检查""" has_battery = False has_emergency = False def check_node(current_node): nonlocal has_battery, has_emergency # 检查电池相关条件 if (current_node.get('type') == 'condition' and 'battery' in str(current_node.get('name', '')).lower()): has_battery = True # 检查紧急动作 if (current_node.get('type') == 'action' and any(keyword in str(current_node.get('name', '')).lower() for keyword in ['emergency', 'safe', 'land'])): has_emergency = True for child in current_node.get('children', []): check_node(child) check_node(node) return has_battery or has_emergency # 放宽要求 def check_leaf_nodes(node, depth=0, max_depth=50): """检查节点结构""" if depth > max_depth: return True # 不因深度限制而失败 # 动作和条件节点不应该有子节点 if node.get('type') in ['action', 'condition']: return 'children' not in node or not node['children'] # 控制流节点应该有子节点 if node.get('type') in ['Sequence', 'Selector', 'Parallel']: if 'children' not in node or not node['children']: return False # 递归检查 for child in node.get('children', []): if not check_leaf_nodes(child, depth + 1, max_depth): return False return True def send_api_request(prompt, instruction_idx, run_number): """发送API请求并返回结果""" url = BASE_URL + ENDPOINT payload = {"user_prompt": prompt} headers = {"Content-Type": "application/json"} for attempt in range(MAX_RETRIES): try: debug_print(f"指令 {instruction_idx}-{run_number} 尝试 {attempt + 1}") start_time = time.time() response = requests.post(url, data=json.dumps(payload), headers=headers, timeout=60) # 增加超时 response_time = time.time() - start_time # 首先检查HTTP状态 response.raise_for_status() # 尝试解析JSON try: data = response.json() except json.JSONDecodeError as e: debug_print(f"JSON解析失败: {e}, 响应文本: {response.text[:200]}") raise root_node = data.get('root', {}) # 基本验证 - 放宽要求 validation_checks = { "is_dict": isinstance(data, dict), "has_root": "root" in data, "root_has_children": bool(root_node.get('children')), "has_plan_id": "plan_id" in data, "has_visualization_url": "visualization_url" in data, } # 可选的高级验证 advanced_checks = { "leaf_nodes_valid": check_leaf_nodes(root_node), "has_safety": check_safety_monitoring(root_node) } # 合并验证结果 validation_checks.update(advanced_checks) # 统计无效节点但不作为失败条件 invalid_actions = [] invalid_conditions = [] def collect_nodes(current_node): if current_node.get('type') == 'action': action_name = current_node.get('name', '') if action_name not in ['deliver_payload', 'emergency_return', 'fly_to_waypoint', 'land', 'loiter', 'object_detect', 'preflight_checks', 'search_pattern', 'strike_target', 'battle_damage_assessment', 'takeoff']: invalid_actions.append(action_name) elif current_node.get('type') == 'condition': condition_name = current_node.get('name', '') if condition_name not in ['battery_above', 'at_waypoint', 'object_detected', 'target_destroyed', 'time_elapsed', 'gps_status']: invalid_conditions.append(condition_name) for child in current_node.get('children', []): collect_nodes(child) collect_nodes(root_node) # 主要检查基本验证,高级验证作为警告 success = all(validation_checks[k] for k in ["is_dict", "has_root", "root_has_children", "has_plan_id", "has_visualization_url"]) debug_print(f"验证结果: 成功={success}, 基本验证通过={all(validation_checks.values())}") return { "success": success, "data": data, "validation_checks": validation_checks, "response_time": response_time, "invalid_actions": invalid_actions, "invalid_conditions": invalid_conditions, "error": None, "attempts": attempt + 1, "http_status": response.status_code } except requests.exceptions.RequestException as e: error_msg = f"请求失败: {e}" debug_print(f"请求异常: {error_msg}") if attempt < MAX_RETRIES - 1: time.sleep(RETRY_DELAY) continue return { "success": False, "data": None, "validation_checks": {}, "response_time": 0, "invalid_actions": [], "invalid_conditions": [], "error": error_msg, "attempts": attempt + 1, "http_status": getattr(e.response, 'status_code', None) if hasattr(e, 'response') else None } except Exception as e: error_msg = f"未知错误: {e}" debug_print(f"未知错误: {error_msg}") return { "success": False, "data": None, "validation_checks": {}, "response_time": 0, "invalid_actions": [], "invalid_conditions": [], "error": error_msg, "attempts": attempt + 1, "http_status": None } def read_instructions(filename): """读取指令列表""" instructions = [] try: with open(filename, 'r', encoding='utf-8') as file: for line in file: line = line.strip() if line and not line.startswith('#'): instructions.append(line) return instructions except Exception as e: print(f"❌ 读取指令文件时出错: {e}") return [] def write_log_entry(log_file, instruction_idx, run_number, prompt, result): """写入详细日志""" timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") with open(log_file, 'a', encoding='utf-8') as f: f.write(f"\n{'='*80}\n") f.write(f"指令 #{instruction_idx} - 运行 #{run_number} - {timestamp}\n") f.write(f"HTTP状态: {result.get('http_status', 'N/A')}\n") f.write(f"指令: {prompt}\n") f.write(f"尝试次数: {result['attempts']}\n") f.write(f"响应时间: {result['response_time']:.2f}秒\n") f.write(f"结果: {'✅ 成功' if result['success'] else '❌ 失败'}\n") if result['success']: f.write("验证结果:\n") for check_name, check_result in result['validation_checks'].items(): f.write(f" {check_name}: {'✅' if check_result else '❌'}\n") if result['invalid_actions']: f.write(f"⚠️ 无效动作节点: {result['invalid_actions']}\n") if result['invalid_conditions']: f.write(f"⚠️ 无效条件节点: {result['invalid_conditions']}\n") else: f.write(f"错误信息: {result['error']}\n") def generate_summary_report(instructions, results_summary, summary_csv_path): """ 生成统计摘要报告 """ try: with open(summary_csv_path, 'w', newline='', encoding='utf-8') as csvfile: fieldnames = ['instruction_index', 'instruction', 'total_runs', 'successful_runs', 'success_rate', 'avg_response_time', 'min_response_time', 'max_response_time', 'total_response_time'] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() for i, instruction in enumerate(instructions): summary = results_summary[i] success_count = summary['success_count'] avg_time = "N/A" min_time = "N/A" max_time = "N/A" if success_count > 0: avg_time = f"{summary['total_response_time'] / success_count:.2f}s" min_time = f"{summary['min_response_time']:.2f}s" max_time = f"{summary['max_response_time']:.2f}s" writer.writerow({ 'instruction_index': i + 1, 'instruction': instruction, 'total_runs': TESTS_PER_INSTRUCTION, 'successful_runs': success_count, 'success_rate': f"{(success_count / TESTS_PER_INSTRUCTION * 100):.2f}%", 'avg_response_time': avg_time, 'min_response_time': min_time, 'max_response_time': max_time, 'total_response_time': f"{summary['total_response_time']:.2f}s" }) print(f"📊 统计摘要已保存至: {summary_csv_path}") except Exception as e: print(f"❌ 保存统计摘要时出错: {e}") def main(): """主测试函数""" # 创建带时间戳的结果目录 timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") RESULTS_DIR = f"test_run_{timestamp_str}" os.makedirs(RESULTS_DIR, exist_ok=True) # 定义顶层结果文件的完整路径 log_file_path = os.path.join(RESULTS_DIR, "api_test_log.txt") results_csv_path = os.path.join(RESULTS_DIR, "test_results.csv") summary_csv_path = os.path.join(RESULTS_DIR, "test_summary.csv") print("🚀 开始批量API测试") print(f"每个指令测试 {TESTS_PER_INSTRUCTION} 次") print(f"所有结果将保存在: {RESULTS_DIR}") instructions = read_instructions(INSTRUCTIONS_FILE) if not instructions: return print(f"找到 {len(instructions)} 条指令") # 初始化统计 results_summary = [{ 'success_count': 0, 'total_response_time': 0, 'min_response_time': float('inf'), 'max_response_time': 0, 'http_statuses': [] } for _ in instructions] detailed_results = [] # 执行测试 for instruction_idx, prompt in enumerate(instructions, 1): print(f"\n{'='*60}") print(f"📋 测试指令 {instruction_idx}/{len(instructions)}") print(f"指令: {prompt[:80]}{'...' if len(prompt) > 80 else ''}") print(f"{'='*60}") # 为当前指令创建子目录 instruction_dir = os.path.join(RESULTS_DIR, f"instruction_{instruction_idx}") os.makedirs(instruction_dir, exist_ok=True) for run_number in range(1, TESTS_PER_INSTRUCTION + 1): print(f" 运行 {run_number}/{TESTS_PER_INSTRUCTION}...", end=" ", flush=True) result = send_api_request(prompt, instruction_idx, run_number) write_log_entry(log_file_path, instruction_idx, run_number, prompt, result) plan_id = result.get('data', {}).get('plan_id') if result.get('data') else None # 记录结果 detailed_result = { "instruction_index": instruction_idx, "instruction": prompt, "run_number": run_number, "success": result["success"], "attempts": result["attempts"], "response_time": result["response_time"], "plan_id": plan_id, "http_status": result.get("http_status"), "error": result["error"] or "", "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") } detailed_results.append(detailed_result) # 更新统计 idx = instruction_idx - 1 if result["success"]: results_summary[idx]['success_count'] += 1 results_summary[idx]['total_response_time'] += result['response_time'] results_summary[idx]['min_response_time'] = min( results_summary[idx]['min_response_time'], result['response_time'] ) results_summary[idx]['max_response_time'] = max( results_summary[idx]['max_response_time'], result['response_time'] ) # 保存JSON和可视化图片 plan_id_str = plan_id or 'unknown_plan' # 1. 保存JSON json_filename = os.path.join(instruction_dir, f"run_{run_number}_{plan_id_str}.json") try: with open(json_filename, 'w', encoding='utf-8') as f: json.dump(result['data'], f, indent=4, ensure_ascii=False) except Exception as e: print(f" ⚠️ 保存JSON失败: {e}") # 2. 下载并保存图片 viz_url = result['data'].get('visualization_url') if viz_url: img_filename = os.path.join(instruction_dir, f"run_{run_number}_{plan_id_str}.png") try: if viz_url.startswith('/'): viz_url = BASE_URL + viz_url img_response = requests.get(viz_url) img_response.raise_for_status() with open(img_filename, 'wb') as f: f.write(img_response.content) except requests.exceptions.RequestException as e: print(f" ⚠️ 下载图片失败: {e}") print(f"✅ 成功 ({result['response_time']:.1f}s)") else: print(f"❌ 失败 (HTTP: {result.get('http_status', 'N/A')})") if 'http_status' in result: results_summary[idx]['http_statuses'].append(result['http_status']) time.sleep(1) # 生成详细结果CSV try: with open(results_csv_path, 'w', newline='', encoding='utf-8') as csvfile: fieldnames = ['instruction_index', 'instruction', 'run_number', 'success', 'attempts', 'response_time', 'plan_id', 'error', 'timestamp'] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() for res in detailed_results: writer.writerow(res) print(f"\n📊 详细结果已保存至: {results_csv_path}") except Exception as e: print(f"❌ 保存详细结果时出错: {e}") # 生成统计摘要 generate_summary_report(instructions, results_summary, summary_csv_path) # 计算并打印最终统计 total_tests = len(instructions) * TESTS_PER_INSTRUCTION total_successful = sum(s['success_count'] for s in results_summary) print(f"\n{'='*60}") print("📈 最终测试统计") print(f"{'='*60}") if total_tests > 0: print(f"总测试次数: {total_tests}") print(f"成功次数: {total_successful}") print(f"失败次数: {total_tests - total_successful}") print(f"总成功率: {(total_successful / total_tests * 100):.2f}%") print(f"\n📋 每个指令的统计:") for i, (instruction, summary) in enumerate(zip(instructions, results_summary), 1): success_rate = (summary['success_count'] / TESTS_PER_INSTRUCTION * 100) avg_time = summary['total_response_time'] / summary['success_count'] if summary['success_count'] > 0 else 0 print(f" 指令 {i}: {success_rate:.1f}% 成功 ({summary['success_count']}/{TESTS_PER_INSTRUCTION}), " f"平均时间: {avg_time:.2f}s") print(f"\n📁 输出文件见: {RESULTS_DIR}") if __name__ == "__main__": main()