代替点云
This commit is contained in:
6
.vscode/settings.json
vendored
Normal file
6
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"cursorpyright.analysis.extraPaths": [
|
||||||
|
"/home/huangfukk/mdsn/metadrive",
|
||||||
|
"/home/huangfukk/mdsn/scenarionet"
|
||||||
|
]
|
||||||
|
}
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
109
Env/check_dataset.py
Normal file
109
Env/check_dataset.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
数据集检查脚本:统计可用的场景数量
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from metadrive.engine.asset_loader import AssetLoader
|
||||||
|
|
||||||
|
def check_dataset(data_dir, subfolder="exp_filtered"):
|
||||||
|
"""
|
||||||
|
检查数据集中的场景数量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: 数据根目录
|
||||||
|
subfolder: 数据子目录(exp_filtered 或 exp_converted)
|
||||||
|
"""
|
||||||
|
print("=" * 80)
|
||||||
|
print("数据集检查工具")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# 获取完整路径
|
||||||
|
full_path = AssetLoader.file_path(data_dir, subfolder, unix_style=False)
|
||||||
|
print(f"\n数据集路径: {full_path}")
|
||||||
|
|
||||||
|
# 检查文件结构
|
||||||
|
if not os.path.exists(full_path):
|
||||||
|
print(f"❌ 错误:路径不存在: {full_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"✅ 路径存在")
|
||||||
|
|
||||||
|
# 读取数据集映射
|
||||||
|
mapping_file = os.path.join(full_path, "dataset_mapping.pkl")
|
||||||
|
if os.path.exists(mapping_file):
|
||||||
|
print(f"\n读取 dataset_mapping.pkl...")
|
||||||
|
with open(mapping_file, 'rb') as f:
|
||||||
|
dataset_mapping = pickle.load(f)
|
||||||
|
|
||||||
|
print(f"✅ 数据集映射文件存在")
|
||||||
|
print(f" 映射的场景数量: {len(dataset_mapping)}")
|
||||||
|
|
||||||
|
# 统计各个子目录的分布
|
||||||
|
subdirs = {}
|
||||||
|
for filename, subdir in dataset_mapping.items():
|
||||||
|
if subdir not in subdirs:
|
||||||
|
subdirs[subdir] = []
|
||||||
|
subdirs[subdir].append(filename)
|
||||||
|
|
||||||
|
print(f"\n场景分布:")
|
||||||
|
for subdir, files in subdirs.items():
|
||||||
|
print(f" {subdir}: {len(files)} 个场景")
|
||||||
|
|
||||||
|
# 读取数据集摘要
|
||||||
|
summary_file = os.path.join(full_path, "dataset_summary.pkl")
|
||||||
|
if os.path.exists(summary_file):
|
||||||
|
print(f"\n读取 dataset_summary.pkl...")
|
||||||
|
with open(summary_file, 'rb') as f:
|
||||||
|
dataset_summary = pickle.load(f)
|
||||||
|
|
||||||
|
print(f"✅ 数据集摘要文件存在")
|
||||||
|
print(f" 摘要的场景数量: {len(dataset_summary)}")
|
||||||
|
|
||||||
|
# 打印前几个场景的ID
|
||||||
|
print(f"\n前10个场景ID:")
|
||||||
|
for i, (scenario_id, info) in enumerate(list(dataset_summary.items())[:10]):
|
||||||
|
if isinstance(info, dict):
|
||||||
|
track_length = info.get('track_length', 'N/A')
|
||||||
|
num_objects = info.get('number_summary', {}).get('num_objects', 'N/A')
|
||||||
|
print(f" {i}: {scenario_id[:16]}... (时长: {track_length}, 对象数: {num_objects})")
|
||||||
|
|
||||||
|
# 检查实际文件
|
||||||
|
print(f"\n检查实际文件...")
|
||||||
|
pkl_files = list(Path(full_path).rglob("*.pkl"))
|
||||||
|
# 排除dataset_mapping和dataset_summary
|
||||||
|
scenario_files = [f for f in pkl_files if f.name not in ["dataset_mapping.pkl", "dataset_summary.pkl"]]
|
||||||
|
print(f" 实际场景文件数量: {len(scenario_files)}")
|
||||||
|
|
||||||
|
# 检查子目录
|
||||||
|
print(f"\n子目录结构:")
|
||||||
|
for item in os.listdir(full_path):
|
||||||
|
item_path = os.path.join(full_path, item)
|
||||||
|
if os.path.isdir(item_path):
|
||||||
|
pkl_count = len([f for f in os.listdir(item_path) if f.endswith('.pkl')])
|
||||||
|
print(f" {item}/: {pkl_count} 个pkl文件")
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("检查完成")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# 返回场景数量
|
||||||
|
return len(dataset_mapping) if 'dataset_mapping' in locals() else 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
|
||||||
|
|
||||||
|
print("\n【检查 exp_filtered 数据集】")
|
||||||
|
num_filtered = check_dataset(WAYMO_DATA_DIR, "exp_filtered")
|
||||||
|
|
||||||
|
print("\n\n【检查 exp_converted 数据集】")
|
||||||
|
num_converted = check_dataset(WAYMO_DATA_DIR, "exp_converted")
|
||||||
|
|
||||||
|
print("\n\n总结:")
|
||||||
|
print(f" exp_filtered: {num_filtered} 个场景")
|
||||||
|
print(f" exp_converted: {num_converted} 个场景")
|
||||||
|
|
||||||
@@ -4,7 +4,7 @@ from simple_idm_policy import ConstantVelocityPolicy
|
|||||||
from replay_policy import ReplayPolicy
|
from replay_policy import ReplayPolicy
|
||||||
from metadrive.engine.asset_loader import AssetLoader
|
from metadrive.engine.asset_loader import AssetLoader
|
||||||
|
|
||||||
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn/exp_converted"
|
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
|
||||||
|
|
||||||
|
|
||||||
def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
|
def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
|
||||||
@@ -40,10 +40,10 @@ def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=Fa
|
|||||||
# ✅ 环境创建移到循环外面,避免重复创建
|
# ✅ 环境创建移到循环外面,避免重复创建
|
||||||
env = MultiAgentScenarioEnv(
|
env = MultiAgentScenarioEnv(
|
||||||
config={
|
config={
|
||||||
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
|
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
|
||||||
"is_multi_agent": True,
|
"is_multi_agent": True,
|
||||||
"horizon": horizon,
|
"horizon": horizon,
|
||||||
"use_render": render,
|
"use_render": render, # 如果False会完全禁用渲染,避免LANE_FREEWAY错误
|
||||||
"sequential_seed": True,
|
"sequential_seed": True,
|
||||||
"reactive_traffic": False, # 回放模式下不需要反应式交通
|
"reactive_traffic": False, # 回放模式下不需要反应式交通
|
||||||
"manual_control": False,
|
"manual_control": False,
|
||||||
@@ -57,20 +57,33 @@ def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=Fa
|
|||||||
"spawn_vehicles": spawn_vehicles,
|
"spawn_vehicles": spawn_vehicles,
|
||||||
"spawn_pedestrians": spawn_pedestrians,
|
"spawn_pedestrians": spawn_pedestrians,
|
||||||
"spawn_cyclists": spawn_cyclists,
|
"spawn_cyclists": spawn_cyclists,
|
||||||
|
# ✅ 关键:设置可用场景数量
|
||||||
|
#"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数
|
||||||
},
|
},
|
||||||
agent2policy=None # 回放模式不需要统一策略
|
agent2policy=None # 回放模式不需要统一策略
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 获取可用场景数量
|
||||||
|
num_scenarios = env.config.get("num_scenarios", 1)
|
||||||
|
print(f"可用场景数量: {num_scenarios}")
|
||||||
|
|
||||||
for episode in range(num_episodes):
|
for episode in range(num_episodes):
|
||||||
print(f"\n{'='*50}")
|
print(f"\n{'='*50}")
|
||||||
print(f"回合 {episode + 1}/{num_episodes}")
|
print(f"回合 {episode + 1}/{num_episodes}")
|
||||||
if scenario_id is not None:
|
if scenario_id is not None:
|
||||||
print(f"场景ID: {scenario_id}")
|
print(f"场景ID: {scenario_id}")
|
||||||
|
else:
|
||||||
|
# 循环使用场景
|
||||||
|
scenario_idx = episode % num_scenarios
|
||||||
|
print(f"使用场景索引: {scenario_idx}")
|
||||||
print(f"{'='*50}")
|
print(f"{'='*50}")
|
||||||
|
|
||||||
# ✅ 如果不是指定场景,使用seed来遍历不同场景
|
# ✅ 如果不是指定场景,使用循环的场景索引
|
||||||
seed = scenario_id if scenario_id is not None else episode
|
if scenario_id is not None:
|
||||||
|
seed = scenario_id
|
||||||
|
else:
|
||||||
|
seed = episode % num_scenarios
|
||||||
obs = env.reset(seed=seed)
|
obs = env.reset(seed=seed)
|
||||||
|
|
||||||
# 为每个车辆分配 ReplayPolicy
|
# 为每个车辆分配 ReplayPolicy
|
||||||
@@ -180,11 +193,11 @@ def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debu
|
|||||||
|
|
||||||
env = MultiAgentScenarioEnv(
|
env = MultiAgentScenarioEnv(
|
||||||
config={
|
config={
|
||||||
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
|
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
|
||||||
"is_multi_agent": True,
|
"is_multi_agent": True,
|
||||||
"num_controlled_agents": 3,
|
"num_controlled_agents": 3,
|
||||||
"horizon": horizon,
|
"horizon": horizon,
|
||||||
"use_render": render,
|
"use_render": render, # 如果False会完全禁用渲染,避免LANE_FREEWAY错误
|
||||||
"sequential_seed": True,
|
"sequential_seed": True,
|
||||||
"reactive_traffic": True,
|
"reactive_traffic": True,
|
||||||
"manual_control": False,
|
"manual_control": False,
|
||||||
@@ -198,19 +211,33 @@ def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debu
|
|||||||
"spawn_vehicles": spawn_vehicles,
|
"spawn_vehicles": spawn_vehicles,
|
||||||
"spawn_pedestrians": spawn_pedestrians,
|
"spawn_pedestrians": spawn_pedestrians,
|
||||||
"spawn_cyclists": spawn_cyclists,
|
"spawn_cyclists": spawn_cyclists,
|
||||||
|
# ✅ 关键:设置可用场景数量
|
||||||
|
#"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数
|
||||||
},
|
},
|
||||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 获取可用场景数量
|
||||||
|
num_scenarios = env.config.get("num_scenarios", 1)
|
||||||
|
print(f"可用场景数量: {num_scenarios}")
|
||||||
|
|
||||||
for episode in range(num_episodes):
|
for episode in range(num_episodes):
|
||||||
print(f"\n{'='*50}")
|
print(f"\n{'='*50}")
|
||||||
print(f"回合 {episode + 1}/{num_episodes}")
|
print(f"回合 {episode + 1}/{num_episodes}")
|
||||||
if scenario_id is not None:
|
if scenario_id is not None:
|
||||||
print(f"场景ID: {scenario_id}")
|
print(f"场景ID: {scenario_id}")
|
||||||
|
else:
|
||||||
|
# 循环使用场景
|
||||||
|
scenario_idx = episode % num_scenarios
|
||||||
|
print(f"使用场景索引: {scenario_idx}")
|
||||||
print(f"{'='*50}")
|
print(f"{'='*50}")
|
||||||
|
|
||||||
seed = scenario_id if scenario_id is not None else episode
|
# ✅ 如果不是指定场景,使用循环的场景索引
|
||||||
|
if scenario_id is not None:
|
||||||
|
seed = scenario_id
|
||||||
|
else:
|
||||||
|
seed = episode % num_scenarios
|
||||||
obs = env.reset(seed=seed)
|
obs = env.reset(seed=seed)
|
||||||
|
|
||||||
actual_horizon = env.config["horizon"]
|
actual_horizon = env.config["horizon"]
|
||||||
|
|||||||
@@ -306,43 +306,72 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def _spawn_controlled_agents(self):
|
def _spawn_controlled_agents(self):
|
||||||
|
"""
|
||||||
|
生成应该在当前或之前出现的车辆
|
||||||
|
如果round=0且所有车辆的show_time>0,则生成show_time最小的车辆(保证至少有车辆出现)
|
||||||
|
"""
|
||||||
|
vehicles_to_spawn = []
|
||||||
|
|
||||||
for car in self.car_birth_info_list:
|
for car in self.car_birth_info_list:
|
||||||
if car['show_time'] == self.round:
|
if car['show_time'] <= self.round:
|
||||||
agent_id = f"controlled_{car['id']}"
|
vehicles_to_spawn.append(car)
|
||||||
vehicle_config = {}
|
|
||||||
vehicle = self.engine.spawn_object(
|
# 如果当前round没有车辆应该出现,但车辆列表不为空,则生成最早出现的车辆
|
||||||
PolicyVehicle,
|
# 这样可以确保在reset时至少有车辆出现
|
||||||
vehicle_config=vehicle_config,
|
# if len(vehicles_to_spawn) == 0 and len(self.car_birth_info_list) > 0:
|
||||||
position=car['begin'],
|
# if self.config.get("debug", False):
|
||||||
heading=car['heading']
|
# self.logger.debug(
|
||||||
|
# f"No vehicles to spawn at round {self.round}, "
|
||||||
|
# f"spawning earliest vehicle instead"
|
||||||
|
# )
|
||||||
|
# # 找到show_time最小的车辆
|
||||||
|
# earliest_car = min(self.car_birth_info_list, key=lambda x: x['show_time'])
|
||||||
|
# vehicles_to_spawn.append(earliest_car)
|
||||||
|
|
||||||
|
for car in vehicles_to_spawn:
|
||||||
|
agent_id = f"controlled_{car['id']}"
|
||||||
|
|
||||||
|
# 避免重复生成
|
||||||
|
if agent_id in self.controlled_agents:
|
||||||
|
continue
|
||||||
|
|
||||||
|
vehicle_config = {}
|
||||||
|
vehicle = self.engine.spawn_object(
|
||||||
|
PolicyVehicle,
|
||||||
|
vehicle_config=vehicle_config,
|
||||||
|
position=car['begin'],
|
||||||
|
heading=car['heading']
|
||||||
|
)
|
||||||
|
|
||||||
|
# 重置车辆状态
|
||||||
|
reset_kwargs = {
|
||||||
|
'position': car['begin'],
|
||||||
|
'heading': car['heading']
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果启用速度继承,设置初始速度
|
||||||
|
if car.get('velocity') is not None:
|
||||||
|
reset_kwargs['velocity'] = car['velocity']
|
||||||
|
|
||||||
|
vehicle.reset(**reset_kwargs)
|
||||||
|
|
||||||
|
# 设置策略和目的地
|
||||||
|
vehicle.set_policy(self.policy)
|
||||||
|
vehicle.set_destination(car['end'])
|
||||||
|
vehicle.set_expert_vehicle_id(car['id'])
|
||||||
|
|
||||||
|
self.controlled_agents[agent_id] = vehicle
|
||||||
|
self.controlled_agent_ids.append(agent_id)
|
||||||
|
|
||||||
|
# 注册到引擎的 active_agents
|
||||||
|
self.engine.agent_manager.active_agents[agent_id] = vehicle
|
||||||
|
|
||||||
|
if self.config.get("debug", False):
|
||||||
|
self.logger.debug(
|
||||||
|
f"Spawned vehicle {agent_id} at round {self.round} "
|
||||||
|
f"(show_time={car['show_time']}), position {car['begin']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 重置车辆状态
|
|
||||||
reset_kwargs = {
|
|
||||||
'position': car['begin'],
|
|
||||||
'heading': car['heading']
|
|
||||||
}
|
|
||||||
|
|
||||||
# 如果启用速度继承,设置初始速度
|
|
||||||
if car.get('velocity') is not None:
|
|
||||||
reset_kwargs['velocity'] = car['velocity']
|
|
||||||
|
|
||||||
vehicle.reset(**reset_kwargs)
|
|
||||||
|
|
||||||
# 设置策略和目的地
|
|
||||||
vehicle.set_policy(self.policy)
|
|
||||||
vehicle.set_destination(car['end'])
|
|
||||||
vehicle.set_expert_vehicle_id(car['id'])
|
|
||||||
|
|
||||||
self.controlled_agents[agent_id] = vehicle
|
|
||||||
self.controlled_agent_ids.append(agent_id)
|
|
||||||
|
|
||||||
# 注册到引擎的 active_agents
|
|
||||||
self.engine.agent_manager.active_agents[agent_id] = vehicle
|
|
||||||
|
|
||||||
if self.config.get("debug", False):
|
|
||||||
self.logger.debug(f"Spawned vehicle {agent_id} at round {self.round}, position {car['begin']}")
|
|
||||||
|
|
||||||
def _get_all_obs(self):
|
def _get_all_obs(self):
|
||||||
self.obs_list = []
|
self.obs_list = []
|
||||||
|
|
||||||
@@ -364,8 +393,20 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
|||||||
traffic_light = 0
|
traffic_light = 0
|
||||||
break
|
break
|
||||||
|
|
||||||
lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle,
|
# 使用最近10辆车的相对位置与相对速度替代原80维LiDAR点云
|
||||||
physics_world=self.engine.physics_world.dynamic_world)
|
lidar_cloud_points, detected_objects = self.engine.get_sensor("lidar").perceive(
|
||||||
|
num_lasers=80,
|
||||||
|
distance=30,
|
||||||
|
base_vehicle=vehicle,
|
||||||
|
physics_world=self.engine.physics_world.dynamic_world
|
||||||
|
)
|
||||||
|
nearest_vehicle_info = self.engine.get_sensor("lidar").get_surrounding_vehicles_info(
|
||||||
|
vehicle,
|
||||||
|
detected_objects,
|
||||||
|
perceive_distance=30,
|
||||||
|
num_others=10,
|
||||||
|
add_others_navi=False
|
||||||
|
)
|
||||||
side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8,
|
side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8,
|
||||||
base_vehicle=vehicle,
|
base_vehicle=vehicle,
|
||||||
physics_world=self.engine.physics_world.static_world)
|
physics_world=self.engine.physics_world.static_world)
|
||||||
@@ -374,7 +415,7 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
|||||||
physics_world=self.engine.physics_world.static_world)
|
physics_world=self.engine.physics_world.static_world)
|
||||||
|
|
||||||
obs = (list(state['position'][:2]) + list(state['velocity']) + [state['heading_theta']]
|
obs = (list(state['position'][:2]) + list(state['velocity']) + [state['heading_theta']]
|
||||||
+ lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
|
+ nearest_vehicle_info + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
|
||||||
+ list(vehicle.destination))
|
+ list(vehicle.destination))
|
||||||
|
|
||||||
self.obs_list.append(obs)
|
self.obs_list.append(obs)
|
||||||
|
|||||||
184
Env/verify_observation.py
Normal file
184
Env/verify_observation.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
简单验证脚本:检查车辆是否正确获取观测空间
|
||||||
|
用法:python verify_observations.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from scenario_env import MultiAgentScenarioEnv
|
||||||
|
from replay_policy import ReplayPolicy
|
||||||
|
from metadrive.engine.asset_loader import AssetLoader
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def verify_observations(data_dir, scenario_id=0):
|
||||||
|
"""
|
||||||
|
验证观测空间是否正确获取
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: 数据目录
|
||||||
|
scenario_id: 场景ID
|
||||||
|
"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("观测空间验证工具")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 创建环境
|
||||||
|
env = MultiAgentScenarioEnv(
|
||||||
|
config={
|
||||||
|
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
|
||||||
|
"is_multi_agent": True,
|
||||||
|
"horizon": 300,
|
||||||
|
"use_render": False, # 不渲染,加速运行
|
||||||
|
"sequential_seed": True,
|
||||||
|
"reactive_traffic": False,
|
||||||
|
"manual_control": False,
|
||||||
|
"filter_offroad_vehicles": True,
|
||||||
|
"lane_tolerance": 3.0,
|
||||||
|
"replay_mode": True,
|
||||||
|
"debug": True, # 启用调试以查看详细信息
|
||||||
|
"specific_scenario_id": scenario_id,
|
||||||
|
"use_scenario_duration": True,
|
||||||
|
},
|
||||||
|
agent2policy=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# 重置环境
|
||||||
|
print(f"\n加载场景 {scenario_id}...")
|
||||||
|
obs = env.reset(seed=scenario_id)
|
||||||
|
|
||||||
|
# 输出基本信息
|
||||||
|
print(f"\n场景信息:")
|
||||||
|
print(f" - 可控车辆数: {len(env.controlled_agents)}")
|
||||||
|
print(f" - 观测数量: {len(obs)}")
|
||||||
|
print(f" - 场景时长: {env.scenario_max_duration} 步")
|
||||||
|
print(f" - 车辆生成列表长度: {len(env.car_birth_info_list)}")
|
||||||
|
print(f" - 当前回合数 (round): {env.round}")
|
||||||
|
|
||||||
|
# 检查车辆生成信息
|
||||||
|
if len(env.car_birth_info_list) > 0:
|
||||||
|
print(f"\n车辆生成信息分析:")
|
||||||
|
show_times = [car['show_time'] for car in env.car_birth_info_list]
|
||||||
|
print(f" - show_time 分布: min={min(show_times)}, max={max(show_times)}")
|
||||||
|
print(f" - show_time == 0 的车辆数: {sum(1 for st in show_times if st == 0)}")
|
||||||
|
print(f" - 前5个车辆的 show_time: {show_times[:5]}")
|
||||||
|
else:
|
||||||
|
print(f"\n⚠️ 警告: 车辆生成列表为空!可能原因:")
|
||||||
|
print(f" 1. 所有车辆都被车道过滤移除")
|
||||||
|
print(f" 2. 所有车辆都被类型过滤移除")
|
||||||
|
print(f" 3. 场景数据中没有有效车辆")
|
||||||
|
|
||||||
|
# 验证观测空间
|
||||||
|
print(f"\n" + "=" * 60)
|
||||||
|
print("观测空间验证")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if len(obs) == 0:
|
||||||
|
print("❌ 错误:没有获取到任何观测!")
|
||||||
|
env.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查第一个观测
|
||||||
|
first_obs = obs[0]
|
||||||
|
print(f"\n第一个车辆的观测:")
|
||||||
|
print(f" - 观测类型: {type(first_obs)}")
|
||||||
|
print(f" - 观测维度: {len(first_obs)}")
|
||||||
|
|
||||||
|
# 详细解析观测
|
||||||
|
if isinstance(first_obs, (list, np.ndarray)):
|
||||||
|
obs_array = np.array(first_obs)
|
||||||
|
print(f" - 观测形状: {obs_array.shape}")
|
||||||
|
print(f" - 数据类型: {obs_array.dtype}")
|
||||||
|
print(f"\n观测内容分解:")
|
||||||
|
# 新观测划分(与 Env/scenario_env._get_all_obs 对齐):
|
||||||
|
# [x, y] (2) + [vx, vy] (2) + heading (1)
|
||||||
|
# + nearest vehicles info (10 vehicles * 4 = 40)
|
||||||
|
# + side_lidar (10) + lane_line_lidar (10)
|
||||||
|
# + traffic_light (1) + destination (2)
|
||||||
|
lidar_len = 40
|
||||||
|
side_len = 10
|
||||||
|
lane_len = 10
|
||||||
|
pos_slice = slice(0, 2)
|
||||||
|
vel_slice = slice(2, 4)
|
||||||
|
heading_idx = 4
|
||||||
|
lidar_slice = slice(5, 5 + lidar_len)
|
||||||
|
side_slice = slice(lidar_slice.stop, lidar_slice.stop + side_len)
|
||||||
|
lane_slice = slice(side_slice.stop, side_slice.stop + lane_len)
|
||||||
|
tl_idx = lane_slice.stop
|
||||||
|
dest_slice = slice(tl_idx + 1, tl_idx + 3)
|
||||||
|
|
||||||
|
print(f" 位置 (x, y): {obs_array[pos_slice]}")
|
||||||
|
print(f" 速度 (vx, vy): {obs_array[vel_slice]}")
|
||||||
|
print(f" 航向角: {obs_array[heading_idx]:.3f} 弧度")
|
||||||
|
print(f" 最近车辆信息: {len(obs_array[lidar_slice])} 维 (10辆*4: 相对x/相对y/相对vx/相对vy)")
|
||||||
|
print(f" 侧向检测: {len(obs_array[side_slice])} 个点")
|
||||||
|
print(f" 车道线检测: {len(obs_array[lane_slice])} 个点")
|
||||||
|
print(f" 交通灯状态: {obs_array[tl_idx]}")
|
||||||
|
print(f" 目的地 (x, y): {obs_array[dest_slice]}")
|
||||||
|
|
||||||
|
# 检查数据有效性
|
||||||
|
print(f"\n数据有效性检查:")
|
||||||
|
has_nan = np.isnan(obs_array).any()
|
||||||
|
has_inf = np.isinf(obs_array).any()
|
||||||
|
|
||||||
|
if has_nan:
|
||||||
|
print(f" ❌ 观测包含 NaN 值!")
|
||||||
|
else:
|
||||||
|
print(f" ✅ 无 NaN 值")
|
||||||
|
|
||||||
|
if has_inf:
|
||||||
|
print(f" ❌ 观测包含 Inf 值!")
|
||||||
|
else:
|
||||||
|
print(f" ✅ 无 Inf 值")
|
||||||
|
|
||||||
|
# 检查最近车辆信息数据(40维)
|
||||||
|
lidar_data = obs_array[lidar_slice]
|
||||||
|
lidar_min = np.min(lidar_data)
|
||||||
|
lidar_max = np.max(lidar_data)
|
||||||
|
print(f"\n 最近车辆特征范围: [{lidar_min:.2f}, {lidar_max:.2f}]")
|
||||||
|
|
||||||
|
if lidar_max > 0:
|
||||||
|
print(f" ✅ 存在非零最近车辆特征")
|
||||||
|
else:
|
||||||
|
print(f" ⚠️ 最近车辆特征全零(可能无邻车或距离过远)")
|
||||||
|
|
||||||
|
# 运行几步,验证观测持续有效
|
||||||
|
print(f"\n" + "=" * 60)
|
||||||
|
print("多步运行验证(前 5 步)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
for step in range(5):
|
||||||
|
# 空动作
|
||||||
|
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
|
||||||
|
obs, rewards, dones, infos = env.step(actions)
|
||||||
|
|
||||||
|
print(f"\nStep {step + 1}:")
|
||||||
|
print(f" - 活跃车辆数: {len(env.controlled_agents)}")
|
||||||
|
print(f" - 观测数量: {len(obs)}")
|
||||||
|
|
||||||
|
if len(obs) > 0:
|
||||||
|
sample_obs = np.array(obs[0])
|
||||||
|
print(f" - 第一辆车位置: ({sample_obs[0]:.2f}, {sample_obs[1]:.2f})")
|
||||||
|
print(f" - 数据有效: {'✅' if not (np.isnan(sample_obs).any() or np.isinf(sample_obs).any()) else '❌'}")
|
||||||
|
|
||||||
|
if dones["__all__"]:
|
||||||
|
print(f" - 场景结束")
|
||||||
|
break
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
print(f"\n" + "=" * 60)
|
||||||
|
print("✅ 验证完成!观测空间正常工作")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="验证观测空间")
|
||||||
|
parser.add_argument("--data_dir", type=str, default="/home/huangfukk/mdsn", help="数据目录")
|
||||||
|
parser.add_argument("--scenario_id", type=int, default=0, help="场景ID")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
verify_observations(args.data_dir, args.scenario_id)
|
||||||
115
train_magail.py
Normal file
115
train_magail.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from Env.scenario_env import MultiAgentScenarioEnv
|
||||||
|
from Algorithm.magail import MAGAIL
|
||||||
|
from Algorithm.buffer import RolloutBuffer # 假设 Buffer 在这里
|
||||||
|
# 假设你有加载专家数据的工具
|
||||||
|
# from utils import load_expert_buffer
|
||||||
|
|
||||||
|
# --- 配置 ---
|
||||||
|
CONFIG = {
|
||||||
|
"data_dir": "/home/huangfukk/mdsn",
|
||||||
|
# ... 其他配置 ...
|
||||||
|
"state_dim": 30, # 你的 Observation 维度
|
||||||
|
"action_dim": 2, # [steering, throttle]
|
||||||
|
"rollout_length": 2048, # Buffer 大小
|
||||||
|
}
|
||||||
|
|
||||||
|
def train():
|
||||||
|
# 1. 环境
|
||||||
|
env = MultiAgentScenarioEnv(config={...}, agent2policy=None)
|
||||||
|
|
||||||
|
# 2. 专家 Buffer (伪代码)
|
||||||
|
# expert_buffer = RolloutBuffer(...)
|
||||||
|
# expert_buffer.load(...)
|
||||||
|
# 必须保证 expert_buffer.sample() 返回 (state, next_state) 供 Discriminator 训练
|
||||||
|
|
||||||
|
# 3. 初始化 MAGAIL
|
||||||
|
magail = MAGAIL(
|
||||||
|
buffer_exp=expert_buffer, # 传入专家 buffer
|
||||||
|
input_dim=(CONFIG["state_dim"],),
|
||||||
|
action_shape=(CONFIG["action_dim"],), # PPO 初始化可能需要 action_shape,原代码好像没传?
|
||||||
|
device=torch.device("cuda"),
|
||||||
|
rollout_length=CONFIG["rollout_length"]
|
||||||
|
)
|
||||||
|
|
||||||
|
writer = SummaryWriter("./logs")
|
||||||
|
|
||||||
|
# 4. 训练循环
|
||||||
|
total_steps = 0
|
||||||
|
obs_dict = env.reset()
|
||||||
|
|
||||||
|
# 将 Dict Obs 转为 Array: (N_agents, Obs_dim)
|
||||||
|
# 假设所有 Agent 的 Obs 维度相同
|
||||||
|
agents = list(obs_dict.keys())
|
||||||
|
current_obs = np.stack([obs_dict[a] for a in agents])
|
||||||
|
|
||||||
|
# 如果需要 state_gail,假设它就是 obs
|
||||||
|
current_state_gail = current_obs.copy()
|
||||||
|
|
||||||
|
while total_steps < 1e7:
|
||||||
|
# --- 收集数据 (Rollout) ---
|
||||||
|
# PPO 的 buffer 长度是 rollout_length
|
||||||
|
# 我们可以在这里调用 magail.step 来自动处理 explore 和 buffer append
|
||||||
|
|
||||||
|
# 注意:magail.step 内部调用了 env.step,这可能不适用于 Multi-Agent 环境返回 Dict 的情况
|
||||||
|
# 原 PPO.step 似乎是为 Single-Agent 或 VectorEnv 设计的
|
||||||
|
# 这里我们需要手动写 Rollout 循环或者修改 PPO.step 以适配 Dict 返回值
|
||||||
|
|
||||||
|
# --- 方案:手动 Rollout 适配 Multi-Agent ---
|
||||||
|
for _ in range(CONFIG["rollout_length"]):
|
||||||
|
# 1. 决策
|
||||||
|
actions_list, log_pis_list = magail.explore(current_obs)
|
||||||
|
|
||||||
|
# 2. 拼装 Action Dict
|
||||||
|
action_dict = {agent_id: action for agent_id, action in zip(agents, actions_list)}
|
||||||
|
|
||||||
|
# 3. 环境步进
|
||||||
|
next_obs_dict, rewards_dict, dones_dict, infos_dict = env.step(action_dict)
|
||||||
|
|
||||||
|
# 4. 处理返回值
|
||||||
|
# 需要处理 Agent 死亡/重置的情况。这里简化假设 Agent 数量不变
|
||||||
|
next_obs = np.stack([next_obs_dict[a] for a in agents])
|
||||||
|
rewards = np.array([rewards_dict[a] for a in agents])
|
||||||
|
dones = np.array([dones_dict[a] for a in agents])
|
||||||
|
# truncated 通常在 infos 里或者 dones 里隐含,需根据 gym 版本确认
|
||||||
|
terminated = dones # 简化
|
||||||
|
truncated = [False] * len(agents) # 简化
|
||||||
|
|
||||||
|
# 5. 存入 Buffer
|
||||||
|
# 获取当前 Actor 的均值方差用于 PPO 更新
|
||||||
|
with torch.no_grad():
|
||||||
|
# 重新计算一遍或者在 explore 里返回
|
||||||
|
# 这里 magail.actor(state) 返回的是 mean (deterministic action)
|
||||||
|
means = magail.actor(torch.tensor(current_obs, device=magail.device, dtype=torch.float)).cpu().numpy()
|
||||||
|
stds = magail.actor.log_stds.exp().detach().cpu().numpy()
|
||||||
|
# 如果是 StateIndependentPolicy,std 可能是共享的参数,维度需要广播
|
||||||
|
if stds.shape[0] != len(agents):
|
||||||
|
stds = np.repeat(stds, len(agents), axis=0)
|
||||||
|
|
||||||
|
magail.buffer.append(
|
||||||
|
current_obs, current_state_gail, actions_list,
|
||||||
|
rewards, dones, terminated, log_pis_list,
|
||||||
|
next_obs, next_obs, # next_state_gail = next_obs
|
||||||
|
means, stds
|
||||||
|
)
|
||||||
|
|
||||||
|
current_obs = next_obs
|
||||||
|
current_state_gail = next_obs # 更新 gail state
|
||||||
|
total_steps += len(agents) # 步数增加 N
|
||||||
|
|
||||||
|
# 处理 Done
|
||||||
|
if all(dones):
|
||||||
|
obs_dict = env.reset()
|
||||||
|
current_obs = np.stack([obs_dict[a] for a in agents])
|
||||||
|
current_state_gail = current_obs
|
||||||
|
|
||||||
|
# --- 更新 ---
|
||||||
|
# 当 Buffer 满时 (rollout_length),调用 update
|
||||||
|
magail.update(writer, total_steps)
|
||||||
|
|
||||||
|
# 保存
|
||||||
|
if total_steps % 10000 == 0:
|
||||||
|
magail.save_models("./models")
|
||||||
|
|
||||||
Reference in New Issue
Block a user