Files
MAGAIL4AutoDrive/Env/verify_observation.py
2025-12-07 20:15:35 +08:00

185 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
简单验证脚本:检查车辆是否正确获取观测空间
用法python verify_observations.py
"""
import argparse
from scenario_env import MultiAgentScenarioEnv
from replay_policy import ReplayPolicy
from metadrive.engine.asset_loader import AssetLoader
import numpy as np
def verify_observations(data_dir, scenario_id=0):
"""
验证观测空间是否正确获取
Args:
data_dir: 数据目录
scenario_id: 场景ID
"""
print("=" * 60)
print("观测空间验证工具")
print("=" * 60)
# 创建环境
env = MultiAgentScenarioEnv(
config={
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
"is_multi_agent": True,
"horizon": 300,
"use_render": False, # 不渲染,加速运行
"sequential_seed": True,
"reactive_traffic": False,
"manual_control": False,
"filter_offroad_vehicles": True,
"lane_tolerance": 3.0,
"replay_mode": True,
"debug": True, # 启用调试以查看详细信息
"specific_scenario_id": scenario_id,
"use_scenario_duration": True,
},
agent2policy=None
)
# 重置环境
print(f"\n加载场景 {scenario_id}...")
obs = env.reset(seed=scenario_id)
# 输出基本信息
print(f"\n场景信息:")
print(f" - 可控车辆数: {len(env.controlled_agents)}")
print(f" - 观测数量: {len(obs)}")
print(f" - 场景时长: {env.scenario_max_duration}")
print(f" - 车辆生成列表长度: {len(env.car_birth_info_list)}")
print(f" - 当前回合数 (round): {env.round}")
# 检查车辆生成信息
if len(env.car_birth_info_list) > 0:
print(f"\n车辆生成信息分析:")
show_times = [car['show_time'] for car in env.car_birth_info_list]
print(f" - show_time 分布: min={min(show_times)}, max={max(show_times)}")
print(f" - show_time == 0 的车辆数: {sum(1 for st in show_times if st == 0)}")
print(f" - 前5个车辆的 show_time: {show_times[:5]}")
else:
print(f"\n⚠️ 警告: 车辆生成列表为空!可能原因:")
print(f" 1. 所有车辆都被车道过滤移除")
print(f" 2. 所有车辆都被类型过滤移除")
print(f" 3. 场景数据中没有有效车辆")
# 验证观测空间
print(f"\n" + "=" * 60)
print("观测空间验证")
print("=" * 60)
if len(obs) == 0:
print("❌ 错误:没有获取到任何观测!")
env.close()
return False
# 检查第一个观测
first_obs = obs[0]
print(f"\n第一个车辆的观测:")
print(f" - 观测类型: {type(first_obs)}")
print(f" - 观测维度: {len(first_obs)}")
# 详细解析观测
if isinstance(first_obs, (list, np.ndarray)):
obs_array = np.array(first_obs)
print(f" - 观测形状: {obs_array.shape}")
print(f" - 数据类型: {obs_array.dtype}")
print(f"\n观测内容分解:")
# 新观测划分(与 Env/scenario_env._get_all_obs 对齐):
# [x, y] (2) + [vx, vy] (2) + heading (1)
# + nearest vehicles info (10 vehicles * 4 = 40)
# + side_lidar (10) + lane_line_lidar (10)
# + traffic_light (1) + destination (2)
lidar_len = 40
side_len = 10
lane_len = 10
pos_slice = slice(0, 2)
vel_slice = slice(2, 4)
heading_idx = 4
lidar_slice = slice(5, 5 + lidar_len)
side_slice = slice(lidar_slice.stop, lidar_slice.stop + side_len)
lane_slice = slice(side_slice.stop, side_slice.stop + lane_len)
tl_idx = lane_slice.stop
dest_slice = slice(tl_idx + 1, tl_idx + 3)
print(f" 位置 (x, y): {obs_array[pos_slice]}")
print(f" 速度 (vx, vy): {obs_array[vel_slice]}")
print(f" 航向角: {obs_array[heading_idx]:.3f} 弧度")
print(f" 最近车辆信息: {len(obs_array[lidar_slice])} 维 (10辆*4: 相对x/相对y/相对vx/相对vy)")
print(f" 侧向检测: {len(obs_array[side_slice])} 个点")
print(f" 车道线检测: {len(obs_array[lane_slice])} 个点")
print(f" 交通灯状态: {obs_array[tl_idx]}")
print(f" 目的地 (x, y): {obs_array[dest_slice]}")
# 检查数据有效性
print(f"\n数据有效性检查:")
has_nan = np.isnan(obs_array).any()
has_inf = np.isinf(obs_array).any()
if has_nan:
print(f" ❌ 观测包含 NaN 值!")
else:
print(f" ✅ 无 NaN 值")
if has_inf:
print(f" ❌ 观测包含 Inf 值!")
else:
print(f" ✅ 无 Inf 值")
# 检查最近车辆信息数据40维
lidar_data = obs_array[lidar_slice]
lidar_min = np.min(lidar_data)
lidar_max = np.max(lidar_data)
print(f"\n 最近车辆特征范围: [{lidar_min:.2f}, {lidar_max:.2f}]")
if lidar_max > 0:
print(f" ✅ 存在非零最近车辆特征")
else:
print(f" ⚠️ 最近车辆特征全零(可能无邻车或距离过远)")
# 运行几步,验证观测持续有效
print(f"\n" + "=" * 60)
print("多步运行验证(前 5 步)")
print("=" * 60)
for step in range(5):
# 空动作
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
obs, rewards, dones, infos = env.step(actions)
print(f"\nStep {step + 1}:")
print(f" - 活跃车辆数: {len(env.controlled_agents)}")
print(f" - 观测数量: {len(obs)}")
if len(obs) > 0:
sample_obs = np.array(obs[0])
print(f" - 第一辆车位置: ({sample_obs[0]:.2f}, {sample_obs[1]:.2f})")
print(f" - 数据有效: {'' if not (np.isnan(sample_obs).any() or np.isinf(sample_obs).any()) else ''}")
if dones["__all__"]:
print(f" - 场景结束")
break
env.close()
print(f"\n" + "=" * 60)
print("✅ 验证完成!观测空间正常工作")
print("=" * 60)
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="验证观测空间")
parser.add_argument("--data_dir", type=str, default="/home/huangfukk/mdsn", help="数据目录")
parser.add_argument("--scenario_id", type=int, default=0, help="场景ID")
args = parser.parse_args()
verify_observations(args.data_dir, args.scenario_id)