2 Commits
main ... dev

12 changed files with 1562 additions and 94 deletions

6
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,6 @@
{
"cursorpyright.analysis.extraPaths": [
"/home/huangfukk/mdsn/metadrive",
"/home/huangfukk/mdsn/scenarionet"
]
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

109
Env/check_dataset.py Normal file
View File

@@ -0,0 +1,109 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
数据集检查脚本:统计可用的场景数量
"""
import pickle
import os
from pathlib import Path
from metadrive.engine.asset_loader import AssetLoader
def check_dataset(data_dir, subfolder="exp_filtered"):
"""
检查数据集中的场景数量
Args:
data_dir: 数据根目录
subfolder: 数据子目录exp_filtered 或 exp_converted
"""
print("=" * 80)
print("数据集检查工具")
print("=" * 80)
# 获取完整路径
full_path = AssetLoader.file_path(data_dir, subfolder, unix_style=False)
print(f"\n数据集路径: {full_path}")
# 检查文件结构
if not os.path.exists(full_path):
print(f"❌ 错误:路径不存在: {full_path}")
return
print(f"✅ 路径存在")
# 读取数据集映射
mapping_file = os.path.join(full_path, "dataset_mapping.pkl")
if os.path.exists(mapping_file):
print(f"\n读取 dataset_mapping.pkl...")
with open(mapping_file, 'rb') as f:
dataset_mapping = pickle.load(f)
print(f"✅ 数据集映射文件存在")
print(f" 映射的场景数量: {len(dataset_mapping)}")
# 统计各个子目录的分布
subdirs = {}
for filename, subdir in dataset_mapping.items():
if subdir not in subdirs:
subdirs[subdir] = []
subdirs[subdir].append(filename)
print(f"\n场景分布:")
for subdir, files in subdirs.items():
print(f" {subdir}: {len(files)} 个场景")
# 读取数据集摘要
summary_file = os.path.join(full_path, "dataset_summary.pkl")
if os.path.exists(summary_file):
print(f"\n读取 dataset_summary.pkl...")
with open(summary_file, 'rb') as f:
dataset_summary = pickle.load(f)
print(f"✅ 数据集摘要文件存在")
print(f" 摘要的场景数量: {len(dataset_summary)}")
# 打印前几个场景的ID
print(f"\n前10个场景ID:")
for i, (scenario_id, info) in enumerate(list(dataset_summary.items())[:10]):
if isinstance(info, dict):
track_length = info.get('track_length', 'N/A')
num_objects = info.get('number_summary', {}).get('num_objects', 'N/A')
print(f" {i}: {scenario_id[:16]}... (时长: {track_length}, 对象数: {num_objects})")
# 检查实际文件
print(f"\n检查实际文件...")
pkl_files = list(Path(full_path).rglob("*.pkl"))
# 排除dataset_mapping和dataset_summary
scenario_files = [f for f in pkl_files if f.name not in ["dataset_mapping.pkl", "dataset_summary.pkl"]]
print(f" 实际场景文件数量: {len(scenario_files)}")
# 检查子目录
print(f"\n子目录结构:")
for item in os.listdir(full_path):
item_path = os.path.join(full_path, item)
if os.path.isdir(item_path):
pkl_count = len([f for f in os.listdir(item_path) if f.endswith('.pkl')])
print(f" {item}/: {pkl_count} 个pkl文件")
print("\n" + "=" * 80)
print("检查完成")
print("=" * 80)
# 返回场景数量
return len(dataset_mapping) if 'dataset_mapping' in locals() else 0
if __name__ == "__main__":
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
print("\n【检查 exp_filtered 数据集】")
num_filtered = check_dataset(WAYMO_DATA_DIR, "exp_filtered")
print("\n\n【检查 exp_converted 数据集】")
num_converted = check_dataset(WAYMO_DATA_DIR, "exp_converted")
print("\n\n总结:")
print(f" exp_filtered: {num_filtered} 个场景")
print(f" exp_converted: {num_converted} 个场景")

62
Env/replay_policy.py Normal file
View File

@@ -0,0 +1,62 @@
import numpy as np
class ReplayPolicy:
"""
严格回放策略:根据专家轨迹数据,逐帧回放车辆状态
"""
def __init__(self, expert_trajectory, vehicle_id):
"""
Args:
expert_trajectory: 专家轨迹字典,包含 positions, headings, velocities, valid
vehicle_id: 车辆ID用于调试
"""
self.trajectory = expert_trajectory
self.vehicle_id = vehicle_id
self.current_step = 0
def act(self, observation=None):
"""
返回动作:在回放模式下返回空动作
实际状态由环境直接设置
"""
return [0.0, 0.0]
def get_target_state(self, step):
"""
获取指定时间步的目标状态
Args:
step: 时间步
Returns:
dict: 包含 position, heading, velocity 的字典,如果无效则返回 None
"""
if step >= len(self.trajectory['valid']):
return None
if not self.trajectory['valid'][step]:
return None
return {
'position': self.trajectory['positions'][step],
'heading': self.trajectory['headings'][step],
'velocity': self.trajectory['velocities'][step]
}
def is_finished(self, step):
"""
判断轨迹是否已经播放完毕
Args:
step: 当前时间步
Returns:
bool: 如果轨迹已播放完或当前步无效,返回 True
"""
# 超出轨迹长度
if step >= len(self.trajectory['valid']):
return True
# 当前步及之后都无效
return not any(self.trajectory['valid'][step:])

View File

@@ -1,40 +1,389 @@
import argparse
from scenario_env import MultiAgentScenarioEnv
from Env.simple_idm_policy import ConstantVelocityPolicy
from simple_idm_policy import ConstantVelocityPolicy
from replay_policy import ReplayPolicy
from metadrive.engine.asset_loader import AssetLoader
WAYMO_DATA_DIR = r"/home/zhy/桌面/MAGAIL_TR/Env"
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
def main():
def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
scenario_id=None, use_scenario_duration=False,
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
"""
回放模式:严格按照专家轨迹回放
Args:
data_dir: 数据目录
num_episodes: 回合数如果指定scenario_id则忽略
horizon: 最大步数如果use_scenario_duration=True则自动设置
render: 是否渲染
debug: 是否调试模式
scenario_id: 指定场景ID可选
use_scenario_duration: 是否使用场景原始时长
spawn_vehicles: 是否生成车辆默认True
spawn_pedestrians: 是否生成行人默认True
spawn_cyclists: 是否生成自行车默认True
"""
print("=" * 50)
print("运行模式: 专家轨迹回放 (Replay Mode)")
if scenario_id is not None:
print(f"指定场景ID: {scenario_id}")
if use_scenario_duration:
print("使用场景原始时长")
print("=" * 50)
# 如果指定了场景ID只运行1个回合
if scenario_id is not None:
num_episodes = 1
# ✅ 环境创建移到循环外面,避免重复创建
env = MultiAgentScenarioEnv(
config={
# "data_directory": AssetLoader.file_path(AssetLoader.asset_path, "waymo", unix_style=False),
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
"is_multi_agent": True,
"horizon": horizon,
"use_render": render, # 如果False会完全禁用渲染避免LANE_FREEWAY错误
"sequential_seed": True,
"reactive_traffic": False, # 回放模式下不需要反应式交通
"manual_control": False,
"filter_offroad_vehicles": True, # 启用车道过滤
"lane_tolerance": 3.0,
"replay_mode": True, # 标记为回放模式
"debug": debug,
"specific_scenario_id": scenario_id, # 指定场景ID
"use_scenario_duration": use_scenario_duration, # 使用场景时长
# 对象类型过滤
"spawn_vehicles": spawn_vehicles,
"spawn_pedestrians": spawn_pedestrians,
"spawn_cyclists": spawn_cyclists,
# ✅ 关键:设置可用场景数量
#"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数
},
agent2policy=None # 回放模式不需要统一策略
)
try:
# 获取可用场景数量
num_scenarios = env.config.get("num_scenarios", 1)
print(f"可用场景数量: {num_scenarios}")
for episode in range(num_episodes):
print(f"\n{'='*50}")
print(f"回合 {episode + 1}/{num_episodes}")
if scenario_id is not None:
print(f"场景ID: {scenario_id}")
else:
# 循环使用场景
scenario_idx = episode % num_scenarios
print(f"使用场景索引: {scenario_idx}")
print(f"{'='*50}")
# ✅ 如果不是指定场景,使用循环的场景索引
if scenario_id is not None:
seed = scenario_id
else:
seed = episode % num_scenarios
obs = env.reset(seed=seed)
# 为每个车辆分配 ReplayPolicy
replay_policies = {}
for agent_id, vehicle in env.controlled_agents.items():
vehicle_id = vehicle.expert_vehicle_id
if vehicle_id in env.expert_trajectories:
replay_policy = ReplayPolicy(
env.expert_trajectories[vehicle_id],
vehicle_id
)
vehicle.set_policy(replay_policy)
replay_policies[agent_id] = replay_policy
# 输出场景信息
actual_horizon = env.config["horizon"]
print(f"初始化完成:")
print(f" 可控车辆数: {len(env.controlled_agents)}")
print(f" 专家轨迹数: {len(env.expert_trajectories)}")
print(f" 场景时长: {env.scenario_max_duration}")
print(f" 实际Horizon: {actual_horizon}")
step_count = 0
active_vehicles_count = []
while True:
# 在回放模式下,直接使用专家轨迹设置车辆状态
for agent_id, vehicle in list(env.controlled_agents.items()):
vehicle_id = vehicle.expert_vehicle_id
if vehicle_id in env.expert_trajectories and agent_id in replay_policies:
target_state = replay_policies[agent_id].get_target_state(env.round)
if target_state is not None:
# 直接设置车辆状态(绕过物理引擎)
# 只使用xy坐标保持车辆在地面上
position_2d = target_state['position'][:2]
vehicle.set_position(position_2d)
vehicle.set_heading_theta(target_state['heading'])
vehicle.set_velocity(target_state['velocity'][:2] if len(target_state['velocity']) > 2 else target_state['velocity'])
# 使用空动作进行步进
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
obs, rewards, dones, infos = env.step(actions)
if render:
env.render(mode="topdown")
step_count += 1
active_vehicles_count.append(len(env.controlled_agents))
# 每50步打印一次状态
if step_count % 50 == 0:
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
# 调试模式下打印车辆高度信息
if debug and len(env.controlled_agents) > 0:
sample_vehicle = list(env.controlled_agents.values())[0]
z_pos = sample_vehicle.position[2] if len(sample_vehicle.position) > 2 else 0
print(f" [DEBUG] 示例车辆高度: z={z_pos:.3f}m")
if dones["__all__"]:
print(f"\n回合结束统计:")
print(f" 总步数: {step_count}")
print(f" 最大同时车辆数: {max(active_vehicles_count) if active_vehicles_count else 0}")
print(f" 平均车辆数: {sum(active_vehicles_count) / len(active_vehicles_count) if active_vehicles_count else 0:.1f}")
if use_scenario_duration:
print(f" 场景完整回放: {'' if step_count >= env.scenario_max_duration else ''}")
break
finally:
# ✅ 确保环境被正确关闭
env.close()
print("\n" + "=" * 50)
print("回放完成!")
print("=" * 50)
def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
scenario_id=None, use_scenario_duration=False,
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
"""
仿真模式:使用自定义策略控制车辆
车辆根据专家数据的初始位姿生成,然后由策略控制
Args:
data_dir: 数据目录
num_episodes: 回合数
horizon: 最大步数
render: 是否渲染
debug: 是否调试模式
scenario_id: 指定场景ID可选
use_scenario_duration: 是否使用场景原始时长
spawn_vehicles: 是否生成车辆默认True
spawn_pedestrians: 是否生成行人默认True
spawn_cyclists: 是否生成自行车默认True
"""
print("=" * 50)
print("运行模式: 策略仿真 (Simulation Mode)")
if scenario_id is not None:
print(f"指定场景ID: {scenario_id}")
if use_scenario_duration:
print("使用场景原始时长")
print("=" * 50)
# 如果指定了场景ID只运行1个回合
if scenario_id is not None:
num_episodes = 1
env = MultiAgentScenarioEnv(
config={
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
"is_multi_agent": True,
"num_controlled_agents": 3,
"horizon": 300,
"use_render": True,
"horizon": horizon,
"use_render": render, # 如果False会完全禁用渲染避免LANE_FREEWAY错误
"sequential_seed": True,
"reactive_traffic": True,
"manual_control": True,
"manual_control": False,
"filter_offroad_vehicles": True, # 启用车道过滤
"lane_tolerance": 3.0,
"replay_mode": False, # 仿真模式
"debug": debug,
"specific_scenario_id": scenario_id, # 指定场景ID
"use_scenario_duration": use_scenario_duration, # 使用场景时长
# 对象类型过滤
"spawn_vehicles": spawn_vehicles,
"spawn_pedestrians": spawn_pedestrians,
"spawn_cyclists": spawn_cyclists,
# ✅ 关键:设置可用场景数量
#"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数
},
agent2policy=ConstantVelocityPolicy(target_speed=50)
)
obs = env.reset(0
)
for step in range(10000):
actions = {
aid: env.controlled_agents[aid].policy.act()
for aid in env.controlled_agents
}
try:
# 获取可用场景数量
num_scenarios = env.config.get("num_scenarios", 1)
print(f"可用场景数量: {num_scenarios}")
for episode in range(num_episodes):
print(f"\n{'='*50}")
print(f"回合 {episode + 1}/{num_episodes}")
if scenario_id is not None:
print(f"场景ID: {scenario_id}")
else:
# 循环使用场景
scenario_idx = episode % num_scenarios
print(f"使用场景索引: {scenario_idx}")
print(f"{'='*50}")
obs, rewards, dones, infos = env.step(actions)
env.render(mode="topdown")
# ✅ 如果不是指定场景,使用循环的场景索引
if scenario_id is not None:
seed = scenario_id
else:
seed = episode % num_scenarios
obs = env.reset(seed=seed)
if dones["__all__"]:
break
actual_horizon = env.config["horizon"]
print(f"初始化完成:")
print(f" 可控车辆数: {len(env.controlled_agents)}")
print(f" 场景时长: {env.scenario_max_duration}")
print(f" 实际Horizon: {actual_horizon}")
env.close()
step_count = 0
total_reward = 0.0
while True:
# 使用策略生成动作
actions = {
aid: env.controlled_agents[aid].policy.act()
for aid in env.controlled_agents
}
obs, rewards, dones, infos = env.step(actions)
if render:
env.render(mode="topdown")
step_count += 1
total_reward += sum(rewards.values())
# 每50步打印一次状态
if step_count % 50 == 0:
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
if dones["__all__"]:
print(f"\n回合结束统计:")
print(f" 总步数: {step_count}")
print(f" 总奖励: {total_reward:.2f}")
break
finally:
env.close()
print("\n" + "=" * 50)
print("仿真完成!")
print("=" * 50)
def main():
parser = argparse.ArgumentParser(description="MetaDrive 多智能体环境运行脚本")
parser.add_argument(
"--mode",
type=str,
choices=["replay", "simulation"],
default="simulation",
help="运行模式: replay=专家轨迹回放, simulation=策略仿真 (默认: simulation)"
)
parser.add_argument(
"--data_dir",
type=str,
default=WAYMO_DATA_DIR,
help=f"数据目录路径 (默认: {WAYMO_DATA_DIR})"
)
parser.add_argument(
"--episodes",
type=int,
default=1,
help="运行回合数 (默认: 1)"
)
parser.add_argument(
"--horizon",
type=int,
default=300,
help="每回合最大步数 (默认: 300如果启用 --use_scenario_duration 则自动设置)"
)
parser.add_argument(
"--no_render",
action="store_true",
help="禁用渲染(加速运行)"
)
parser.add_argument(
"--debug",
action="store_true",
help="启用调试模式(显示详细日志)"
)
parser.add_argument(
"--scenario_id",
type=int,
default=None,
help="指定场景ID可选如指定则只运行该场景"
)
parser.add_argument(
"--use_scenario_duration",
action="store_true",
help="使用场景原始时长作为horizon自动停止"
)
parser.add_argument(
"--no_vehicles",
action="store_true",
help="禁止生成车辆"
)
parser.add_argument(
"--no_pedestrians",
action="store_true",
help="禁止生成行人"
)
parser.add_argument(
"--no_cyclists",
action="store_true",
help="禁止生成自行车"
)
args = parser.parse_args()
if args.mode == "replay":
run_replay_mode(
data_dir=args.data_dir,
num_episodes=args.episodes,
horizon=args.horizon,
render=not args.no_render,
debug=args.debug,
scenario_id=args.scenario_id,
use_scenario_duration=args.use_scenario_duration,
spawn_vehicles=not args.no_vehicles,
spawn_pedestrians=not args.no_pedestrians,
spawn_cyclists=not args.no_cyclists
)
else:
run_simulation_mode(
data_dir=args.data_dir,
num_episodes=args.episodes,
horizon=args.horizon,
render=not args.no_render,
debug=args.debug,
scenario_id=args.scenario_id,
use_scenario_duration=args.use_scenario_duration,
spawn_vehicles=not args.no_vehicles,
spawn_pedestrians=not args.no_pedestrians,
spawn_cyclists=not args.no_cyclists
)
if __name__ == "__main__":

View File

@@ -15,6 +15,7 @@ class PolicyVehicle(DefaultVehicle):
super().__init__(*args, **kwargs)
self.policy = None
self.destination = None
self.expert_vehicle_id = None # 关联专家车辆ID
def set_policy(self, policy):
self.policy = policy
@@ -22,6 +23,9 @@ class PolicyVehicle(DefaultVehicle):
def set_destination(self, des):
self.destination = des
def set_expert_vehicle_id(self, vid):
self.expert_vehicle_id = vid
def act(self, observation, policy=None):
if self.policy is not None:
return self.policy.act(observation)
@@ -53,6 +57,15 @@ class MultiAgentScenarioEnv(ScenarioEnv):
data_directory=None,
num_controlled_agents=3,
horizon=1000,
filter_offroad_vehicles=True, # 车道过滤开关
lane_tolerance=3.0, # 车道检测容差(米)
replay_mode=False, # 回放模式开关
specific_scenario_id=None, # 新增指定场景ID仅回放模式
use_scenario_duration=False, # 新增使用场景原始时长作为horizon
# 对象类型过滤选项
spawn_vehicles=True, # 是否生成车辆
spawn_pedestrians=True, # 是否生成行人
spawn_cyclists=True, # 是否生成自行车
))
return config
@@ -62,50 +75,180 @@ class MultiAgentScenarioEnv(ScenarioEnv):
self.controlled_agent_ids = []
self.obs_list = []
self.round = 0
self.expert_trajectories = {} # 存储完整专家轨迹
self.replay_mode = config.get("replay_mode", False)
self.scenario_max_duration = 0 # 场景实际最大时长
super().__init__(config)
def reset(self, seed: Union[None, int] = None):
self.round = 0
if self.logger is None:
self.logger = get_logger()
log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO)
set_log_level(log_level)
log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO)
set_log_level(log_level)
# ✅ 关键修复:在每次 reset 前清理所有自定义生成的对象
if hasattr(self, 'engine') and self.engine is not None:
if hasattr(self, 'controlled_agents') and self.controlled_agents:
# 先从 agent_manager 中移除
if hasattr(self.engine, 'agent_manager'):
for agent_id in list(self.controlled_agents.keys()):
if agent_id in self.engine.agent_manager.active_agents:
self.engine.agent_manager.active_agents.pop(agent_id)
# 然后清理对象
for agent_id, vehicle in list(self.controlled_agents.items()):
try:
self.engine.clear_objects([vehicle.id])
except:
pass
self.controlled_agents.clear()
self.controlled_agent_ids.clear()
self.lazy_init()
self._reset_global_seed(seed)
if self.engine is None:
raise ValueError("Broken MetaDrive instance.")
# 记录专家数据中每辆车的位置,接着全部清除,只保留位置等信息,用于后续生成
_obj_to_clean_this_frame = []
self.car_birth_info_list = []
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
continue
else:
if track["type"] == MetaDriveType.VEHICLE:
_obj_to_clean_this_frame.append(scenario_id)
valid = track['state']['valid']
first_show = np.argmax(valid) if valid.any() else -1
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
# id出现时间出生点坐标出生朝向目的地
self.car_birth_info_list.append({
'id': track['metadata']['object_id'],
'show_time': first_show,
'begin': (track['state']['position'][first_show, 0], track['state']['position'][first_show, 1]),
'heading': track['state']['heading'][first_show],
'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1])
})
for scenario_id in _obj_to_clean_this_frame:
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
# 如果指定了场景ID修改start_scenario_index
if self.config.get("specific_scenario_id") is not None:
scenario_id = self.config.get("specific_scenario_id")
self.config["start_scenario_index"] = scenario_id
if self.config.get("debug", False):
self.logger.info(f"Using specific scenario ID: {scenario_id}")
# ✅ 先初始化引擎和 lanes
self.engine.reset()
self.reset_sensors()
self.engine.taskMgr.step()
self.lanes = self.engine.map_manager.current_map.road_network.graph
# 记录专家数据(现在 self.lanes 已经初始化)
_obj_to_clean_this_frame = []
self.car_birth_info_list = []
self.expert_trajectories.clear()
total_vehicles = 0
total_pedestrians = 0
total_cyclists = 0
filtered_vehicles = 0
filtered_by_type = 0
self.scenario_max_duration = 0 # 重置场景时长
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
continue
# 对象类型过滤
obj_type = track["type"]
# 统计对象类型
if obj_type == MetaDriveType.VEHICLE:
total_vehicles += 1
elif obj_type == MetaDriveType.PEDESTRIAN:
total_pedestrians += 1
elif obj_type == MetaDriveType.CYCLIST:
total_cyclists += 1
# 根据配置过滤对象类型
if obj_type == MetaDriveType.VEHICLE and not self.config.get("spawn_vehicles", True):
_obj_to_clean_this_frame.append(scenario_id)
filtered_by_type += 1
if self.config.get("debug", False):
self.logger.debug(f"Filtering VEHICLE {track['metadata']['object_id']} - spawn_vehicles=False")
continue
if obj_type == MetaDriveType.PEDESTRIAN and not self.config.get("spawn_pedestrians", True):
_obj_to_clean_this_frame.append(scenario_id)
filtered_by_type += 1
if self.config.get("debug", False):
self.logger.debug(f"Filtering PEDESTRIAN {track['metadata']['object_id']} - spawn_pedestrians=False")
continue
if obj_type == MetaDriveType.CYCLIST and not self.config.get("spawn_cyclists", True):
_obj_to_clean_this_frame.append(scenario_id)
filtered_by_type += 1
if self.config.get("debug", False):
self.logger.debug(f"Filtering CYCLIST {track['metadata']['object_id']} - spawn_cyclists=False")
continue
# 只处理车辆类型(行人和自行车暂时只做过滤)
if track["type"] == MetaDriveType.VEHICLE:
valid = track['state']['valid']
first_show = np.argmax(valid) if valid.any() else -1
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
if first_show == -1 or last_show == -1:
continue
# 更新场景最大时长
self.scenario_max_duration = max(self.scenario_max_duration, last_show + 1)
# 获取车辆初始位置
initial_position = (
track['state']['position'][first_show, 0],
track['state']['position'][first_show, 1]
)
# 车道过滤
if self.config.get("filter_offroad_vehicles", True):
if not self._is_position_on_lane(initial_position):
filtered_vehicles += 1
_obj_to_clean_this_frame.append(scenario_id)
if self.config.get("debug", False):
self.logger.debug(
f"Filtering vehicle {track['metadata']['object_id']} - "
f"not on lane at position {initial_position}"
)
continue
# 存储完整专家轨迹只使用2D位置避免高度问题
object_id = track['metadata']['object_id']
positions_2d = track['state']['position'].copy()
positions_2d[:, 2] = 0 # 将z坐标设为0让MetaDrive自动处理高度
self.expert_trajectories[object_id] = {
'positions': positions_2d,
'headings': track['state']['heading'].copy(),
'velocities': track['state']['velocity'].copy(),
'valid': track['state']['valid'].copy(),
}
# 保存车辆生成信息
self.car_birth_info_list.append({
'id': object_id,
'show_time': first_show,
'begin': initial_position,
'heading': track['state']['heading'][first_show],
'velocity': track['state']['velocity'][first_show] if self.config.get("inherit_expert_velocity", False) else None,
'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1])
})
# 在回放和仿真模式下都清除原始专家车辆
_obj_to_clean_this_frame.append(scenario_id)
# 清除专家车辆和过滤的对象
for scenario_id in _obj_to_clean_this_frame:
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
# 输出统计信息
if self.config.get("debug", False):
self.logger.info(f"=== 对象统计 ===")
self.logger.info(f"车辆 (VEHICLE): 总数={total_vehicles}, 车道过滤={filtered_vehicles}, 保留={total_vehicles - filtered_vehicles}")
self.logger.info(f"行人 (PEDESTRIAN): 总数={total_pedestrians}")
self.logger.info(f"自行车 (CYCLIST): 总数={total_cyclists}")
self.logger.info(f"类型过滤: {filtered_by_type} 个对象")
self.logger.info(f"场景时长: {self.scenario_max_duration}")
# 如果启用场景时长控制更新horizon
if self.config.get("use_scenario_duration", False) and self.scenario_max_duration > 0:
original_horizon = self.config["horizon"]
self.config["horizon"] = self.scenario_max_duration
if self.config.get("debug", False):
self.logger.info(f"Horizon updated from {original_horizon} to {self.scenario_max_duration} (scenario duration)")
if self.top_down_renderer is not None:
self.top_down_renderer.clear()
self.engine.top_down_renderer = None
@@ -113,7 +256,6 @@ class MultiAgentScenarioEnv(ScenarioEnv):
self.dones = {}
self.episode_rewards = defaultdict(float)
self.episode_lengths = defaultdict(int)
self.controlled_agents.clear()
self.controlled_agent_ids.clear()
@@ -122,37 +264,121 @@ class MultiAgentScenarioEnv(ScenarioEnv):
return self._get_all_obs()
def _is_position_on_lane(self, position, tolerance=None):
if tolerance is None:
tolerance = self.config.get("lane_tolerance", 3.0)
# 确保 self.lanes 已初始化
if not hasattr(self, 'lanes') or self.lanes is None:
if self.config.get("debug", False):
self.logger.warning("Lanes not initialized, skipping lane check")
return True
position_2d = np.array(position[:2]) if len(position) > 2 else np.array(position)
try:
for lane in self.lanes.values():
if lane.lane.point_on_lane(position_2d):
return True
lane_start = np.array(lane.lane.start)[:2]
lane_end = np.array(lane.lane.end)[:2]
lane_vec = lane_end - lane_start
lane_length = np.linalg.norm(lane_vec)
if lane_length < 1e-6:
continue
lane_vec_normalized = lane_vec / lane_length
point_vec = position_2d - lane_start
projection = np.dot(point_vec, lane_vec_normalized)
if 0 <= projection <= lane_length:
closest_point = lane_start + projection * lane_vec_normalized
distance = np.linalg.norm(position_2d - closest_point)
if distance <= tolerance:
return True
except Exception as e:
if self.config.get("debug", False):
self.logger.warning(f"Lane check error: {e}")
return False
return False
def _spawn_controlled_agents(self):
# ego_vehicle = self.engine.agent_manager.active_agents.get("default_agent")
# ego_position = ego_vehicle.position if ego_vehicle else np.array([0, 0])
"""
生成应该在当前或之前出现的车辆
如果round=0且所有车辆的show_time>0则生成show_time最小的车辆保证至少有车辆出现
"""
vehicles_to_spawn = []
for car in self.car_birth_info_list:
if car['show_time'] == self.round:
agent_id = f"controlled_{car['id']}"
if car['show_time'] <= self.round:
vehicles_to_spawn.append(car)
# 如果当前round没有车辆应该出现但车辆列表不为空则生成最早出现的车辆
# 这样可以确保在reset时至少有车辆出现
# if len(vehicles_to_spawn) == 0 and len(self.car_birth_info_list) > 0:
# if self.config.get("debug", False):
# self.logger.debug(
# f"No vehicles to spawn at round {self.round}, "
# f"spawning earliest vehicle instead"
# )
# # 找到show_time最小的车辆
# earliest_car = min(self.car_birth_info_list, key=lambda x: x['show_time'])
# vehicles_to_spawn.append(earliest_car)
for car in vehicles_to_spawn:
agent_id = f"controlled_{car['id']}"
# 避免重复生成
if agent_id in self.controlled_agents:
continue
vehicle_config = {}
vehicle = self.engine.spawn_object(
PolicyVehicle,
vehicle_config=vehicle_config,
position=car['begin'],
heading=car['heading']
)
vehicle = self.engine.spawn_object(
PolicyVehicle,
vehicle_config={},
position=car['begin'],
heading=car['heading']
# 重置车辆状态
reset_kwargs = {
'position': car['begin'],
'heading': car['heading']
}
# 如果启用速度继承,设置初始速度
if car.get('velocity') is not None:
reset_kwargs['velocity'] = car['velocity']
vehicle.reset(**reset_kwargs)
# 设置策略和目的地
vehicle.set_policy(self.policy)
vehicle.set_destination(car['end'])
vehicle.set_expert_vehicle_id(car['id'])
self.controlled_agents[agent_id] = vehicle
self.controlled_agent_ids.append(agent_id)
# 注册到引擎的 active_agents
self.engine.agent_manager.active_agents[agent_id] = vehicle
if self.config.get("debug", False):
self.logger.debug(
f"Spawned vehicle {agent_id} at round {self.round} "
f"(show_time={car['show_time']}), position {car['begin']}"
)
vehicle.reset(position=car['begin'], heading=car['heading'])
vehicle.set_policy(self.policy)
vehicle.set_destination(car['end'])
self.controlled_agents[agent_id] = vehicle
self.controlled_agent_ids.append(agent_id)
# ✅ 关键:注册到引擎的 active_agents才能参与物理更新
self.engine.agent_manager.active_agents[agent_id] = vehicle
def _get_all_obs(self):
# position, velocity, heading, lidar, navigation, TODO: trafficlight -> list
self.obs_list = []
for agent_id, vehicle in self.controlled_agents.items():
state = vehicle.get_state()
traffic_light = 0
for lane in self.lanes.values():
if lane.lane.point_on_lane(state['position'][:2]):
if self.engine.light_manager.has_traffic_light(lane.lane.index):
@@ -167,38 +393,82 @@ class MultiAgentScenarioEnv(ScenarioEnv):
traffic_light = 0
break
lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle,
physics_world=self.engine.physics_world.dynamic_world)
# 使用最近10辆车的相对位置与相对速度替代原80维LiDAR点云
lidar_cloud_points, detected_objects = self.engine.get_sensor("lidar").perceive(
num_lasers=80,
distance=30,
base_vehicle=vehicle,
physics_world=self.engine.physics_world.dynamic_world
)
nearest_vehicle_info = self.engine.get_sensor("lidar").get_surrounding_vehicles_info(
vehicle,
detected_objects,
perceive_distance=30,
num_others=10,
add_others_navi=False
)
side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8,
base_vehicle=vehicle,
physics_world=self.engine.physics_world.static_world)
base_vehicle=vehicle,
physics_world=self.engine.physics_world.static_world)
lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(num_lasers=10, distance=3,
base_vehicle=vehicle,
physics_world=self.engine.physics_world.static_world)
base_vehicle=vehicle,
physics_world=self.engine.physics_world.static_world)
obs = (state['position'][:2] + list(state['velocity']) + [state['heading_theta']]
+ lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
obs = (list(state['position'][:2]) + list(state['velocity']) + [state['heading_theta']]
+ nearest_vehicle_info + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
+ list(vehicle.destination))
self.obs_list.append(obs)
return self.obs_list
def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]):
self.round += 1
# 应用动作
for agent_id, action in action_dict.items():
if agent_id in self.controlled_agents:
self.controlled_agents[agent_id].before_step(action)
# 物理引擎步进
self.engine.step()
# 后处理
for agent_id in action_dict:
if agent_id in self.controlled_agents:
self.controlled_agents[agent_id].after_step()
# 生成新车辆
self._spawn_controlled_agents()
# 获取观测
obs = self._get_all_obs()
rewards = {aid: 0.0 for aid in self.controlled_agents}
dones = {aid: False for aid in self.controlled_agents}
dones["__all__"] = self.episode_step >= self.config["horizon"]
# ✅ 修复:添加回放模式的完成检查
replay_finished = False
if self.replay_mode and self.config.get("use_scenario_duration", False):
# 检查是否所有专家轨迹都已播放完毕
if self.round >= self.scenario_max_duration:
replay_finished = True
if self.config.get("debug", False):
self.logger.info(f"Replay finished at step {self.round}/{self.scenario_max_duration}")
dones["__all__"] = self.episode_step >= self.config["horizon"] or replay_finished
infos = {aid: {} for aid in self.controlled_agents}
return obs, rewards, dones, infos
def close(self):
# ✅ 清理所有生成的车辆
if hasattr(self, 'controlled_agents') and self.controlled_agents:
for agent_id, vehicle in list(self.controlled_agents.items()):
if vehicle in self.engine.get_objects():
self.engine.clear_objects([vehicle.id])
self.controlled_agents.clear()
self.controlled_agent_ids.clear()
super().close()

184
Env/verify_observation.py Normal file
View File

@@ -0,0 +1,184 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
简单验证脚本:检查车辆是否正确获取观测空间
用法python verify_observations.py
"""
import argparse
from scenario_env import MultiAgentScenarioEnv
from replay_policy import ReplayPolicy
from metadrive.engine.asset_loader import AssetLoader
import numpy as np
def verify_observations(data_dir, scenario_id=0):
"""
验证观测空间是否正确获取
Args:
data_dir: 数据目录
scenario_id: 场景ID
"""
print("=" * 60)
print("观测空间验证工具")
print("=" * 60)
# 创建环境
env = MultiAgentScenarioEnv(
config={
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
"is_multi_agent": True,
"horizon": 300,
"use_render": False, # 不渲染,加速运行
"sequential_seed": True,
"reactive_traffic": False,
"manual_control": False,
"filter_offroad_vehicles": True,
"lane_tolerance": 3.0,
"replay_mode": True,
"debug": True, # 启用调试以查看详细信息
"specific_scenario_id": scenario_id,
"use_scenario_duration": True,
},
agent2policy=None
)
# 重置环境
print(f"\n加载场景 {scenario_id}...")
obs = env.reset(seed=scenario_id)
# 输出基本信息
print(f"\n场景信息:")
print(f" - 可控车辆数: {len(env.controlled_agents)}")
print(f" - 观测数量: {len(obs)}")
print(f" - 场景时长: {env.scenario_max_duration}")
print(f" - 车辆生成列表长度: {len(env.car_birth_info_list)}")
print(f" - 当前回合数 (round): {env.round}")
# 检查车辆生成信息
if len(env.car_birth_info_list) > 0:
print(f"\n车辆生成信息分析:")
show_times = [car['show_time'] for car in env.car_birth_info_list]
print(f" - show_time 分布: min={min(show_times)}, max={max(show_times)}")
print(f" - show_time == 0 的车辆数: {sum(1 for st in show_times if st == 0)}")
print(f" - 前5个车辆的 show_time: {show_times[:5]}")
else:
print(f"\n⚠️ 警告: 车辆生成列表为空!可能原因:")
print(f" 1. 所有车辆都被车道过滤移除")
print(f" 2. 所有车辆都被类型过滤移除")
print(f" 3. 场景数据中没有有效车辆")
# 验证观测空间
print(f"\n" + "=" * 60)
print("观测空间验证")
print("=" * 60)
if len(obs) == 0:
print("❌ 错误:没有获取到任何观测!")
env.close()
return False
# 检查第一个观测
first_obs = obs[0]
print(f"\n第一个车辆的观测:")
print(f" - 观测类型: {type(first_obs)}")
print(f" - 观测维度: {len(first_obs)}")
# 详细解析观测
if isinstance(first_obs, (list, np.ndarray)):
obs_array = np.array(first_obs)
print(f" - 观测形状: {obs_array.shape}")
print(f" - 数据类型: {obs_array.dtype}")
print(f"\n观测内容分解:")
# 新观测划分(与 Env/scenario_env._get_all_obs 对齐):
# [x, y] (2) + [vx, vy] (2) + heading (1)
# + nearest vehicles info (10 vehicles * 4 = 40)
# + side_lidar (10) + lane_line_lidar (10)
# + traffic_light (1) + destination (2)
lidar_len = 40
side_len = 10
lane_len = 10
pos_slice = slice(0, 2)
vel_slice = slice(2, 4)
heading_idx = 4
lidar_slice = slice(5, 5 + lidar_len)
side_slice = slice(lidar_slice.stop, lidar_slice.stop + side_len)
lane_slice = slice(side_slice.stop, side_slice.stop + lane_len)
tl_idx = lane_slice.stop
dest_slice = slice(tl_idx + 1, tl_idx + 3)
print(f" 位置 (x, y): {obs_array[pos_slice]}")
print(f" 速度 (vx, vy): {obs_array[vel_slice]}")
print(f" 航向角: {obs_array[heading_idx]:.3f} 弧度")
print(f" 最近车辆信息: {len(obs_array[lidar_slice])} 维 (10辆*4: 相对x/相对y/相对vx/相对vy)")
print(f" 侧向检测: {len(obs_array[side_slice])} 个点")
print(f" 车道线检测: {len(obs_array[lane_slice])} 个点")
print(f" 交通灯状态: {obs_array[tl_idx]}")
print(f" 目的地 (x, y): {obs_array[dest_slice]}")
# 检查数据有效性
print(f"\n数据有效性检查:")
has_nan = np.isnan(obs_array).any()
has_inf = np.isinf(obs_array).any()
if has_nan:
print(f" ❌ 观测包含 NaN 值!")
else:
print(f" ✅ 无 NaN 值")
if has_inf:
print(f" ❌ 观测包含 Inf 值!")
else:
print(f" ✅ 无 Inf 值")
# 检查最近车辆信息数据40维
lidar_data = obs_array[lidar_slice]
lidar_min = np.min(lidar_data)
lidar_max = np.max(lidar_data)
print(f"\n 最近车辆特征范围: [{lidar_min:.2f}, {lidar_max:.2f}]")
if lidar_max > 0:
print(f" ✅ 存在非零最近车辆特征")
else:
print(f" ⚠️ 最近车辆特征全零(可能无邻车或距离过远)")
# 运行几步,验证观测持续有效
print(f"\n" + "=" * 60)
print("多步运行验证(前 5 步)")
print("=" * 60)
for step in range(5):
# 空动作
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
obs, rewards, dones, infos = env.step(actions)
print(f"\nStep {step + 1}:")
print(f" - 活跃车辆数: {len(env.controlled_agents)}")
print(f" - 观测数量: {len(obs)}")
if len(obs) > 0:
sample_obs = np.array(obs[0])
print(f" - 第一辆车位置: ({sample_obs[0]:.2f}, {sample_obs[1]:.2f})")
print(f" - 数据有效: {'' if not (np.isnan(sample_obs).any() or np.isinf(sample_obs).any()) else ''}")
if dones["__all__"]:
print(f" - 场景结束")
break
env.close()
print(f"\n" + "=" * 60)
print("✅ 验证完成!观测空间正常工作")
print("=" * 60)
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="验证观测空间")
parser.add_argument("--data_dir", type=str, default="/home/huangfukk/mdsn", help="数据目录")
parser.add_argument("--scenario_id", type=int, default=0, help="场景ID")
args = parser.parse_args()
verify_observations(args.data_dir, args.scenario_id)

407
README.md
View File

@@ -1,28 +1,401 @@
# MAGAIL4AutoDrive
### 1.1 环境搭建
环境核心代码封装于`Env`文件夹,通过运行`run_multiagent_env.py`即可启动多智能体交互环境,该脚本的核心功能为读取各智能体(车辆)的动作指令,并将其传入`env.step()`方法中完成仿真执行。
# MAGAIL4AutoDrive - 多智能体自动驾驶环境
当前已初步实现`Env.senario_env.MultiAgentScenarioEnv.reset()`车辆生成函数,具体逻辑如下:首先读取专家数据集中各车辆的初始位姿信息;随后对原始数据进行清洗,剔除车辆 Agent 实例信息,记录核心参数(车辆 ID、初始生成位置、朝向角、生成时间戳、目标终点坐标最后调用`_spawn_controlled_agents()`函数,依据清洗后的参数在指定时间、指定位置生成搭载自动驾驶算法的可控车辆
基于 MetaDrive 的多智能体自动驾驶仿真与回放环境,支持 Waymo Open Dataset 的专家轨迹回放和自定义策略仿真
需解决的关键问题:部分车辆存在生成位置偏差(如生成于草坪区域),推测成因可能为专家数据记录误差或场景中模拟停车场区域的特殊标注。后续计划引入车道区域检测机制,通过判断车辆初始生成位置是否位于有效车道范围内,对非车道区域生成的车辆进行过滤,确保环境初始化的合理性。
## 📋 目录
- [项目简介](#项目简介)
- [功能特性](#功能特性)
- [环境要求](#环境要求)
- [安装步骤](#安装步骤)
- [快速开始](#快速开始)
- [使用指南](#使用指南)
- [项目结构](#项目结构)
- [配置说明](#配置说明)
- [常见问题](#常见问题)
## 项目简介
MAGAIL4AutoDrive 是一个基于 MetaDrive 0.4.3 的多智能体自动驾驶环境专为模仿学习Imitation Learning和强化学习Reinforcement Learning研究设计。项目支持从真实世界数据集如 Waymo Open Dataset中加载场景并提供两种核心运行模式
- **回放模式Replay Mode**:严格按照专家轨迹回放,用于数据可视化和验证
- **仿真模式Simulation Mode**:使用自定义策略控制车辆,用于算法训练和测试
## 功能特性
### 核心功能
-**多智能体支持**:同时控制多辆车辆进行协同仿真
-**专家轨迹回放**:精确回放 Waymo 数据集中的专家驾驶行为
-**自定义策略接口**灵活接入各种控制策略IDM、RL 等)
-**智能车道过滤**:自动过滤不在车道上的异常车辆
-**场景时长控制**:支持使用数据集原始场景时长或自定义 horizon
-**丰富的传感器**LiDAR、侧向检测器、车道线检测器、相机、仪表盘
### 高级特性
- 🎯 指定场景 ID 运行
- 🔄 自动场景切换(修复版)
- 📊 详细的调试日志输出
- 🚗 车辆动态生成与管理
- 🎮 支持可视化渲染和无头运行
## 环境要求
### 系统要求
- **操作系统**Ubuntu 18.04+ / macOS 10.14+ / Windows 10+
- **Python 版本**3.8 - 3.10
- **GPU**:可选,但推荐使用(用于加速渲染)
### 依赖库
```
metadrive-simulator==0.4.3
numpy>=1.19.0
pygame>=2.0.0
```
## 安装步骤
### 1. 创建 Conda 环境
```
conda create -n metadrive python=3.10
conda activate metadrive
```
### 2. 安装 MetaDrive
```
pip install metadrive-simulator==0.4.3
```
### 3. 克隆项目
```
git clone https://github.com/your-username/MAGAIL4AutoDrive.git
cd MAGAIL4AutoDrive/Env
```
### 4. 准备数据集
将 Waymo 数据集转换为 MetaDrive 格式并放置在项目目录下:
```
MAGAIL4AutoDrive/Env/
├── exp_converted/
│ ├── scenario_0/
│ ├── scenario_1/
│ └── ...
```
## 快速开始
### 回放模式(推荐先尝试)
```
### 1.2 观测获取
观测信息采集功能通过`Env.senario_env.MultiAgentScenarioEnv._get_all_obs()`函数实现,该函数支持遍历所有可控车辆并采集多维度观测数据,当前已实现的观测维度包括:车辆实时位置坐标、朝向角、行驶速度、雷达扫描点云(含障碍物与车道线特征)、导航信息(因场景复杂度较低,暂采用目标终点坐标直接作为导航输入)。
# 使用场景原始时长回放第一个场景
红绿灯信息采集机制需改进:当前方案通过 “车辆所属车道序号匹配对应红绿灯实例” 的逻辑获取信号灯状态,但存在两类问题:一是部分红绿灯实例的状态值为`None`;二是当单条车道存在分段设计时,部分区域的车辆会无法获取红绿灯状态。
python run_multiagent_env.py --mode replay --episodes 1 --use_scenario_duration
# 回放指定场景
python run_multiagent_env.py --mode replay --scenario_id 0 --use_scenario_duration
# 回放多个场景
python run_multiagent_env.py --mode replay --episodes 3 --use_scenario_duration
```
### 仿真模式
```
### 1.3 算法模块
本方案的核心创新点在于对 GAIL 算法的判别器进行改进,使其适配多智能体场景下 “输入长度动态变化”(车辆数量不固定)的特性,实现对整体交互场景的分类判断,进而满足多智能体自动驾驶环境的训练需求。算法核心代码封装于`Algorithm.bert.Bert`类,具体实现逻辑如下:
# 使用默认策略运行仿真
1. 输入层处理:输入数据为维度`(N, input_dim)`的矩阵(其中`N`为当前场景车辆数量,`input_dim`为单车辆固定观测维度),初始化`Bert`类时需设置`input_dim`,确保输入维度匹配;
2. 嵌入层与位置编码:通过`projection`线性投影层将单车辆观测维度映射至预设的嵌入维度(`embed_dim`),随后叠加可学习的位置编码(`pos_embed`),以捕捉观测序列的时序与空间关联信息;
3. Transformer 特征提取:嵌入后的特征向量输入至多层`Transformer`网络(层数由`num_layers`参数控制),完成高阶特征交互与抽象;
4. 分类头设计:提供两种特征聚合与分类方案:若开启`CLS`模式,在嵌入层前拼接 1 个可学习的`CLS`标记,最终取`CLS`标记对应的特征向量输入全连接层完成分类;若关闭`CLS`模式,则对`Transformer`输出的所有车辆特征向量进行序列维度均值池化,再将池化后的全局特征输入全连接层。分类器支持可选的`Tanh`激活函数,以适配不同场景下的输出分布需求。
python run_multiagent_env.py --mode simulation --episodes 1
# 无渲染运行(加速训练)
### 1.4 动作执行
在当前环境测试阶段,暂沿用腾达的动作执行框架:为每辆可控车辆分配独立的`policy`模型,将单车辆观测数据输入对应`policy`得到动作指令后,传入`env.step()`完成仿真;同时在`before_step`阶段调用`_set_action()`函数,将动作指令绑定至车辆实例,最终由 MetaDrive 仿真系统完成物理动力学计算与场景渲染。
python run_multiagent_env.py --mode simulation --episodes 5 --no_render
后续优化方向为构建 “参数共享式统一模型框架”,具体设计如下:所有车辆共用 1 个`policy`模型,通过参数共享机制实现模型的全局统一维护。该框架具备三重优势:一是避免多车辆独立模型带来的训练偏差(如不同模型训练程度不一致);二是解决车辆数量动态变化时的模型管理问题(车辆新增无需额外初始化模型,车辆减少不丢失模型训练信息);三是支持动作指令的并行计算,可显著提升每一步决策的迭代效率,适配大规模多智能体交互场景的训练需求。
```
## 使用指南
### 命令行参数
| 参数 | 类型 | 默认值 | 说明 |
|------|------|--------|------|
| `--mode` | str | simulation | 运行模式:`replay``simulation` |
| `--data_dir` | str | 当前目录 | Waymo 数据目录路径 |
| `--episodes` | int | 1 | 运行回合数 |
| `--horizon` | int | 300 | 每回合最大步数 |
| `--no_render` | flag | False | 禁用渲染(加速运行) |
| `--debug` | flag | False | 启用调试模式 |
| `--scenario_id` | int | None | 指定场景 ID |
| `--use_scenario_duration` | flag | False | 使用场景原始时长 |
| `--no_vehicles` | flag | False | 禁止生成车辆 |
| `--no_pedestrians` | flag | False | 禁止生成行人 |
| `--no_cyclists` | flag | False | 禁止生成自行车 |
### 回放模式详解
回放模式严格按照专家轨迹回放车辆状态,不涉及物理引擎控制。主要用途:
- 数据集可视化
- 验证数据质量
- 生成演示视频
```bash
# 完整参数示例
python run_multiagent_env.py \
--mode replay \
--episodes 1 \
--use_scenario_duration \
--debug
# 仅回放车辆,禁止行人和自行车
python run_multiagent_env.py \
--mode replay \
--use_scenario_duration \
--no_pedestrians \
--no_cyclists
```
**重要提示**:回放模式建议始终启用 `--use_scenario_duration`,否则会出现场景播放完后继续运行的问题。
### 仿真模式详解
仿真模式使用自定义策略控制车辆,适合算法开发和测试:
```bash
# 基础仿真
python run_multiagent_env.py --mode simulation
# 长时间训练(无渲染)
python run_multiagent_env.py \
--mode simulation \
--episodes 100 \
--horizon 500 \
--no_render
# 仅车辆仿真(用于专注车车交互场景)
python run_multiagent_env.py \
--mode simulation \
--no_pedestrians \
--no_cyclists
```
### 自定义策略
修改 `simple_idm_policy.py` 或创建新的策略类:
```python
class CustomPolicy:
def __init__(self, **kwargs):
# 初始化策略参数
pass
def act(self, observation=None):
# 返回动作 [steering, acceleration]
# steering: [-1, 1]
# acceleration: [-1, 1]
return [0.0, 0.5]
```
`run_multiagent_env.py` 中使用:
```
from custom_policy import CustomPolicy
env = MultiAgentScenarioEnv(
config={...},
agent2policy=CustomPolicy()
)
```
## 项目结构
```
MAGAIL4AutoDrive/Env/
├── run_multiagent_env.py \# 主运行脚本
├── scenario_env.py \# 多智能体场景环境
├── replay_policy.py \# 专家轨迹回放策略
├── simple_idm_policy.py \# IDM 策略实现
├── utils.py \# 工具函数
├── ENHANCED_USAGE_GUIDE.md \# 详细使用指南
├── README.md \# 本文档
└── exp_converted/ \# Waymo 数据集(需自行准备)
├── scenario_0/
├── scenario_1/
└── ...
```
### 核心文件说明
**run_multiagent_env.py**
- 主入口脚本
- 处理命令行参数
- 管理回放和仿真两种模式的运行逻辑
**scenario_env.py**
- 自定义多智能体环境类
- 车辆生成与管理
- 车道过滤逻辑
- 观测空间定义
**replay_policy.py**
- 专家轨迹回放策略
- 逐帧状态查询
- 轨迹完成判断
**simple_idm_policy.py**
- 简单的恒速策略示例
- 可作为自定义策略的模板
## 配置说明
### 环境配置参数
`scenario_env.py``default_config()` 中可修改:
```python
config.update(dict(
data_directory=None, # 数据目录
num_controlled_agents=3, # 可控车辆数量(仅仿真模式)
horizon=1000, # 最大步数
filter_offroad_vehicles=True, # 是否过滤车道外车辆
lane_tolerance=3.0, # 车道容差(米)
replay_mode=False, # 是否为回放模式
specific_scenario_id=None, # 指定场景 ID
use_scenario_duration=False, # 使用场景原始时长
# 对象类型过滤选项
spawn_vehicles=True, # 是否生成车辆
spawn_pedestrians=True, # 是否生成行人
spawn_cyclists=True, # 是否生成自行车
))
```
### 传感器配置
默认启用的传感器(可在环境初始化时修改):
- **LiDAR**80 条激光,探测距离 30 米
- **侧向检测器**10 条激光,探测距离 8 米
- **车道线检测器**10 条激光,探测距离 3 米
- **主相机**:分辨率 1200x900
- **仪表盘**:车辆状态信息
## 常见问题
### Q1: 回放模式为什么超出数据集的最大帧数还在继续?
**A**: 需要添加 `--use_scenario_duration` 参数。修复版本已在 `scenario_env.py` 中添加了自动检测机制。
### Q2: 如何切换不同的场景?
**A**:
- 方法一:使用 `--scenario_id` 指定场景
- 方法二:使用 `--episodes N` 自动遍历 N 个场景
### Q3: 为什么有些车辆没有出现?
**A**: 启用了车道过滤功能(`filter_offroad_vehicles=True`),不在车道上的车辆会被过滤。可以通过设置 `lane_tolerance` 调整容差或关闭此功能。
### Q4: 如何提高运行速度?
**A**:
- 使用 `--no_render` 禁用可视化
- 减少 `num_controlled_agents` 数量
- 使用 GPU 加速
### Q5: 如何控制场景中的对象类型?
**A**: 使用对象过滤参数:
```bash
# 仅车辆,无行人和自行车
python run_multiagent_env.py --mode replay --no_pedestrians --no_cyclists
# 仅行人和自行车,无车辆(特殊场景)
python run_multiagent_env.py --mode replay --no_vehicles
# 调试模式查看过滤统计
python run_multiagent_env.py --mode replay --debug --no_pedestrians
```
### Q6: 为什么有些车辆生成在空中?
**A**: 已在 v1.2.0 中修复。现在所有车辆位置都只使用 2D 坐标x, yz 坐标设为 0让 MetaDrive 自动处理高度,确保车辆贴在地面上。
### Q7: 如何导出观测数据?
**A**: 在 `run_multiagent_env.py` 中添加数据保存逻辑:
```python
import pickle
obs_data = []
while True:
obs, rewards, dones, infos = env.step(actions)
obs_data.append(obs)
if dones["__all__"]:
break
with open('observations.pkl', 'wb') as f:
pickle.dump(obs_data, f)
```
## 更新日志
### v1.2.0 (2025-10-26)
- ✅ 修复车辆生成高度问题(车辆悬空)
- ✅ 添加对象类型过滤功能(车辆/行人/自行车)
- ✅ 新增命令行参数:`--no_vehicles``--no_pedestrians``--no_cyclists`
- ✅ 改进调试信息输出,显示各类型对象统计
- ✅ 优化位置处理逻辑,只使用 2D 坐标避免高度问题
### v1.1.0 (2025-10-26)
- ✅ 修复回放模式超出场景时长问题
- ✅ 添加场景自动切换功能
- ✅ 改进 `replay_policy.py`,新增 `is_finished()` 方法
- ✅ 优化 `scenario_env.py` 的 done 判断逻辑
- ✅ 修复多回合运行时的对象清理问题
### v1.0.0 (初始版本)
- 基础多智能体环境实现
- 回放和仿真两种模式
- 车道过滤功能
- Waymo 数据集支持
## 贡献指南
欢迎提交 Issue 和 Pull Request
### 提交 Issue
- 请详细描述问题和复现步骤
- 附上运行日志和错误信息
- 说明运行环境OS、Python 版本等)
### 提交 PR
- Fork 本项目
- 创建特性分支:`git checkout -b feature/your-feature`
- 提交更改:`git commit -m 'Add some feature'`
- 推送分支:`git push origin feature/your-feature`
- 提交 Pull Request
## 许可证
本项目基于 MIT 许可证开源。
## 致谢
- [MetaDrive](https://github.com/metadriverse/metadrive) - 优秀的驾驶仿真平台
- [Waymo Open Dataset](https://waymo.com/open/) - 高质量的自动驾驶数据集
## 联系方式
如有问题或建议,请通过以下方式联系:
- GitHub Issues: [项目 Issues 页面]
- Email: huangfukk@xxx.com
---
**Happy Driving! 🚗💨**

115
train_magail.py Normal file
View File

@@ -0,0 +1,115 @@
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from Env.scenario_env import MultiAgentScenarioEnv
from Algorithm.magail import MAGAIL
from Algorithm.buffer import RolloutBuffer # 假设 Buffer 在这里
# 假设你有加载专家数据的工具
# from utils import load_expert_buffer
# --- 配置 ---
CONFIG = {
"data_dir": "/home/huangfukk/mdsn",
# ... 其他配置 ...
"state_dim": 30, # 你的 Observation 维度
"action_dim": 2, # [steering, throttle]
"rollout_length": 2048, # Buffer 大小
}
def train():
# 1. 环境
env = MultiAgentScenarioEnv(config={...}, agent2policy=None)
# 2. 专家 Buffer (伪代码)
# expert_buffer = RolloutBuffer(...)
# expert_buffer.load(...)
# 必须保证 expert_buffer.sample() 返回 (state, next_state) 供 Discriminator 训练
# 3. 初始化 MAGAIL
magail = MAGAIL(
buffer_exp=expert_buffer, # 传入专家 buffer
input_dim=(CONFIG["state_dim"],),
action_shape=(CONFIG["action_dim"],), # PPO 初始化可能需要 action_shape原代码好像没传
device=torch.device("cuda"),
rollout_length=CONFIG["rollout_length"]
)
writer = SummaryWriter("./logs")
# 4. 训练循环
total_steps = 0
obs_dict = env.reset()
# 将 Dict Obs 转为 Array: (N_agents, Obs_dim)
# 假设所有 Agent 的 Obs 维度相同
agents = list(obs_dict.keys())
current_obs = np.stack([obs_dict[a] for a in agents])
# 如果需要 state_gail假设它就是 obs
current_state_gail = current_obs.copy()
while total_steps < 1e7:
# --- 收集数据 (Rollout) ---
# PPO 的 buffer 长度是 rollout_length
# 我们可以在这里调用 magail.step 来自动处理 explore 和 buffer append
# 注意magail.step 内部调用了 env.step这可能不适用于 Multi-Agent 环境返回 Dict 的情况
# 原 PPO.step 似乎是为 Single-Agent 或 VectorEnv 设计的
# 这里我们需要手动写 Rollout 循环或者修改 PPO.step 以适配 Dict 返回值
# --- 方案:手动 Rollout 适配 Multi-Agent ---
for _ in range(CONFIG["rollout_length"]):
# 1. 决策
actions_list, log_pis_list = magail.explore(current_obs)
# 2. 拼装 Action Dict
action_dict = {agent_id: action for agent_id, action in zip(agents, actions_list)}
# 3. 环境步进
next_obs_dict, rewards_dict, dones_dict, infos_dict = env.step(action_dict)
# 4. 处理返回值
# 需要处理 Agent 死亡/重置的情况。这里简化假设 Agent 数量不变
next_obs = np.stack([next_obs_dict[a] for a in agents])
rewards = np.array([rewards_dict[a] for a in agents])
dones = np.array([dones_dict[a] for a in agents])
# truncated 通常在 infos 里或者 dones 里隐含,需根据 gym 版本确认
terminated = dones # 简化
truncated = [False] * len(agents) # 简化
# 5. 存入 Buffer
# 获取当前 Actor 的均值方差用于 PPO 更新
with torch.no_grad():
# 重新计算一遍或者在 explore 里返回
# 这里 magail.actor(state) 返回的是 mean (deterministic action)
means = magail.actor(torch.tensor(current_obs, device=magail.device, dtype=torch.float)).cpu().numpy()
stds = magail.actor.log_stds.exp().detach().cpu().numpy()
# 如果是 StateIndependentPolicystd 可能是共享的参数,维度需要广播
if stds.shape[0] != len(agents):
stds = np.repeat(stds, len(agents), axis=0)
magail.buffer.append(
current_obs, current_state_gail, actions_list,
rewards, dones, terminated, log_pis_list,
next_obs, next_obs, # next_state_gail = next_obs
means, stds
)
current_obs = next_obs
current_state_gail = next_obs # 更新 gail state
total_steps += len(agents) # 步数增加 N
# 处理 Done
if all(dones):
obs_dict = env.reset()
current_obs = np.stack([obs_dict[a] for a in agents])
current_state_gail = current_obs
# --- 更新 ---
# 当 Buffer 满时 (rollout_length),调用 update
magail.update(writer, total_steps)
# 保存
if total_steps % 10000 == 0:
magail.save_models("./models")