优化交互式测试验证脚本,针对场景4修改提示词以及代码

This commit is contained in:
2026-01-02 16:28:58 +08:00
parent c08cdfb339
commit 6f990e645d
31 changed files with 71855 additions and 184 deletions

View File

@@ -0,0 +1,60 @@
import requests
import time
import json
class APIClient:
def __init__(self, base_url="http://127.0.0.1:8000"):
self.base_url = base_url
self.endpoint = "/generate_plan"
def send_request(self, prompt, timeout=60):
"""
Sends a request to the API and returns a structured result.
Returns:
dict: {
"success": bool,
"data": dict or None,
"latency": float (seconds),
"error": str or None,
"http_status": int or None
}
"""
url = f"{self.base_url}{self.endpoint}"
payload = {"user_prompt": prompt}
headers = {"Content-Type": "application/json"}
start_time = time.time()
try:
response = requests.post(url, json=payload, headers=headers, timeout=timeout)
latency = time.time() - start_time
response.raise_for_status()
try:
data = response.json()
return {
"success": True,
"data": data,
"latency": latency,
"error": None,
"http_status": response.status_code
}
except json.JSONDecodeError:
return {
"success": False,
"data": None,
"latency": latency,
"error": f"Invalid JSON response: {response.text[:200]}",
"http_status": response.status_code
}
except requests.exceptions.RequestException as e:
latency = time.time() - start_time
return {
"success": False,
"data": None,
"latency": latency,
"error": str(e),
"http_status": getattr(e.response, 'status_code', None)
}

View File

@@ -0,0 +1,129 @@
import os
import json
import time
import csv
import shutil
from datetime import datetime
from .api_client import APIClient
from .visualizer import generate_visualization, sanitize_filename
def run_batch_test():
# 1. 选择指令文件
base_dir = os.path.dirname(os.path.dirname(__file__))
instr_dir = os.path.join(base_dir, "instructions")
files = [f for f in os.listdir(instr_dir) if f.endswith('.txt')]
if not files:
print("❌ 未在 instructions 目录下找到 .txt 文件")
return
print("\n请选择测试指令文件:")
for i, f in enumerate(files):
print(f"{i+1}. {f}")
try:
idx = int(input("请输入序号: ").strip()) - 1
if idx < 0 or idx >= len(files):
print("❌ 无效序号")
return
selected_file = os.path.join(instr_dir, files[idx])
except ValueError:
print("❌ 输入无效")
return
# 2. 配置参数
try:
iterations = int(input("请输入每条指令的测试次数 (默认1): ").strip() or "1")
except ValueError:
iterations = 1
# 3. 准备输出目录
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
output_dir = os.path.join(base_dir, "validation", timestamp)
os.makedirs(output_dir, exist_ok=True)
# 备份指令文件
shutil.copy(selected_file, os.path.join(output_dir, "instructions_backup.txt"))
# 读取指令
with open(selected_file, 'r', encoding='utf-8') as f:
instructions = [line.strip() for line in f if line.strip() and not line.startswith('#')]
print(f"\n🚀 开始批量测试 (共 {len(instructions)} 条指令, 每条 {iterations} 次)")
print(f"📂 输出目录: {output_dir}\n")
client = APIClient()
detailed_results = []
summary_stats = {}
for i, prompt in enumerate(instructions, 1):
print(f"[{i}/{len(instructions)}] 测试指令: {prompt[:30]}...")
safe_name = sanitize_filename(prompt)
instr_out_dir = os.path.join(output_dir, safe_name)
os.makedirs(instr_out_dir, exist_ok=True)
success_count = 0
total_latency = 0
for k in range(1, iterations + 1):
print(f" - 第 {k} 次...", end="", flush=True)
result = client.send_request(prompt)
# 记录详情
detailed_results.append({
"instruction": prompt,
"run_id": k,
"success": result['success'],
"latency": result['latency'],
"error": result['error'] or ""
})
# 保存产物
if result['success']:
print(f" ✅ ({result['latency']:.2f}s)")
success_count += 1
total_latency += result['latency']
# 保存JSON
with open(os.path.join(instr_out_dir, f"{k}.json"), 'w', encoding='utf-8') as f:
json.dump(result['data'], f, indent=2, ensure_ascii=False)
# 保存图片
if result['data'] and 'root' in result['data']:
generate_visualization(result['data']['root'], os.path.join(instr_out_dir, f"{k}.png"))
else:
print(f"{result['error']}")
time.sleep(0.5) # 避免过快请求
# 统计单条指令
avg_lat = total_latency / success_count if success_count > 0 else 0
summary_stats[prompt] = {
"total_runs": iterations,
"success_runs": success_count,
"success_rate": f"{(success_count/iterations)*100:.1f}%",
"avg_latency": f"{avg_lat:.2f}s"
}
# 4. 生成报告
# 详细报告
with open(os.path.join(output_dir, "test_details.csv"), 'w', newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=["instruction", "run_id", "success", "latency", "error"])
writer.writeheader()
writer.writerows(detailed_results)
# 摘要报告
with open(os.path.join(output_dir, "test_summary.csv"), 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow(["Instruction", "Total Runs", "Success Runs", "Success Rate", "Avg Latency"])
for prompt, stats in summary_stats.items():
writer.writerow([
prompt,
stats["total_runs"],
stats["success_runs"],
stats["success_rate"],
stats["avg_latency"]
])
print(f"\n✅ 测试完成! 统计报告已保存至 {output_dir}")

View File

@@ -0,0 +1,37 @@
import requests
import os
import sys
def upload_mission(drone_ip, file_path):
"""上传一个JSON任务文件到无人机"""
if not os.path.exists(file_path):
print(f"Error: File not found at {file_path}")
return
url = f"http://{drone_ip}:5000/missions"
print(f"正在上传 {file_path}{url} ...")
try:
with open(file_path, 'rb') as f:
files = {'file': (os.path.basename(file_path), f, 'application/json')}
response = requests.post(url, files=files, timeout=10)
# 检查HTTP响应状态码
response.raise_for_status()
print("上传成功!")
print("无人机端响应:", response.json())
except requests.exceptions.RequestException as e:
print(f"上传过程中发生错误: {e}")
if __name__ == '__main__':
if len(sys.argv) < 3:
print("用法: python ground_station_client.py [无人机IP地址] [JSON文件路径]")
print("示例: python ground_station_client.py 192.168.1.10 ./missions/rescue_mission.json")
sys.exit(1)
drone_ip_address = sys.argv[1]
mission_file_path = sys.argv[2]
upload_mission(drone_ip_address, mission_file_path)

View File

@@ -0,0 +1,87 @@
import os
import json
import logging
from .api_client import APIClient
from .visualizer import generate_visualization, sanitize_filename
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(message)s')
def run_interactive_test():
client = APIClient()
print("\n🚀 进入交互式测试模式 (输入 'exit''q' 退出)")
while True:
try:
prompt = input("\n请输入测试指令: ").strip()
if prompt.lower() in ['exit', 'q']:
break
if not prompt:
continue
print("⏳ 正在请求后端 API...")
result = client.send_request(prompt)
if result['success']:
print(f"✅ 请求成功 (耗时: {result['latency']:.2f}s)")
# 创建输出目录
sanitized_name = sanitize_filename(prompt)
output_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"validation",
"temporary",
sanitized_name
)
os.makedirs(output_dir, exist_ok=True)
# 保存 JSON
json_path = os.path.join(output_dir, "response.json")
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(result['data'], f, indent=2, ensure_ascii=False)
# 保存日志
log_path = os.path.join(output_dir, "process.log")
with open(log_path, 'w', encoding='utf-8') as f:
f.write(f"Prompt: {prompt}\n")
f.write(f"Status: {result['http_status']}\n")
f.write(f"Latency: {result['latency']}\n")
f.write(f"Response: {json.dumps(result['data'], ensure_ascii=False)}\n")
# 生成图片
if result['data'] and 'root' in result['data']:
png_path = os.path.join(output_dir, "plan.png")
if generate_visualization(result['data']['root'], png_path):
print(f"🖼️ 可视化图已生成: {png_path}")
else:
print("⚠️ 可视化生成失败")
print(f"📂 结果已保存至: {output_dir}")
else:
print(f"❌ 请求失败: {result['error']}")
# 即使失败也保存日志,以便排查
sanitized_name = sanitize_filename(prompt)
output_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"validation",
"temporary",
sanitized_name
)
os.makedirs(output_dir, exist_ok=True)
log_path = os.path.join(output_dir, "process.log")
with open(log_path, 'w', encoding='utf-8') as f:
f.write(f"Prompt: {prompt}\n")
f.write(f"Status: {result['http_status']}\n")
f.write(f"Latency: {result['latency']}\n")
f.write(f"Error: {result['error']}\n")
# 如果有部分数据,也记录下来
if result['data']:
f.write(f"Partial Response: {json.dumps(result['data'], ensure_ascii=False)}\n")
print(f"⚠️ 错误日志已保存至: {log_path}")
except KeyboardInterrupt:
print("\n已取消")
break

View File

@@ -0,0 +1,174 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import json
import argparse
from typing import Any, Dict
import requests
def build_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="调用本地 llama-server (OpenAI兼容) 进行推理,支持自定义系统/用户提示词"
)
parser.add_argument(
"--base-url",
default=os.getenv("SIMPLE_BASE_URL", "http://127.0.0.1:8081/v1"),
help="llama-server 的基础URL默认: http://127.0.0.1:8081/v1或环境变量 SIMPLE_BASE_URL",
)
parser.add_argument(
"--model",
default=os.getenv("SIMPLE_MODEL", "local-model"),
help="模型名称(默认: local-model或环境变量 SIMPLE_MODEL",
)
parser.add_argument(
"--system",
default="You are a helpful assistant.",
help="系统提示词system role",
)
parser.add_argument(
"--system-file",
default=None,
help="系统提示词文件路径txt若提供则覆盖 --system 的字符串",
)
parser.add_argument(
"--user",
default=None,
help="用户提示词user role若不传则从交互式输入读取",
)
parser.add_argument(
"--temperature",
type=float,
default=0.2,
help="采样温度(默认: 0.2",
)
parser.add_argument(
"--max-tokens",
type=int,
default=4096,
help="最大生成Token数默认: 4096",
)
parser.add_argument(
"--timeout",
type=float,
default=120.0,
help="HTTP超时时间秒默认: 120",
)
parser.add_argument(
"--verbose",
action="store_true",
help="打印完整返回JSON",
)
return parser.parse_args()
def call_llama_server(
base_url: str,
model: str,
system_prompt: str,
user_prompt: str,
temperature: float,
max_tokens: int,
timeout: float,
) -> Dict[str, Any]:
endpoint = base_url.rstrip("/") + "/chat/completions"
headers: Dict[str, str] = {"Content-Type": "application/json"}
# 兼容需要API Key的代理/服务llama-server通常不强制
api_key = os.getenv("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
payload: Dict[str, Any] = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": temperature,
"max_tokens": max_tokens,
}
resp = requests.post(endpoint, headers=headers, data=json.dumps(payload), timeout=timeout)
resp.raise_for_status()
return resp.json()
def main() -> None:
args = build_args()
user_prompt = args.user
if not user_prompt:
try:
user_prompt = input("请输入用户提示词: ")
except KeyboardInterrupt:
print("\n已取消。")
sys.exit(1)
# 解析系统提示词:优先使用 --system-file
system_prompt = args.system
if args.system_file:
try:
with open(args.system_file, "r", encoding="utf-8") as f:
system_prompt = f.read()
except Exception as e:
print("\n❌ 读取系统提示词文件失败:")
print(str(e))
sys.exit(1)
try:
print("--- llama-server 推理 ---")
print(f"Base URL: {args.base_url}")
print(f"Model: {args.model}")
if args.system_file:
print(f"System(from file): {args.system_file}")
else:
print(f"System: {system_prompt}")
print(f"User: {user_prompt}")
data = call_llama_server(
base_url=args.base_url,
model=args.model,
system_prompt=system_prompt,
user_prompt=user_prompt,
temperature=args.temperature,
max_tokens=args.max_tokens,
timeout=args.timeout,
)
if args.verbose:
print("\n完整返回JSON:")
print(json.dumps(data, ensure_ascii=False, indent=2))
# 尝试按OpenAI兼容格式提取assistant内容
content = None
try:
content = data["choices"][0]["message"]["content"]
except Exception:
pass
if content is not None:
print("\n模型输出:")
print(content)
else:
# 兜底打印
print("\n无法按OpenAI兼容字段解析内容原始返回如下")
print(json.dumps(data, ensure_ascii=False))
except requests.exceptions.RequestException as e:
print("\n❌ 请求失败:请确认 llama-server 已在 8081 端口启动并可访问。")
print(f"详情: {e}")
sys.exit(2)
except Exception as e:
print("\n❌ 发生未预期的错误:")
print(str(e))
sys.exit(3)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,287 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
从API测试日志中提取JSON响应并批量可视化
"""
import json
import os
import re
import logging
import platform
import random
import html
from typing import Dict, List, Tuple
from collections import defaultdict
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def sanitize_filename(text: str) -> str:
"""将文本转换为安全的文件名"""
# 移除或替换不安全的字符
text = re.sub(r'[<>:"/\\|?*]', '_', text)
# 限制长度
if len(text) > 100:
text = text[:100]
return text
def _pick_zh_font():
"""选择合适的中文字体"""
sys = platform.system()
if sys == "Windows":
return "Microsoft YaHei"
elif sys == "Darwin":
return "PingFang SC"
else:
return "Noto Sans CJK SC"
def _add_nodes_and_edges(node: dict, dot, parent_id: str | None = None) -> str:
"""递归辅助函数,用于添加节点和边。"""
try:
from graphviz import Digraph
except ImportError:
logging.critical("错误未安装graphviz库。请运行: pip install graphviz")
return ""
current_id = f"{id(node)}_{random.randint(1000, 9999)}"
# 准备节点标签HTML-like正确换行与转义
name = html.escape(str(node.get('name', '')))
ntype = html.escape(str(node.get('type', '')))
label_parts = [f"<B>{name}</B> <FONT POINT-SIZE='10'><I>({ntype})</I></FONT>"]
# 格式化参数显示
params = node.get('params') or {}
if params:
params_lines = []
for key, value in params.items():
k = html.escape(str(key))
if isinstance(value, float):
value_str = f"{value:.2f}".rstrip('0').rstrip('.')
else:
value_str = str(value)
v = html.escape(value_str)
params_lines.append(f"{k}: {v}")
params_text = "<BR ALIGN='LEFT'/>".join(params_lines)
label_parts.append(f"<FONT POINT-SIZE='9' COLOR='#555555'>{params_text}</FONT>")
node_label = f"<{'<BR/>'.join(label_parts)}>"
# 根据类型设置节点样式和颜色
node_type = (node.get('type') or '').lower()
shape = 'ellipse'
style = 'filled'
fillcolor = '#e6e6e6' # 默认灰色填充
border_color = '#666666' # 默认描边色
if node_type == 'action':
shape = 'box'
style = 'rounded,filled'
fillcolor = "#cde4ff" # 浅蓝
elif node_type == 'condition':
shape = 'diamond'
style = 'filled'
fillcolor = "#fff2cc" # 浅黄
elif node_type == 'sequence':
shape = 'ellipse'
style = 'filled'
fillcolor = '#d5e8d4' # 绿色
elif node_type == 'selector':
shape = 'ellipse'
style = 'filled'
fillcolor = '#ffe6cc' # 橙色
elif node_type == 'parallel':
shape = 'ellipse'
style = 'filled'
fillcolor = '#e1d5e7' # 紫色
# 特别标记安全相关节点
if node.get('name') in ['battery_above', 'gps_status', 'SafetyMonitor']:
border_color = '#ff0000' # 红色边框突出显示安全节点
style = 'filled,bold' # 加粗
dot.node(current_id, label=node_label, shape=shape, style=style, fillcolor=fillcolor, color=border_color)
# 连接父节点
if parent_id:
dot.edge(parent_id, current_id)
# 递归处理子节点
children = node.get("children", [])
if not children:
return current_id
# 记录所有子节点的ID
child_ids = []
# 正确的递归连接:每个子节点都连接到当前节点
for child in children:
child_id = _add_nodes_and_edges(child, dot, current_id)
child_ids.append(child_id)
# 子节点同级排列(横向排布,更直观地表现同层)
if len(child_ids) > 1:
with dot.subgraph(name=f"rank_{current_id}") as s:
s.attr(rank='same')
for cid in child_ids:
s.node(cid)
return current_id
def generate_visualization(node: Dict, file_path: str):
"""
使用Graphviz将Pytree字典可视化并保存到指定路径。
"""
try:
from graphviz import Digraph
except ImportError:
logging.critical("错误未安装graphviz库。请运行: pip install graphviz")
return False
fontname = _pick_zh_font()
dot = Digraph('Pytree', comment='Drone Mission Plan')
dot.attr(rankdir='TB', label='Drone Mission Plan', fontsize='20', fontname=fontname)
dot.attr('node', shape='box', style='rounded,filled', fontname=fontname)
dot.attr('edge', fontname=fontname)
_add_nodes_and_edges(node, dot)
try:
# 确保输出目录存在
base_path, ext = os.path.splitext(file_path)
render_path = base_path if ext.lower() == '.png' else file_path
out_dir = os.path.dirname(render_path)
if out_dir and not os.path.exists(out_dir):
os.makedirs(out_dir, exist_ok=True)
# 保存为 .png 文件,并自动删除源码 .gv 文件
dot.render(render_path, format='png', cleanup=True, view=False)
return True
except Exception as e:
logging.error(f"❌ 生成可视化图形失败: {e}")
return False
# 保留旧的函数以兼容(如果有其他脚本引用)
def _visualize_pytree(node: Dict, file_path: str):
return generate_visualization(node, file_path)
def parse_log_file(log_file_path: str) -> Dict[str, List[Dict]]:
"""
解析日志文件提取原始指令和完整API响应JSON
返回: {原始指令: [JSON响应列表]}
"""
with open(log_file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 按分隔符分割条目
entries = re.split(r'={80,}', content)
results = defaultdict(list)
for entry in entries:
if not entry.strip():
continue
# 提取原始指令
instruction_match = re.search(r'原始指令:\s*(.+)', entry)
if not instruction_match:
continue
original_instruction = instruction_match.group(1).strip()
# 提取完整API响应JSON
json_match = re.search(r'完整API响应:\s*\n(\{.*\})', entry, re.DOTALL)
if not json_match:
logging.warning(f"未找到指令 '{original_instruction}' 的JSON响应")
continue
json_str = json_match.group(1).strip()
try:
json_obj = json.loads(json_str)
results[original_instruction].append(json_obj)
logging.info(f"成功提取指令 '{original_instruction}' 的JSON响应")
except json.JSONDecodeError as e:
logging.error(f"解析指令 '{original_instruction}' 的JSON失败: {e}")
continue
return results
def process_and_visualize(log_file_path: str, output_dir: str):
"""
处理日志文件并批量可视化
"""
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 解析日志文件
logging.info(f"开始解析日志文件: {log_file_path}")
instruction_responses = parse_log_file(log_file_path)
logging.info(f"共找到 {len(instruction_responses)} 个不同的原始指令")
# 处理每个指令的所有响应
for instruction, responses in instruction_responses.items():
logging.info(f"\n处理指令: {instruction} (共 {len(responses)} 个响应)")
# 创建指令目录(使用安全的文件名)
safe_instruction_name = sanitize_filename(instruction)
instruction_dir = os.path.join(output_dir, safe_instruction_name)
os.makedirs(instruction_dir, exist_ok=True)
# 处理每个响应
for idx, response in enumerate(responses, 1):
try:
# 提取root节点
root_node = response.get('root')
if not root_node:
logging.warning(f"响应 #{idx} 没有root节点跳过")
continue
# 生成文件名
json_filename = f"response_{idx}.json"
png_filename = f"response_{idx}.png"
json_path = os.path.join(instruction_dir, json_filename)
png_path = os.path.join(instruction_dir, png_filename)
# 保存JSON文件
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(response, f, ensure_ascii=False, indent=2)
logging.info(f" 保存JSON: {json_filename}")
# 生成可视化
generate_visualization(root_node, png_path)
logging.info(f" 生成可视化: {png_filename}")
except Exception as e:
logging.error(f"处理响应 #{idx} 时出错: {e}")
continue
logging.info(f"\n✅ 所有处理完成!结果保存在: {output_dir}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="批量可视化API测试日志")
parser.add_argument("--log", default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "api_test_log.txt"), help="日志文件路径")
parser.add_argument("--out", default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "validation"), help="输出目录")
args = parser.parse_args()
log_file = args.log
output_directory = args.out
print(f"日志文件: {log_file}")
print(f"输出目录: {output_directory}")
if os.path.exists(log_file):
process_and_visualize(log_file, output_directory)
else:
print(f"错误: 找不到日志文件 {log_file}")