代替点云
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
109
Env/check_dataset.py
Normal file
109
Env/check_dataset.py
Normal file
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据集检查脚本:统计可用的场景数量
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import os
|
||||
from pathlib import Path
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
|
||||
def check_dataset(data_dir, subfolder="exp_filtered"):
|
||||
"""
|
||||
检查数据集中的场景数量
|
||||
|
||||
Args:
|
||||
data_dir: 数据根目录
|
||||
subfolder: 数据子目录(exp_filtered 或 exp_converted)
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("数据集检查工具")
|
||||
print("=" * 80)
|
||||
|
||||
# 获取完整路径
|
||||
full_path = AssetLoader.file_path(data_dir, subfolder, unix_style=False)
|
||||
print(f"\n数据集路径: {full_path}")
|
||||
|
||||
# 检查文件结构
|
||||
if not os.path.exists(full_path):
|
||||
print(f"❌ 错误:路径不存在: {full_path}")
|
||||
return
|
||||
|
||||
print(f"✅ 路径存在")
|
||||
|
||||
# 读取数据集映射
|
||||
mapping_file = os.path.join(full_path, "dataset_mapping.pkl")
|
||||
if os.path.exists(mapping_file):
|
||||
print(f"\n读取 dataset_mapping.pkl...")
|
||||
with open(mapping_file, 'rb') as f:
|
||||
dataset_mapping = pickle.load(f)
|
||||
|
||||
print(f"✅ 数据集映射文件存在")
|
||||
print(f" 映射的场景数量: {len(dataset_mapping)}")
|
||||
|
||||
# 统计各个子目录的分布
|
||||
subdirs = {}
|
||||
for filename, subdir in dataset_mapping.items():
|
||||
if subdir not in subdirs:
|
||||
subdirs[subdir] = []
|
||||
subdirs[subdir].append(filename)
|
||||
|
||||
print(f"\n场景分布:")
|
||||
for subdir, files in subdirs.items():
|
||||
print(f" {subdir}: {len(files)} 个场景")
|
||||
|
||||
# 读取数据集摘要
|
||||
summary_file = os.path.join(full_path, "dataset_summary.pkl")
|
||||
if os.path.exists(summary_file):
|
||||
print(f"\n读取 dataset_summary.pkl...")
|
||||
with open(summary_file, 'rb') as f:
|
||||
dataset_summary = pickle.load(f)
|
||||
|
||||
print(f"✅ 数据集摘要文件存在")
|
||||
print(f" 摘要的场景数量: {len(dataset_summary)}")
|
||||
|
||||
# 打印前几个场景的ID
|
||||
print(f"\n前10个场景ID:")
|
||||
for i, (scenario_id, info) in enumerate(list(dataset_summary.items())[:10]):
|
||||
if isinstance(info, dict):
|
||||
track_length = info.get('track_length', 'N/A')
|
||||
num_objects = info.get('number_summary', {}).get('num_objects', 'N/A')
|
||||
print(f" {i}: {scenario_id[:16]}... (时长: {track_length}, 对象数: {num_objects})")
|
||||
|
||||
# 检查实际文件
|
||||
print(f"\n检查实际文件...")
|
||||
pkl_files = list(Path(full_path).rglob("*.pkl"))
|
||||
# 排除dataset_mapping和dataset_summary
|
||||
scenario_files = [f for f in pkl_files if f.name not in ["dataset_mapping.pkl", "dataset_summary.pkl"]]
|
||||
print(f" 实际场景文件数量: {len(scenario_files)}")
|
||||
|
||||
# 检查子目录
|
||||
print(f"\n子目录结构:")
|
||||
for item in os.listdir(full_path):
|
||||
item_path = os.path.join(full_path, item)
|
||||
if os.path.isdir(item_path):
|
||||
pkl_count = len([f for f in os.listdir(item_path) if f.endswith('.pkl')])
|
||||
print(f" {item}/: {pkl_count} 个pkl文件")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("检查完成")
|
||||
print("=" * 80)
|
||||
|
||||
# 返回场景数量
|
||||
return len(dataset_mapping) if 'dataset_mapping' in locals() else 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
|
||||
|
||||
print("\n【检查 exp_filtered 数据集】")
|
||||
num_filtered = check_dataset(WAYMO_DATA_DIR, "exp_filtered")
|
||||
|
||||
print("\n\n【检查 exp_converted 数据集】")
|
||||
num_converted = check_dataset(WAYMO_DATA_DIR, "exp_converted")
|
||||
|
||||
print("\n\n总结:")
|
||||
print(f" exp_filtered: {num_filtered} 个场景")
|
||||
print(f" exp_converted: {num_converted} 个场景")
|
||||
|
||||
@@ -4,7 +4,7 @@ from simple_idm_policy import ConstantVelocityPolicy
|
||||
from replay_policy import ReplayPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn/exp_converted"
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
|
||||
|
||||
|
||||
def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
|
||||
@@ -40,10 +40,10 @@ def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=Fa
|
||||
# ✅ 环境创建移到循环外面,避免重复创建
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
|
||||
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"horizon": horizon,
|
||||
"use_render": render,
|
||||
"use_render": render, # 如果False会完全禁用渲染,避免LANE_FREEWAY错误
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": False, # 回放模式下不需要反应式交通
|
||||
"manual_control": False,
|
||||
@@ -57,20 +57,33 @@ def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=Fa
|
||||
"spawn_vehicles": spawn_vehicles,
|
||||
"spawn_pedestrians": spawn_pedestrians,
|
||||
"spawn_cyclists": spawn_cyclists,
|
||||
# ✅ 关键:设置可用场景数量
|
||||
#"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数
|
||||
},
|
||||
agent2policy=None # 回放模式不需要统一策略
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取可用场景数量
|
||||
num_scenarios = env.config.get("num_scenarios", 1)
|
||||
print(f"可用场景数量: {num_scenarios}")
|
||||
|
||||
for episode in range(num_episodes):
|
||||
print(f"\n{'='*50}")
|
||||
print(f"回合 {episode + 1}/{num_episodes}")
|
||||
if scenario_id is not None:
|
||||
print(f"场景ID: {scenario_id}")
|
||||
else:
|
||||
# 循环使用场景
|
||||
scenario_idx = episode % num_scenarios
|
||||
print(f"使用场景索引: {scenario_idx}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
# ✅ 如果不是指定场景,使用seed来遍历不同场景
|
||||
seed = scenario_id if scenario_id is not None else episode
|
||||
# ✅ 如果不是指定场景,使用循环的场景索引
|
||||
if scenario_id is not None:
|
||||
seed = scenario_id
|
||||
else:
|
||||
seed = episode % num_scenarios
|
||||
obs = env.reset(seed=seed)
|
||||
|
||||
# 为每个车辆分配 ReplayPolicy
|
||||
@@ -180,11 +193,11 @@ def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debu
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
|
||||
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": horizon,
|
||||
"use_render": render,
|
||||
"use_render": render, # 如果False会完全禁用渲染,避免LANE_FREEWAY错误
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": True,
|
||||
"manual_control": False,
|
||||
@@ -198,19 +211,33 @@ def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debu
|
||||
"spawn_vehicles": spawn_vehicles,
|
||||
"spawn_pedestrians": spawn_pedestrians,
|
||||
"spawn_cyclists": spawn_cyclists,
|
||||
# ✅ 关键:设置可用场景数量
|
||||
#"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取可用场景数量
|
||||
num_scenarios = env.config.get("num_scenarios", 1)
|
||||
print(f"可用场景数量: {num_scenarios}")
|
||||
|
||||
for episode in range(num_episodes):
|
||||
print(f"\n{'='*50}")
|
||||
print(f"回合 {episode + 1}/{num_episodes}")
|
||||
if scenario_id is not None:
|
||||
print(f"场景ID: {scenario_id}")
|
||||
else:
|
||||
# 循环使用场景
|
||||
scenario_idx = episode % num_scenarios
|
||||
print(f"使用场景索引: {scenario_idx}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
seed = scenario_id if scenario_id is not None else episode
|
||||
# ✅ 如果不是指定场景,使用循环的场景索引
|
||||
if scenario_id is not None:
|
||||
seed = scenario_id
|
||||
else:
|
||||
seed = episode % num_scenarios
|
||||
obs = env.reset(seed=seed)
|
||||
|
||||
actual_horizon = env.config["horizon"]
|
||||
|
||||
@@ -306,43 +306,72 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
return False
|
||||
|
||||
def _spawn_controlled_agents(self):
|
||||
"""
|
||||
生成应该在当前或之前出现的车辆
|
||||
如果round=0且所有车辆的show_time>0,则生成show_time最小的车辆(保证至少有车辆出现)
|
||||
"""
|
||||
vehicles_to_spawn = []
|
||||
|
||||
for car in self.car_birth_info_list:
|
||||
if car['show_time'] == self.round:
|
||||
agent_id = f"controlled_{car['id']}"
|
||||
vehicle_config = {}
|
||||
vehicle = self.engine.spawn_object(
|
||||
PolicyVehicle,
|
||||
vehicle_config=vehicle_config,
|
||||
position=car['begin'],
|
||||
heading=car['heading']
|
||||
if car['show_time'] <= self.round:
|
||||
vehicles_to_spawn.append(car)
|
||||
|
||||
# 如果当前round没有车辆应该出现,但车辆列表不为空,则生成最早出现的车辆
|
||||
# 这样可以确保在reset时至少有车辆出现
|
||||
# if len(vehicles_to_spawn) == 0 and len(self.car_birth_info_list) > 0:
|
||||
# if self.config.get("debug", False):
|
||||
# self.logger.debug(
|
||||
# f"No vehicles to spawn at round {self.round}, "
|
||||
# f"spawning earliest vehicle instead"
|
||||
# )
|
||||
# # 找到show_time最小的车辆
|
||||
# earliest_car = min(self.car_birth_info_list, key=lambda x: x['show_time'])
|
||||
# vehicles_to_spawn.append(earliest_car)
|
||||
|
||||
for car in vehicles_to_spawn:
|
||||
agent_id = f"controlled_{car['id']}"
|
||||
|
||||
# 避免重复生成
|
||||
if agent_id in self.controlled_agents:
|
||||
continue
|
||||
|
||||
vehicle_config = {}
|
||||
vehicle = self.engine.spawn_object(
|
||||
PolicyVehicle,
|
||||
vehicle_config=vehicle_config,
|
||||
position=car['begin'],
|
||||
heading=car['heading']
|
||||
)
|
||||
|
||||
# 重置车辆状态
|
||||
reset_kwargs = {
|
||||
'position': car['begin'],
|
||||
'heading': car['heading']
|
||||
}
|
||||
|
||||
# 如果启用速度继承,设置初始速度
|
||||
if car.get('velocity') is not None:
|
||||
reset_kwargs['velocity'] = car['velocity']
|
||||
|
||||
vehicle.reset(**reset_kwargs)
|
||||
|
||||
# 设置策略和目的地
|
||||
vehicle.set_policy(self.policy)
|
||||
vehicle.set_destination(car['end'])
|
||||
vehicle.set_expert_vehicle_id(car['id'])
|
||||
|
||||
self.controlled_agents[agent_id] = vehicle
|
||||
self.controlled_agent_ids.append(agent_id)
|
||||
|
||||
# 注册到引擎的 active_agents
|
||||
self.engine.agent_manager.active_agents[agent_id] = vehicle
|
||||
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(
|
||||
f"Spawned vehicle {agent_id} at round {self.round} "
|
||||
f"(show_time={car['show_time']}), position {car['begin']}"
|
||||
)
|
||||
|
||||
# 重置车辆状态
|
||||
reset_kwargs = {
|
||||
'position': car['begin'],
|
||||
'heading': car['heading']
|
||||
}
|
||||
|
||||
# 如果启用速度继承,设置初始速度
|
||||
if car.get('velocity') is not None:
|
||||
reset_kwargs['velocity'] = car['velocity']
|
||||
|
||||
vehicle.reset(**reset_kwargs)
|
||||
|
||||
# 设置策略和目的地
|
||||
vehicle.set_policy(self.policy)
|
||||
vehicle.set_destination(car['end'])
|
||||
vehicle.set_expert_vehicle_id(car['id'])
|
||||
|
||||
self.controlled_agents[agent_id] = vehicle
|
||||
self.controlled_agent_ids.append(agent_id)
|
||||
|
||||
# 注册到引擎的 active_agents
|
||||
self.engine.agent_manager.active_agents[agent_id] = vehicle
|
||||
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(f"Spawned vehicle {agent_id} at round {self.round}, position {car['begin']}")
|
||||
|
||||
def _get_all_obs(self):
|
||||
self.obs_list = []
|
||||
|
||||
@@ -364,8 +393,20 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
traffic_light = 0
|
||||
break
|
||||
|
||||
lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.dynamic_world)
|
||||
# 使用最近10辆车的相对位置与相对速度替代原80维LiDAR点云
|
||||
lidar_cloud_points, detected_objects = self.engine.get_sensor("lidar").perceive(
|
||||
num_lasers=80,
|
||||
distance=30,
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.dynamic_world
|
||||
)
|
||||
nearest_vehicle_info = self.engine.get_sensor("lidar").get_surrounding_vehicles_info(
|
||||
vehicle,
|
||||
detected_objects,
|
||||
perceive_distance=30,
|
||||
num_others=10,
|
||||
add_others_navi=False
|
||||
)
|
||||
side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8,
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
@@ -374,7 +415,7 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
|
||||
obs = (list(state['position'][:2]) + list(state['velocity']) + [state['heading_theta']]
|
||||
+ lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
|
||||
+ nearest_vehicle_info + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
|
||||
+ list(vehicle.destination))
|
||||
|
||||
self.obs_list.append(obs)
|
||||
|
||||
184
Env/verify_observation.py
Normal file
184
Env/verify_observation.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
简单验证脚本:检查车辆是否正确获取观测空间
|
||||
用法:python verify_observations.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from replay_policy import ReplayPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
import numpy as np
|
||||
|
||||
def verify_observations(data_dir, scenario_id=0):
|
||||
"""
|
||||
验证观测空间是否正确获取
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录
|
||||
scenario_id: 场景ID
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("观测空间验证工具")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建环境
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"horizon": 300,
|
||||
"use_render": False, # 不渲染,加速运行
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": False,
|
||||
"manual_control": False,
|
||||
"filter_offroad_vehicles": True,
|
||||
"lane_tolerance": 3.0,
|
||||
"replay_mode": True,
|
||||
"debug": True, # 启用调试以查看详细信息
|
||||
"specific_scenario_id": scenario_id,
|
||||
"use_scenario_duration": True,
|
||||
},
|
||||
agent2policy=None
|
||||
)
|
||||
|
||||
# 重置环境
|
||||
print(f"\n加载场景 {scenario_id}...")
|
||||
obs = env.reset(seed=scenario_id)
|
||||
|
||||
# 输出基本信息
|
||||
print(f"\n场景信息:")
|
||||
print(f" - 可控车辆数: {len(env.controlled_agents)}")
|
||||
print(f" - 观测数量: {len(obs)}")
|
||||
print(f" - 场景时长: {env.scenario_max_duration} 步")
|
||||
print(f" - 车辆生成列表长度: {len(env.car_birth_info_list)}")
|
||||
print(f" - 当前回合数 (round): {env.round}")
|
||||
|
||||
# 检查车辆生成信息
|
||||
if len(env.car_birth_info_list) > 0:
|
||||
print(f"\n车辆生成信息分析:")
|
||||
show_times = [car['show_time'] for car in env.car_birth_info_list]
|
||||
print(f" - show_time 分布: min={min(show_times)}, max={max(show_times)}")
|
||||
print(f" - show_time == 0 的车辆数: {sum(1 for st in show_times if st == 0)}")
|
||||
print(f" - 前5个车辆的 show_time: {show_times[:5]}")
|
||||
else:
|
||||
print(f"\n⚠️ 警告: 车辆生成列表为空!可能原因:")
|
||||
print(f" 1. 所有车辆都被车道过滤移除")
|
||||
print(f" 2. 所有车辆都被类型过滤移除")
|
||||
print(f" 3. 场景数据中没有有效车辆")
|
||||
|
||||
# 验证观测空间
|
||||
print(f"\n" + "=" * 60)
|
||||
print("观测空间验证")
|
||||
print("=" * 60)
|
||||
|
||||
if len(obs) == 0:
|
||||
print("❌ 错误:没有获取到任何观测!")
|
||||
env.close()
|
||||
return False
|
||||
|
||||
# 检查第一个观测
|
||||
first_obs = obs[0]
|
||||
print(f"\n第一个车辆的观测:")
|
||||
print(f" - 观测类型: {type(first_obs)}")
|
||||
print(f" - 观测维度: {len(first_obs)}")
|
||||
|
||||
# 详细解析观测
|
||||
if isinstance(first_obs, (list, np.ndarray)):
|
||||
obs_array = np.array(first_obs)
|
||||
print(f" - 观测形状: {obs_array.shape}")
|
||||
print(f" - 数据类型: {obs_array.dtype}")
|
||||
print(f"\n观测内容分解:")
|
||||
# 新观测划分(与 Env/scenario_env._get_all_obs 对齐):
|
||||
# [x, y] (2) + [vx, vy] (2) + heading (1)
|
||||
# + nearest vehicles info (10 vehicles * 4 = 40)
|
||||
# + side_lidar (10) + lane_line_lidar (10)
|
||||
# + traffic_light (1) + destination (2)
|
||||
lidar_len = 40
|
||||
side_len = 10
|
||||
lane_len = 10
|
||||
pos_slice = slice(0, 2)
|
||||
vel_slice = slice(2, 4)
|
||||
heading_idx = 4
|
||||
lidar_slice = slice(5, 5 + lidar_len)
|
||||
side_slice = slice(lidar_slice.stop, lidar_slice.stop + side_len)
|
||||
lane_slice = slice(side_slice.stop, side_slice.stop + lane_len)
|
||||
tl_idx = lane_slice.stop
|
||||
dest_slice = slice(tl_idx + 1, tl_idx + 3)
|
||||
|
||||
print(f" 位置 (x, y): {obs_array[pos_slice]}")
|
||||
print(f" 速度 (vx, vy): {obs_array[vel_slice]}")
|
||||
print(f" 航向角: {obs_array[heading_idx]:.3f} 弧度")
|
||||
print(f" 最近车辆信息: {len(obs_array[lidar_slice])} 维 (10辆*4: 相对x/相对y/相对vx/相对vy)")
|
||||
print(f" 侧向检测: {len(obs_array[side_slice])} 个点")
|
||||
print(f" 车道线检测: {len(obs_array[lane_slice])} 个点")
|
||||
print(f" 交通灯状态: {obs_array[tl_idx]}")
|
||||
print(f" 目的地 (x, y): {obs_array[dest_slice]}")
|
||||
|
||||
# 检查数据有效性
|
||||
print(f"\n数据有效性检查:")
|
||||
has_nan = np.isnan(obs_array).any()
|
||||
has_inf = np.isinf(obs_array).any()
|
||||
|
||||
if has_nan:
|
||||
print(f" ❌ 观测包含 NaN 值!")
|
||||
else:
|
||||
print(f" ✅ 无 NaN 值")
|
||||
|
||||
if has_inf:
|
||||
print(f" ❌ 观测包含 Inf 值!")
|
||||
else:
|
||||
print(f" ✅ 无 Inf 值")
|
||||
|
||||
# 检查最近车辆信息数据(40维)
|
||||
lidar_data = obs_array[lidar_slice]
|
||||
lidar_min = np.min(lidar_data)
|
||||
lidar_max = np.max(lidar_data)
|
||||
print(f"\n 最近车辆特征范围: [{lidar_min:.2f}, {lidar_max:.2f}]")
|
||||
|
||||
if lidar_max > 0:
|
||||
print(f" ✅ 存在非零最近车辆特征")
|
||||
else:
|
||||
print(f" ⚠️ 最近车辆特征全零(可能无邻车或距离过远)")
|
||||
|
||||
# 运行几步,验证观测持续有效
|
||||
print(f"\n" + "=" * 60)
|
||||
print("多步运行验证(前 5 步)")
|
||||
print("=" * 60)
|
||||
|
||||
for step in range(5):
|
||||
# 空动作
|
||||
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
|
||||
print(f"\nStep {step + 1}:")
|
||||
print(f" - 活跃车辆数: {len(env.controlled_agents)}")
|
||||
print(f" - 观测数量: {len(obs)}")
|
||||
|
||||
if len(obs) > 0:
|
||||
sample_obs = np.array(obs[0])
|
||||
print(f" - 第一辆车位置: ({sample_obs[0]:.2f}, {sample_obs[1]:.2f})")
|
||||
print(f" - 数据有效: {'✅' if not (np.isnan(sample_obs).any() or np.isinf(sample_obs).any()) else '❌'}")
|
||||
|
||||
if dones["__all__"]:
|
||||
print(f" - 场景结束")
|
||||
break
|
||||
|
||||
env.close()
|
||||
|
||||
print(f"\n" + "=" * 60)
|
||||
print("✅ 验证完成!观测空间正常工作")
|
||||
print("=" * 60)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="验证观测空间")
|
||||
parser.add_argument("--data_dir", type=str, default="/home/huangfukk/mdsn", help="数据目录")
|
||||
parser.add_argument("--scenario_id", type=int, default=0, help="场景ID")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
verify_observations(args.data_dir, args.scenario_id)
|
||||
Reference in New Issue
Block a user