代替点云

This commit is contained in:
2025-12-07 20:15:35 +08:00
parent 113e86bda2
commit 8cb86115c2
10 changed files with 527 additions and 45 deletions

184
Env/verify_observation.py Normal file
View File

@@ -0,0 +1,184 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
简单验证脚本:检查车辆是否正确获取观测空间
用法python verify_observations.py
"""
import argparse
from scenario_env import MultiAgentScenarioEnv
from replay_policy import ReplayPolicy
from metadrive.engine.asset_loader import AssetLoader
import numpy as np
def verify_observations(data_dir, scenario_id=0):
"""
验证观测空间是否正确获取
Args:
data_dir: 数据目录
scenario_id: 场景ID
"""
print("=" * 60)
print("观测空间验证工具")
print("=" * 60)
# 创建环境
env = MultiAgentScenarioEnv(
config={
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
"is_multi_agent": True,
"horizon": 300,
"use_render": False, # 不渲染,加速运行
"sequential_seed": True,
"reactive_traffic": False,
"manual_control": False,
"filter_offroad_vehicles": True,
"lane_tolerance": 3.0,
"replay_mode": True,
"debug": True, # 启用调试以查看详细信息
"specific_scenario_id": scenario_id,
"use_scenario_duration": True,
},
agent2policy=None
)
# 重置环境
print(f"\n加载场景 {scenario_id}...")
obs = env.reset(seed=scenario_id)
# 输出基本信息
print(f"\n场景信息:")
print(f" - 可控车辆数: {len(env.controlled_agents)}")
print(f" - 观测数量: {len(obs)}")
print(f" - 场景时长: {env.scenario_max_duration}")
print(f" - 车辆生成列表长度: {len(env.car_birth_info_list)}")
print(f" - 当前回合数 (round): {env.round}")
# 检查车辆生成信息
if len(env.car_birth_info_list) > 0:
print(f"\n车辆生成信息分析:")
show_times = [car['show_time'] for car in env.car_birth_info_list]
print(f" - show_time 分布: min={min(show_times)}, max={max(show_times)}")
print(f" - show_time == 0 的车辆数: {sum(1 for st in show_times if st == 0)}")
print(f" - 前5个车辆的 show_time: {show_times[:5]}")
else:
print(f"\n⚠️ 警告: 车辆生成列表为空!可能原因:")
print(f" 1. 所有车辆都被车道过滤移除")
print(f" 2. 所有车辆都被类型过滤移除")
print(f" 3. 场景数据中没有有效车辆")
# 验证观测空间
print(f"\n" + "=" * 60)
print("观测空间验证")
print("=" * 60)
if len(obs) == 0:
print("❌ 错误:没有获取到任何观测!")
env.close()
return False
# 检查第一个观测
first_obs = obs[0]
print(f"\n第一个车辆的观测:")
print(f" - 观测类型: {type(first_obs)}")
print(f" - 观测维度: {len(first_obs)}")
# 详细解析观测
if isinstance(first_obs, (list, np.ndarray)):
obs_array = np.array(first_obs)
print(f" - 观测形状: {obs_array.shape}")
print(f" - 数据类型: {obs_array.dtype}")
print(f"\n观测内容分解:")
# 新观测划分(与 Env/scenario_env._get_all_obs 对齐):
# [x, y] (2) + [vx, vy] (2) + heading (1)
# + nearest vehicles info (10 vehicles * 4 = 40)
# + side_lidar (10) + lane_line_lidar (10)
# + traffic_light (1) + destination (2)
lidar_len = 40
side_len = 10
lane_len = 10
pos_slice = slice(0, 2)
vel_slice = slice(2, 4)
heading_idx = 4
lidar_slice = slice(5, 5 + lidar_len)
side_slice = slice(lidar_slice.stop, lidar_slice.stop + side_len)
lane_slice = slice(side_slice.stop, side_slice.stop + lane_len)
tl_idx = lane_slice.stop
dest_slice = slice(tl_idx + 1, tl_idx + 3)
print(f" 位置 (x, y): {obs_array[pos_slice]}")
print(f" 速度 (vx, vy): {obs_array[vel_slice]}")
print(f" 航向角: {obs_array[heading_idx]:.3f} 弧度")
print(f" 最近车辆信息: {len(obs_array[lidar_slice])} 维 (10辆*4: 相对x/相对y/相对vx/相对vy)")
print(f" 侧向检测: {len(obs_array[side_slice])} 个点")
print(f" 车道线检测: {len(obs_array[lane_slice])} 个点")
print(f" 交通灯状态: {obs_array[tl_idx]}")
print(f" 目的地 (x, y): {obs_array[dest_slice]}")
# 检查数据有效性
print(f"\n数据有效性检查:")
has_nan = np.isnan(obs_array).any()
has_inf = np.isinf(obs_array).any()
if has_nan:
print(f" ❌ 观测包含 NaN 值!")
else:
print(f" ✅ 无 NaN 值")
if has_inf:
print(f" ❌ 观测包含 Inf 值!")
else:
print(f" ✅ 无 Inf 值")
# 检查最近车辆信息数据40维
lidar_data = obs_array[lidar_slice]
lidar_min = np.min(lidar_data)
lidar_max = np.max(lidar_data)
print(f"\n 最近车辆特征范围: [{lidar_min:.2f}, {lidar_max:.2f}]")
if lidar_max > 0:
print(f" ✅ 存在非零最近车辆特征")
else:
print(f" ⚠️ 最近车辆特征全零(可能无邻车或距离过远)")
# 运行几步,验证观测持续有效
print(f"\n" + "=" * 60)
print("多步运行验证(前 5 步)")
print("=" * 60)
for step in range(5):
# 空动作
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
obs, rewards, dones, infos = env.step(actions)
print(f"\nStep {step + 1}:")
print(f" - 活跃车辆数: {len(env.controlled_agents)}")
print(f" - 观测数量: {len(obs)}")
if len(obs) > 0:
sample_obs = np.array(obs[0])
print(f" - 第一辆车位置: ({sample_obs[0]:.2f}, {sample_obs[1]:.2f})")
print(f" - 数据有效: {'' if not (np.isnan(sample_obs).any() or np.isinf(sample_obs).any()) else ''}")
if dones["__all__"]:
print(f" - 场景结束")
break
env.close()
print(f"\n" + "=" * 60)
print("✅ 验证完成!观测空间正常工作")
print("=" * 60)
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="验证观测空间")
parser.add_argument("--data_dir", type=str, default="/home/huangfukk/mdsn", help="数据目录")
parser.add_argument("--scenario_id", type=int, default=0, help="场景ID")
args = parser.parse_args()
verify_observations(args.data_dir, args.scenario_id)