130 lines
4.6 KiB
Python
130 lines
4.6 KiB
Python
import os
|
|
import json
|
|
import time
|
|
import csv
|
|
import shutil
|
|
from datetime import datetime
|
|
from .api_client import APIClient
|
|
from .visualizer import generate_visualization, sanitize_filename
|
|
|
|
def run_batch_test():
|
|
# 1. 选择指令文件
|
|
base_dir = os.path.dirname(os.path.dirname(__file__))
|
|
instr_dir = os.path.join(base_dir, "instructions")
|
|
files = [f for f in os.listdir(instr_dir) if f.endswith('.txt')]
|
|
|
|
if not files:
|
|
print("❌ 未在 instructions 目录下找到 .txt 文件")
|
|
return
|
|
|
|
print("\n请选择测试指令文件:")
|
|
for i, f in enumerate(files):
|
|
print(f"{i+1}. {f}")
|
|
|
|
try:
|
|
idx = int(input("请输入序号: ").strip()) - 1
|
|
if idx < 0 or idx >= len(files):
|
|
print("❌ 无效序号")
|
|
return
|
|
selected_file = os.path.join(instr_dir, files[idx])
|
|
except ValueError:
|
|
print("❌ 输入无效")
|
|
return
|
|
|
|
# 2. 配置参数
|
|
try:
|
|
iterations = int(input("请输入每条指令的测试次数 (默认1): ").strip() or "1")
|
|
except ValueError:
|
|
iterations = 1
|
|
|
|
# 3. 准备输出目录
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
output_dir = os.path.join(base_dir, "validation", timestamp)
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# 备份指令文件
|
|
shutil.copy(selected_file, os.path.join(output_dir, "instructions_backup.txt"))
|
|
|
|
# 读取指令
|
|
with open(selected_file, 'r', encoding='utf-8') as f:
|
|
instructions = [line.strip() for line in f if line.strip() and not line.startswith('#')]
|
|
|
|
print(f"\n🚀 开始批量测试 (共 {len(instructions)} 条指令, 每条 {iterations} 次)")
|
|
print(f"📂 输出目录: {output_dir}\n")
|
|
|
|
client = APIClient()
|
|
detailed_results = []
|
|
summary_stats = {}
|
|
|
|
for i, prompt in enumerate(instructions, 1):
|
|
print(f"[{i}/{len(instructions)}] 测试指令: {prompt[:30]}...")
|
|
safe_name = sanitize_filename(prompt)
|
|
instr_out_dir = os.path.join(output_dir, safe_name)
|
|
os.makedirs(instr_out_dir, exist_ok=True)
|
|
|
|
success_count = 0
|
|
total_latency = 0
|
|
|
|
for k in range(1, iterations + 1):
|
|
print(f" - 第 {k} 次...", end="", flush=True)
|
|
result = client.send_request(prompt)
|
|
|
|
# 记录详情
|
|
detailed_results.append({
|
|
"instruction": prompt,
|
|
"run_id": k,
|
|
"success": result['success'],
|
|
"latency": result['latency'],
|
|
"error": result['error'] or ""
|
|
})
|
|
|
|
# 保存产物
|
|
if result['success']:
|
|
print(f" ✅ ({result['latency']:.2f}s)")
|
|
success_count += 1
|
|
total_latency += result['latency']
|
|
|
|
# 保存JSON
|
|
with open(os.path.join(instr_out_dir, f"{k}.json"), 'w', encoding='utf-8') as f:
|
|
json.dump(result['data'], f, indent=2, ensure_ascii=False)
|
|
|
|
# 保存图片
|
|
if result['data'] and 'root' in result['data']:
|
|
generate_visualization(result['data']['root'], os.path.join(instr_out_dir, f"{k}.png"))
|
|
else:
|
|
print(f" ❌ {result['error']}")
|
|
|
|
time.sleep(0.5) # 避免过快请求
|
|
|
|
# 统计单条指令
|
|
avg_lat = total_latency / success_count if success_count > 0 else 0
|
|
summary_stats[prompt] = {
|
|
"total_runs": iterations,
|
|
"success_runs": success_count,
|
|
"success_rate": f"{(success_count/iterations)*100:.1f}%",
|
|
"avg_latency": f"{avg_lat:.2f}s"
|
|
}
|
|
|
|
# 4. 生成报告
|
|
# 详细报告
|
|
with open(os.path.join(output_dir, "test_details.csv"), 'w', newline='', encoding='utf-8') as f:
|
|
writer = csv.DictWriter(f, fieldnames=["instruction", "run_id", "success", "latency", "error"])
|
|
writer.writeheader()
|
|
writer.writerows(detailed_results)
|
|
|
|
# 摘要报告
|
|
with open(os.path.join(output_dir, "test_summary.csv"), 'w', newline='', encoding='utf-8') as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow(["Instruction", "Total Runs", "Success Runs", "Success Rate", "Avg Latency"])
|
|
for prompt, stats in summary_stats.items():
|
|
writer.writerow([
|
|
prompt,
|
|
stats["total_runs"],
|
|
stats["success_runs"],
|
|
stats["success_rate"],
|
|
stats["avg_latency"]
|
|
])
|
|
|
|
print(f"\n✅ 测试完成! 统计报告已保存至 {output_dir}")
|
|
|