185 lines
6.5 KiB
Python
185 lines
6.5 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
简单验证脚本:检查车辆是否正确获取观测空间
|
||
用法:python verify_observations.py
|
||
"""
|
||
|
||
import argparse
|
||
from scenario_env import MultiAgentScenarioEnv
|
||
from replay_policy import ReplayPolicy
|
||
from metadrive.engine.asset_loader import AssetLoader
|
||
import numpy as np
|
||
|
||
def verify_observations(data_dir, scenario_id=0):
|
||
"""
|
||
验证观测空间是否正确获取
|
||
|
||
Args:
|
||
data_dir: 数据目录
|
||
scenario_id: 场景ID
|
||
"""
|
||
print("=" * 60)
|
||
print("观测空间验证工具")
|
||
print("=" * 60)
|
||
|
||
# 创建环境
|
||
env = MultiAgentScenarioEnv(
|
||
config={
|
||
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
|
||
"is_multi_agent": True,
|
||
"horizon": 300,
|
||
"use_render": False, # 不渲染,加速运行
|
||
"sequential_seed": True,
|
||
"reactive_traffic": False,
|
||
"manual_control": False,
|
||
"filter_offroad_vehicles": True,
|
||
"lane_tolerance": 3.0,
|
||
"replay_mode": True,
|
||
"debug": True, # 启用调试以查看详细信息
|
||
"specific_scenario_id": scenario_id,
|
||
"use_scenario_duration": True,
|
||
},
|
||
agent2policy=None
|
||
)
|
||
|
||
# 重置环境
|
||
print(f"\n加载场景 {scenario_id}...")
|
||
obs = env.reset(seed=scenario_id)
|
||
|
||
# 输出基本信息
|
||
print(f"\n场景信息:")
|
||
print(f" - 可控车辆数: {len(env.controlled_agents)}")
|
||
print(f" - 观测数量: {len(obs)}")
|
||
print(f" - 场景时长: {env.scenario_max_duration} 步")
|
||
print(f" - 车辆生成列表长度: {len(env.car_birth_info_list)}")
|
||
print(f" - 当前回合数 (round): {env.round}")
|
||
|
||
# 检查车辆生成信息
|
||
if len(env.car_birth_info_list) > 0:
|
||
print(f"\n车辆生成信息分析:")
|
||
show_times = [car['show_time'] for car in env.car_birth_info_list]
|
||
print(f" - show_time 分布: min={min(show_times)}, max={max(show_times)}")
|
||
print(f" - show_time == 0 的车辆数: {sum(1 for st in show_times if st == 0)}")
|
||
print(f" - 前5个车辆的 show_time: {show_times[:5]}")
|
||
else:
|
||
print(f"\n⚠️ 警告: 车辆生成列表为空!可能原因:")
|
||
print(f" 1. 所有车辆都被车道过滤移除")
|
||
print(f" 2. 所有车辆都被类型过滤移除")
|
||
print(f" 3. 场景数据中没有有效车辆")
|
||
|
||
# 验证观测空间
|
||
print(f"\n" + "=" * 60)
|
||
print("观测空间验证")
|
||
print("=" * 60)
|
||
|
||
if len(obs) == 0:
|
||
print("❌ 错误:没有获取到任何观测!")
|
||
env.close()
|
||
return False
|
||
|
||
# 检查第一个观测
|
||
first_obs = obs[0]
|
||
print(f"\n第一个车辆的观测:")
|
||
print(f" - 观测类型: {type(first_obs)}")
|
||
print(f" - 观测维度: {len(first_obs)}")
|
||
|
||
# 详细解析观测
|
||
if isinstance(first_obs, (list, np.ndarray)):
|
||
obs_array = np.array(first_obs)
|
||
print(f" - 观测形状: {obs_array.shape}")
|
||
print(f" - 数据类型: {obs_array.dtype}")
|
||
print(f"\n观测内容分解:")
|
||
# 新观测划分(与 Env/scenario_env._get_all_obs 对齐):
|
||
# [x, y] (2) + [vx, vy] (2) + heading (1)
|
||
# + nearest vehicles info (10 vehicles * 4 = 40)
|
||
# + side_lidar (10) + lane_line_lidar (10)
|
||
# + traffic_light (1) + destination (2)
|
||
lidar_len = 40
|
||
side_len = 10
|
||
lane_len = 10
|
||
pos_slice = slice(0, 2)
|
||
vel_slice = slice(2, 4)
|
||
heading_idx = 4
|
||
lidar_slice = slice(5, 5 + lidar_len)
|
||
side_slice = slice(lidar_slice.stop, lidar_slice.stop + side_len)
|
||
lane_slice = slice(side_slice.stop, side_slice.stop + lane_len)
|
||
tl_idx = lane_slice.stop
|
||
dest_slice = slice(tl_idx + 1, tl_idx + 3)
|
||
|
||
print(f" 位置 (x, y): {obs_array[pos_slice]}")
|
||
print(f" 速度 (vx, vy): {obs_array[vel_slice]}")
|
||
print(f" 航向角: {obs_array[heading_idx]:.3f} 弧度")
|
||
print(f" 最近车辆信息: {len(obs_array[lidar_slice])} 维 (10辆*4: 相对x/相对y/相对vx/相对vy)")
|
||
print(f" 侧向检测: {len(obs_array[side_slice])} 个点")
|
||
print(f" 车道线检测: {len(obs_array[lane_slice])} 个点")
|
||
print(f" 交通灯状态: {obs_array[tl_idx]}")
|
||
print(f" 目的地 (x, y): {obs_array[dest_slice]}")
|
||
|
||
# 检查数据有效性
|
||
print(f"\n数据有效性检查:")
|
||
has_nan = np.isnan(obs_array).any()
|
||
has_inf = np.isinf(obs_array).any()
|
||
|
||
if has_nan:
|
||
print(f" ❌ 观测包含 NaN 值!")
|
||
else:
|
||
print(f" ✅ 无 NaN 值")
|
||
|
||
if has_inf:
|
||
print(f" ❌ 观测包含 Inf 值!")
|
||
else:
|
||
print(f" ✅ 无 Inf 值")
|
||
|
||
# 检查最近车辆信息数据(40维)
|
||
lidar_data = obs_array[lidar_slice]
|
||
lidar_min = np.min(lidar_data)
|
||
lidar_max = np.max(lidar_data)
|
||
print(f"\n 最近车辆特征范围: [{lidar_min:.2f}, {lidar_max:.2f}]")
|
||
|
||
if lidar_max > 0:
|
||
print(f" ✅ 存在非零最近车辆特征")
|
||
else:
|
||
print(f" ⚠️ 最近车辆特征全零(可能无邻车或距离过远)")
|
||
|
||
# 运行几步,验证观测持续有效
|
||
print(f"\n" + "=" * 60)
|
||
print("多步运行验证(前 5 步)")
|
||
print("=" * 60)
|
||
|
||
for step in range(5):
|
||
# 空动作
|
||
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
|
||
obs, rewards, dones, infos = env.step(actions)
|
||
|
||
print(f"\nStep {step + 1}:")
|
||
print(f" - 活跃车辆数: {len(env.controlled_agents)}")
|
||
print(f" - 观测数量: {len(obs)}")
|
||
|
||
if len(obs) > 0:
|
||
sample_obs = np.array(obs[0])
|
||
print(f" - 第一辆车位置: ({sample_obs[0]:.2f}, {sample_obs[1]:.2f})")
|
||
print(f" - 数据有效: {'✅' if not (np.isnan(sample_obs).any() or np.isinf(sample_obs).any()) else '❌'}")
|
||
|
||
if dones["__all__"]:
|
||
print(f" - 场景结束")
|
||
break
|
||
|
||
env.close()
|
||
|
||
print(f"\n" + "=" * 60)
|
||
print("✅ 验证完成!观测空间正常工作")
|
||
print("=" * 60)
|
||
|
||
return True
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="验证观测空间")
|
||
parser.add_argument("--data_dir", type=str, default="/home/huangfukk/mdsn", help="数据目录")
|
||
parser.add_argument("--scenario_id", type=int, default=0, help="场景ID")
|
||
|
||
args = parser.parse_args()
|
||
|
||
verify_observations(args.data_dir, args.scenario_id)
|