#!/usr/bin/env python # -*- coding: utf-8 -*- """ 简单验证脚本:检查车辆是否正确获取观测空间 用法:python verify_observations.py """ import argparse from scenario_env import MultiAgentScenarioEnv from replay_policy import ReplayPolicy from metadrive.engine.asset_loader import AssetLoader import numpy as np def verify_observations(data_dir, scenario_id=0): """ 验证观测空间是否正确获取 Args: data_dir: 数据目录 scenario_id: 场景ID """ print("=" * 60) print("观测空间验证工具") print("=" * 60) # 创建环境 env = MultiAgentScenarioEnv( config={ "data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False), "is_multi_agent": True, "horizon": 300, "use_render": False, # 不渲染,加速运行 "sequential_seed": True, "reactive_traffic": False, "manual_control": False, "filter_offroad_vehicles": True, "lane_tolerance": 3.0, "replay_mode": True, "debug": True, # 启用调试以查看详细信息 "specific_scenario_id": scenario_id, "use_scenario_duration": True, }, agent2policy=None ) # 重置环境 print(f"\n加载场景 {scenario_id}...") obs = env.reset(seed=scenario_id) # 输出基本信息 print(f"\n场景信息:") print(f" - 可控车辆数: {len(env.controlled_agents)}") print(f" - 观测数量: {len(obs)}") print(f" - 场景时长: {env.scenario_max_duration} 步") print(f" - 车辆生成列表长度: {len(env.car_birth_info_list)}") print(f" - 当前回合数 (round): {env.round}") # 检查车辆生成信息 if len(env.car_birth_info_list) > 0: print(f"\n车辆生成信息分析:") show_times = [car['show_time'] for car in env.car_birth_info_list] print(f" - show_time 分布: min={min(show_times)}, max={max(show_times)}") print(f" - show_time == 0 的车辆数: {sum(1 for st in show_times if st == 0)}") print(f" - 前5个车辆的 show_time: {show_times[:5]}") else: print(f"\n⚠️ 警告: 车辆生成列表为空!可能原因:") print(f" 1. 所有车辆都被车道过滤移除") print(f" 2. 所有车辆都被类型过滤移除") print(f" 3. 场景数据中没有有效车辆") # 验证观测空间 print(f"\n" + "=" * 60) print("观测空间验证") print("=" * 60) if len(obs) == 0: print("❌ 错误:没有获取到任何观测!") env.close() return False # 检查第一个观测 first_obs = obs[0] print(f"\n第一个车辆的观测:") print(f" - 观测类型: {type(first_obs)}") print(f" - 观测维度: {len(first_obs)}") # 详细解析观测 if isinstance(first_obs, (list, np.ndarray)): obs_array = np.array(first_obs) print(f" - 观测形状: {obs_array.shape}") print(f" - 数据类型: {obs_array.dtype}") print(f"\n观测内容分解:") # 新观测划分(与 Env/scenario_env._get_all_obs 对齐): # [x, y] (2) + [vx, vy] (2) + heading (1) # + nearest vehicles info (10 vehicles * 4 = 40) # + side_lidar (10) + lane_line_lidar (10) # + traffic_light (1) + destination (2) lidar_len = 40 side_len = 10 lane_len = 10 pos_slice = slice(0, 2) vel_slice = slice(2, 4) heading_idx = 4 lidar_slice = slice(5, 5 + lidar_len) side_slice = slice(lidar_slice.stop, lidar_slice.stop + side_len) lane_slice = slice(side_slice.stop, side_slice.stop + lane_len) tl_idx = lane_slice.stop dest_slice = slice(tl_idx + 1, tl_idx + 3) print(f" 位置 (x, y): {obs_array[pos_slice]}") print(f" 速度 (vx, vy): {obs_array[vel_slice]}") print(f" 航向角: {obs_array[heading_idx]:.3f} 弧度") print(f" 最近车辆信息: {len(obs_array[lidar_slice])} 维 (10辆*4: 相对x/相对y/相对vx/相对vy)") print(f" 侧向检测: {len(obs_array[side_slice])} 个点") print(f" 车道线检测: {len(obs_array[lane_slice])} 个点") print(f" 交通灯状态: {obs_array[tl_idx]}") print(f" 目的地 (x, y): {obs_array[dest_slice]}") # 检查数据有效性 print(f"\n数据有效性检查:") has_nan = np.isnan(obs_array).any() has_inf = np.isinf(obs_array).any() if has_nan: print(f" ❌ 观测包含 NaN 值!") else: print(f" ✅ 无 NaN 值") if has_inf: print(f" ❌ 观测包含 Inf 值!") else: print(f" ✅ 无 Inf 值") # 检查最近车辆信息数据(40维) lidar_data = obs_array[lidar_slice] lidar_min = np.min(lidar_data) lidar_max = np.max(lidar_data) print(f"\n 最近车辆特征范围: [{lidar_min:.2f}, {lidar_max:.2f}]") if lidar_max > 0: print(f" ✅ 存在非零最近车辆特征") else: print(f" ⚠️ 最近车辆特征全零(可能无邻车或距离过远)") # 运行几步,验证观测持续有效 print(f"\n" + "=" * 60) print("多步运行验证(前 5 步)") print("=" * 60) for step in range(5): # 空动作 actions = {aid: [0.0, 0.0] for aid in env.controlled_agents} obs, rewards, dones, infos = env.step(actions) print(f"\nStep {step + 1}:") print(f" - 活跃车辆数: {len(env.controlled_agents)}") print(f" - 观测数量: {len(obs)}") if len(obs) > 0: sample_obs = np.array(obs[0]) print(f" - 第一辆车位置: ({sample_obs[0]:.2f}, {sample_obs[1]:.2f})") print(f" - 数据有效: {'✅' if not (np.isnan(sample_obs).any() or np.isinf(sample_obs).any()) else '❌'}") if dones["__all__"]: print(f" - 场景结束") break env.close() print(f"\n" + "=" * 60) print("✅ 验证完成!观测空间正常工作") print("=" * 60) return True if __name__ == "__main__": parser = argparse.ArgumentParser(description="验证观测空间") parser.add_argument("--data_dir", type=str, default="/home/huangfukk/mdsn", help="数据目录") parser.add_argument("--scenario_id", type=int, default=0, help="场景ID") args = parser.parse_args() verify_observations(args.data_dir, args.scenario_id)