优化交互式测试验证脚本,针对场景4修改提示词以及代码
This commit is contained in:
129
tools/test_validate/modules/batch_runner.py
Normal file
129
tools/test_validate/modules/batch_runner.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import csv
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from .api_client import APIClient
|
||||
from .visualizer import generate_visualization, sanitize_filename
|
||||
|
||||
def run_batch_test():
|
||||
# 1. 选择指令文件
|
||||
base_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
instr_dir = os.path.join(base_dir, "instructions")
|
||||
files = [f for f in os.listdir(instr_dir) if f.endswith('.txt')]
|
||||
|
||||
if not files:
|
||||
print("❌ 未在 instructions 目录下找到 .txt 文件")
|
||||
return
|
||||
|
||||
print("\n请选择测试指令文件:")
|
||||
for i, f in enumerate(files):
|
||||
print(f"{i+1}. {f}")
|
||||
|
||||
try:
|
||||
idx = int(input("请输入序号: ").strip()) - 1
|
||||
if idx < 0 or idx >= len(files):
|
||||
print("❌ 无效序号")
|
||||
return
|
||||
selected_file = os.path.join(instr_dir, files[idx])
|
||||
except ValueError:
|
||||
print("❌ 输入无效")
|
||||
return
|
||||
|
||||
# 2. 配置参数
|
||||
try:
|
||||
iterations = int(input("请输入每条指令的测试次数 (默认1): ").strip() or "1")
|
||||
except ValueError:
|
||||
iterations = 1
|
||||
|
||||
# 3. 准备输出目录
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
output_dir = os.path.join(base_dir, "validation", timestamp)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 备份指令文件
|
||||
shutil.copy(selected_file, os.path.join(output_dir, "instructions_backup.txt"))
|
||||
|
||||
# 读取指令
|
||||
with open(selected_file, 'r', encoding='utf-8') as f:
|
||||
instructions = [line.strip() for line in f if line.strip() and not line.startswith('#')]
|
||||
|
||||
print(f"\n🚀 开始批量测试 (共 {len(instructions)} 条指令, 每条 {iterations} 次)")
|
||||
print(f"📂 输出目录: {output_dir}\n")
|
||||
|
||||
client = APIClient()
|
||||
detailed_results = []
|
||||
summary_stats = {}
|
||||
|
||||
for i, prompt in enumerate(instructions, 1):
|
||||
print(f"[{i}/{len(instructions)}] 测试指令: {prompt[:30]}...")
|
||||
safe_name = sanitize_filename(prompt)
|
||||
instr_out_dir = os.path.join(output_dir, safe_name)
|
||||
os.makedirs(instr_out_dir, exist_ok=True)
|
||||
|
||||
success_count = 0
|
||||
total_latency = 0
|
||||
|
||||
for k in range(1, iterations + 1):
|
||||
print(f" - 第 {k} 次...", end="", flush=True)
|
||||
result = client.send_request(prompt)
|
||||
|
||||
# 记录详情
|
||||
detailed_results.append({
|
||||
"instruction": prompt,
|
||||
"run_id": k,
|
||||
"success": result['success'],
|
||||
"latency": result['latency'],
|
||||
"error": result['error'] or ""
|
||||
})
|
||||
|
||||
# 保存产物
|
||||
if result['success']:
|
||||
print(f" ✅ ({result['latency']:.2f}s)")
|
||||
success_count += 1
|
||||
total_latency += result['latency']
|
||||
|
||||
# 保存JSON
|
||||
with open(os.path.join(instr_out_dir, f"{k}.json"), 'w', encoding='utf-8') as f:
|
||||
json.dump(result['data'], f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 保存图片
|
||||
if result['data'] and 'root' in result['data']:
|
||||
generate_visualization(result['data']['root'], os.path.join(instr_out_dir, f"{k}.png"))
|
||||
else:
|
||||
print(f" ❌ {result['error']}")
|
||||
|
||||
time.sleep(0.5) # 避免过快请求
|
||||
|
||||
# 统计单条指令
|
||||
avg_lat = total_latency / success_count if success_count > 0 else 0
|
||||
summary_stats[prompt] = {
|
||||
"total_runs": iterations,
|
||||
"success_runs": success_count,
|
||||
"success_rate": f"{(success_count/iterations)*100:.1f}%",
|
||||
"avg_latency": f"{avg_lat:.2f}s"
|
||||
}
|
||||
|
||||
# 4. 生成报告
|
||||
# 详细报告
|
||||
with open(os.path.join(output_dir, "test_details.csv"), 'w', newline='', encoding='utf-8') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=["instruction", "run_id", "success", "latency", "error"])
|
||||
writer.writeheader()
|
||||
writer.writerows(detailed_results)
|
||||
|
||||
# 摘要报告
|
||||
with open(os.path.join(output_dir, "test_summary.csv"), 'w', newline='', encoding='utf-8') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["Instruction", "Total Runs", "Success Runs", "Success Rate", "Avg Latency"])
|
||||
for prompt, stats in summary_stats.items():
|
||||
writer.writerow([
|
||||
prompt,
|
||||
stats["total_runs"],
|
||||
stats["success_runs"],
|
||||
stats["success_rate"],
|
||||
stats["avg_latency"]
|
||||
])
|
||||
|
||||
print(f"\n✅ 测试完成! 统计报告已保存至 {output_dir}")
|
||||
|
||||
Reference in New Issue
Block a user