代替点云

This commit is contained in:
2025-12-07 20:15:35 +08:00
parent 113e86bda2
commit 8cb86115c2
10 changed files with 527 additions and 45 deletions

6
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,6 @@
{
"cursorpyright.analysis.extraPaths": [
"/home/huangfukk/mdsn/metadrive",
"/home/huangfukk/mdsn/scenarionet"
]
}

109
Env/check_dataset.py Normal file
View File

@@ -0,0 +1,109 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
数据集检查脚本:统计可用的场景数量
"""
import pickle
import os
from pathlib import Path
from metadrive.engine.asset_loader import AssetLoader
def check_dataset(data_dir, subfolder="exp_filtered"):
"""
检查数据集中的场景数量
Args:
data_dir: 数据根目录
subfolder: 数据子目录exp_filtered 或 exp_converted
"""
print("=" * 80)
print("数据集检查工具")
print("=" * 80)
# 获取完整路径
full_path = AssetLoader.file_path(data_dir, subfolder, unix_style=False)
print(f"\n数据集路径: {full_path}")
# 检查文件结构
if not os.path.exists(full_path):
print(f"❌ 错误:路径不存在: {full_path}")
return
print(f"✅ 路径存在")
# 读取数据集映射
mapping_file = os.path.join(full_path, "dataset_mapping.pkl")
if os.path.exists(mapping_file):
print(f"\n读取 dataset_mapping.pkl...")
with open(mapping_file, 'rb') as f:
dataset_mapping = pickle.load(f)
print(f"✅ 数据集映射文件存在")
print(f" 映射的场景数量: {len(dataset_mapping)}")
# 统计各个子目录的分布
subdirs = {}
for filename, subdir in dataset_mapping.items():
if subdir not in subdirs:
subdirs[subdir] = []
subdirs[subdir].append(filename)
print(f"\n场景分布:")
for subdir, files in subdirs.items():
print(f" {subdir}: {len(files)} 个场景")
# 读取数据集摘要
summary_file = os.path.join(full_path, "dataset_summary.pkl")
if os.path.exists(summary_file):
print(f"\n读取 dataset_summary.pkl...")
with open(summary_file, 'rb') as f:
dataset_summary = pickle.load(f)
print(f"✅ 数据集摘要文件存在")
print(f" 摘要的场景数量: {len(dataset_summary)}")
# 打印前几个场景的ID
print(f"\n前10个场景ID:")
for i, (scenario_id, info) in enumerate(list(dataset_summary.items())[:10]):
if isinstance(info, dict):
track_length = info.get('track_length', 'N/A')
num_objects = info.get('number_summary', {}).get('num_objects', 'N/A')
print(f" {i}: {scenario_id[:16]}... (时长: {track_length}, 对象数: {num_objects})")
# 检查实际文件
print(f"\n检查实际文件...")
pkl_files = list(Path(full_path).rglob("*.pkl"))
# 排除dataset_mapping和dataset_summary
scenario_files = [f for f in pkl_files if f.name not in ["dataset_mapping.pkl", "dataset_summary.pkl"]]
print(f" 实际场景文件数量: {len(scenario_files)}")
# 检查子目录
print(f"\n子目录结构:")
for item in os.listdir(full_path):
item_path = os.path.join(full_path, item)
if os.path.isdir(item_path):
pkl_count = len([f for f in os.listdir(item_path) if f.endswith('.pkl')])
print(f" {item}/: {pkl_count} 个pkl文件")
print("\n" + "=" * 80)
print("检查完成")
print("=" * 80)
# 返回场景数量
return len(dataset_mapping) if 'dataset_mapping' in locals() else 0
if __name__ == "__main__":
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
print("\n【检查 exp_filtered 数据集】")
num_filtered = check_dataset(WAYMO_DATA_DIR, "exp_filtered")
print("\n\n【检查 exp_converted 数据集】")
num_converted = check_dataset(WAYMO_DATA_DIR, "exp_converted")
print("\n\n总结:")
print(f" exp_filtered: {num_filtered} 个场景")
print(f" exp_converted: {num_converted} 个场景")

View File

@@ -4,7 +4,7 @@ from simple_idm_policy import ConstantVelocityPolicy
from replay_policy import ReplayPolicy from replay_policy import ReplayPolicy
from metadrive.engine.asset_loader import AssetLoader from metadrive.engine.asset_loader import AssetLoader
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn/exp_converted" WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False, def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
@@ -40,10 +40,10 @@ def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=Fa
# ✅ 环境创建移到循环外面,避免重复创建 # ✅ 环境创建移到循环外面,避免重复创建
env = MultiAgentScenarioEnv( env = MultiAgentScenarioEnv(
config={ config={
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False), "data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
"is_multi_agent": True, "is_multi_agent": True,
"horizon": horizon, "horizon": horizon,
"use_render": render, "use_render": render, # 如果False会完全禁用渲染避免LANE_FREEWAY错误
"sequential_seed": True, "sequential_seed": True,
"reactive_traffic": False, # 回放模式下不需要反应式交通 "reactive_traffic": False, # 回放模式下不需要反应式交通
"manual_control": False, "manual_control": False,
@@ -57,20 +57,33 @@ def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=Fa
"spawn_vehicles": spawn_vehicles, "spawn_vehicles": spawn_vehicles,
"spawn_pedestrians": spawn_pedestrians, "spawn_pedestrians": spawn_pedestrians,
"spawn_cyclists": spawn_cyclists, "spawn_cyclists": spawn_cyclists,
# ✅ 关键:设置可用场景数量
#"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数
}, },
agent2policy=None # 回放模式不需要统一策略 agent2policy=None # 回放模式不需要统一策略
) )
try: try:
# 获取可用场景数量
num_scenarios = env.config.get("num_scenarios", 1)
print(f"可用场景数量: {num_scenarios}")
for episode in range(num_episodes): for episode in range(num_episodes):
print(f"\n{'='*50}") print(f"\n{'='*50}")
print(f"回合 {episode + 1}/{num_episodes}") print(f"回合 {episode + 1}/{num_episodes}")
if scenario_id is not None: if scenario_id is not None:
print(f"场景ID: {scenario_id}") print(f"场景ID: {scenario_id}")
else:
# 循环使用场景
scenario_idx = episode % num_scenarios
print(f"使用场景索引: {scenario_idx}")
print(f"{'='*50}") print(f"{'='*50}")
# ✅ 如果不是指定场景,使用seed来遍历不同场景 # ✅ 如果不是指定场景,使用循环的场景索引
seed = scenario_id if scenario_id is not None else episode if scenario_id is not None:
seed = scenario_id
else:
seed = episode % num_scenarios
obs = env.reset(seed=seed) obs = env.reset(seed=seed)
# 为每个车辆分配 ReplayPolicy # 为每个车辆分配 ReplayPolicy
@@ -180,11 +193,11 @@ def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debu
env = MultiAgentScenarioEnv( env = MultiAgentScenarioEnv(
config={ config={
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False), "data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
"is_multi_agent": True, "is_multi_agent": True,
"num_controlled_agents": 3, "num_controlled_agents": 3,
"horizon": horizon, "horizon": horizon,
"use_render": render, "use_render": render, # 如果False会完全禁用渲染避免LANE_FREEWAY错误
"sequential_seed": True, "sequential_seed": True,
"reactive_traffic": True, "reactive_traffic": True,
"manual_control": False, "manual_control": False,
@@ -198,19 +211,33 @@ def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debu
"spawn_vehicles": spawn_vehicles, "spawn_vehicles": spawn_vehicles,
"spawn_pedestrians": spawn_pedestrians, "spawn_pedestrians": spawn_pedestrians,
"spawn_cyclists": spawn_cyclists, "spawn_cyclists": spawn_cyclists,
# ✅ 关键:设置可用场景数量
#"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数
}, },
agent2policy=ConstantVelocityPolicy(target_speed=50) agent2policy=ConstantVelocityPolicy(target_speed=50)
) )
try: try:
# 获取可用场景数量
num_scenarios = env.config.get("num_scenarios", 1)
print(f"可用场景数量: {num_scenarios}")
for episode in range(num_episodes): for episode in range(num_episodes):
print(f"\n{'='*50}") print(f"\n{'='*50}")
print(f"回合 {episode + 1}/{num_episodes}") print(f"回合 {episode + 1}/{num_episodes}")
if scenario_id is not None: if scenario_id is not None:
print(f"场景ID: {scenario_id}") print(f"场景ID: {scenario_id}")
else:
# 循环使用场景
scenario_idx = episode % num_scenarios
print(f"使用场景索引: {scenario_idx}")
print(f"{'='*50}") print(f"{'='*50}")
seed = scenario_id if scenario_id is not None else episode # ✅ 如果不是指定场景,使用循环的场景索引
if scenario_id is not None:
seed = scenario_id
else:
seed = episode % num_scenarios
obs = env.reset(seed=seed) obs = env.reset(seed=seed)
actual_horizon = env.config["horizon"] actual_horizon = env.config["horizon"]

View File

@@ -306,9 +306,35 @@ class MultiAgentScenarioEnv(ScenarioEnv):
return False return False
def _spawn_controlled_agents(self): def _spawn_controlled_agents(self):
"""
生成应该在当前或之前出现的车辆
如果round=0且所有车辆的show_time>0则生成show_time最小的车辆保证至少有车辆出现
"""
vehicles_to_spawn = []
for car in self.car_birth_info_list: for car in self.car_birth_info_list:
if car['show_time'] == self.round: if car['show_time'] <= self.round:
vehicles_to_spawn.append(car)
# 如果当前round没有车辆应该出现但车辆列表不为空则生成最早出现的车辆
# 这样可以确保在reset时至少有车辆出现
# if len(vehicles_to_spawn) == 0 and len(self.car_birth_info_list) > 0:
# if self.config.get("debug", False):
# self.logger.debug(
# f"No vehicles to spawn at round {self.round}, "
# f"spawning earliest vehicle instead"
# )
# # 找到show_time最小的车辆
# earliest_car = min(self.car_birth_info_list, key=lambda x: x['show_time'])
# vehicles_to_spawn.append(earliest_car)
for car in vehicles_to_spawn:
agent_id = f"controlled_{car['id']}" agent_id = f"controlled_{car['id']}"
# 避免重复生成
if agent_id in self.controlled_agents:
continue
vehicle_config = {} vehicle_config = {}
vehicle = self.engine.spawn_object( vehicle = self.engine.spawn_object(
PolicyVehicle, PolicyVehicle,
@@ -341,7 +367,10 @@ class MultiAgentScenarioEnv(ScenarioEnv):
self.engine.agent_manager.active_agents[agent_id] = vehicle self.engine.agent_manager.active_agents[agent_id] = vehicle
if self.config.get("debug", False): if self.config.get("debug", False):
self.logger.debug(f"Spawned vehicle {agent_id} at round {self.round}, position {car['begin']}") self.logger.debug(
f"Spawned vehicle {agent_id} at round {self.round} "
f"(show_time={car['show_time']}), position {car['begin']}"
)
def _get_all_obs(self): def _get_all_obs(self):
self.obs_list = [] self.obs_list = []
@@ -364,8 +393,20 @@ class MultiAgentScenarioEnv(ScenarioEnv):
traffic_light = 0 traffic_light = 0
break break
lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle, # 使用最近10辆车的相对位置与相对速度替代原80维LiDAR点云
physics_world=self.engine.physics_world.dynamic_world) lidar_cloud_points, detected_objects = self.engine.get_sensor("lidar").perceive(
num_lasers=80,
distance=30,
base_vehicle=vehicle,
physics_world=self.engine.physics_world.dynamic_world
)
nearest_vehicle_info = self.engine.get_sensor("lidar").get_surrounding_vehicles_info(
vehicle,
detected_objects,
perceive_distance=30,
num_others=10,
add_others_navi=False
)
side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8, side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8,
base_vehicle=vehicle, base_vehicle=vehicle,
physics_world=self.engine.physics_world.static_world) physics_world=self.engine.physics_world.static_world)
@@ -374,7 +415,7 @@ class MultiAgentScenarioEnv(ScenarioEnv):
physics_world=self.engine.physics_world.static_world) physics_world=self.engine.physics_world.static_world)
obs = (list(state['position'][:2]) + list(state['velocity']) + [state['heading_theta']] obs = (list(state['position'][:2]) + list(state['velocity']) + [state['heading_theta']]
+ lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light] + nearest_vehicle_info + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
+ list(vehicle.destination)) + list(vehicle.destination))
self.obs_list.append(obs) self.obs_list.append(obs)

184
Env/verify_observation.py Normal file
View File

@@ -0,0 +1,184 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
简单验证脚本:检查车辆是否正确获取观测空间
用法python verify_observations.py
"""
import argparse
from scenario_env import MultiAgentScenarioEnv
from replay_policy import ReplayPolicy
from metadrive.engine.asset_loader import AssetLoader
import numpy as np
def verify_observations(data_dir, scenario_id=0):
"""
验证观测空间是否正确获取
Args:
data_dir: 数据目录
scenario_id: 场景ID
"""
print("=" * 60)
print("观测空间验证工具")
print("=" * 60)
# 创建环境
env = MultiAgentScenarioEnv(
config={
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
"is_multi_agent": True,
"horizon": 300,
"use_render": False, # 不渲染,加速运行
"sequential_seed": True,
"reactive_traffic": False,
"manual_control": False,
"filter_offroad_vehicles": True,
"lane_tolerance": 3.0,
"replay_mode": True,
"debug": True, # 启用调试以查看详细信息
"specific_scenario_id": scenario_id,
"use_scenario_duration": True,
},
agent2policy=None
)
# 重置环境
print(f"\n加载场景 {scenario_id}...")
obs = env.reset(seed=scenario_id)
# 输出基本信息
print(f"\n场景信息:")
print(f" - 可控车辆数: {len(env.controlled_agents)}")
print(f" - 观测数量: {len(obs)}")
print(f" - 场景时长: {env.scenario_max_duration}")
print(f" - 车辆生成列表长度: {len(env.car_birth_info_list)}")
print(f" - 当前回合数 (round): {env.round}")
# 检查车辆生成信息
if len(env.car_birth_info_list) > 0:
print(f"\n车辆生成信息分析:")
show_times = [car['show_time'] for car in env.car_birth_info_list]
print(f" - show_time 分布: min={min(show_times)}, max={max(show_times)}")
print(f" - show_time == 0 的车辆数: {sum(1 for st in show_times if st == 0)}")
print(f" - 前5个车辆的 show_time: {show_times[:5]}")
else:
print(f"\n⚠️ 警告: 车辆生成列表为空!可能原因:")
print(f" 1. 所有车辆都被车道过滤移除")
print(f" 2. 所有车辆都被类型过滤移除")
print(f" 3. 场景数据中没有有效车辆")
# 验证观测空间
print(f"\n" + "=" * 60)
print("观测空间验证")
print("=" * 60)
if len(obs) == 0:
print("❌ 错误:没有获取到任何观测!")
env.close()
return False
# 检查第一个观测
first_obs = obs[0]
print(f"\n第一个车辆的观测:")
print(f" - 观测类型: {type(first_obs)}")
print(f" - 观测维度: {len(first_obs)}")
# 详细解析观测
if isinstance(first_obs, (list, np.ndarray)):
obs_array = np.array(first_obs)
print(f" - 观测形状: {obs_array.shape}")
print(f" - 数据类型: {obs_array.dtype}")
print(f"\n观测内容分解:")
# 新观测划分(与 Env/scenario_env._get_all_obs 对齐):
# [x, y] (2) + [vx, vy] (2) + heading (1)
# + nearest vehicles info (10 vehicles * 4 = 40)
# + side_lidar (10) + lane_line_lidar (10)
# + traffic_light (1) + destination (2)
lidar_len = 40
side_len = 10
lane_len = 10
pos_slice = slice(0, 2)
vel_slice = slice(2, 4)
heading_idx = 4
lidar_slice = slice(5, 5 + lidar_len)
side_slice = slice(lidar_slice.stop, lidar_slice.stop + side_len)
lane_slice = slice(side_slice.stop, side_slice.stop + lane_len)
tl_idx = lane_slice.stop
dest_slice = slice(tl_idx + 1, tl_idx + 3)
print(f" 位置 (x, y): {obs_array[pos_slice]}")
print(f" 速度 (vx, vy): {obs_array[vel_slice]}")
print(f" 航向角: {obs_array[heading_idx]:.3f} 弧度")
print(f" 最近车辆信息: {len(obs_array[lidar_slice])} 维 (10辆*4: 相对x/相对y/相对vx/相对vy)")
print(f" 侧向检测: {len(obs_array[side_slice])} 个点")
print(f" 车道线检测: {len(obs_array[lane_slice])} 个点")
print(f" 交通灯状态: {obs_array[tl_idx]}")
print(f" 目的地 (x, y): {obs_array[dest_slice]}")
# 检查数据有效性
print(f"\n数据有效性检查:")
has_nan = np.isnan(obs_array).any()
has_inf = np.isinf(obs_array).any()
if has_nan:
print(f" ❌ 观测包含 NaN 值!")
else:
print(f" ✅ 无 NaN 值")
if has_inf:
print(f" ❌ 观测包含 Inf 值!")
else:
print(f" ✅ 无 Inf 值")
# 检查最近车辆信息数据40维
lidar_data = obs_array[lidar_slice]
lidar_min = np.min(lidar_data)
lidar_max = np.max(lidar_data)
print(f"\n 最近车辆特征范围: [{lidar_min:.2f}, {lidar_max:.2f}]")
if lidar_max > 0:
print(f" ✅ 存在非零最近车辆特征")
else:
print(f" ⚠️ 最近车辆特征全零(可能无邻车或距离过远)")
# 运行几步,验证观测持续有效
print(f"\n" + "=" * 60)
print("多步运行验证(前 5 步)")
print("=" * 60)
for step in range(5):
# 空动作
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
obs, rewards, dones, infos = env.step(actions)
print(f"\nStep {step + 1}:")
print(f" - 活跃车辆数: {len(env.controlled_agents)}")
print(f" - 观测数量: {len(obs)}")
if len(obs) > 0:
sample_obs = np.array(obs[0])
print(f" - 第一辆车位置: ({sample_obs[0]:.2f}, {sample_obs[1]:.2f})")
print(f" - 数据有效: {'' if not (np.isnan(sample_obs).any() or np.isinf(sample_obs).any()) else ''}")
if dones["__all__"]:
print(f" - 场景结束")
break
env.close()
print(f"\n" + "=" * 60)
print("✅ 验证完成!观测空间正常工作")
print("=" * 60)
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="验证观测空间")
parser.add_argument("--data_dir", type=str, default="/home/huangfukk/mdsn", help="数据目录")
parser.add_argument("--scenario_id", type=int, default=0, help="场景ID")
args = parser.parse_args()
verify_observations(args.data_dir, args.scenario_id)

115
train_magail.py Normal file
View File

@@ -0,0 +1,115 @@
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from Env.scenario_env import MultiAgentScenarioEnv
from Algorithm.magail import MAGAIL
from Algorithm.buffer import RolloutBuffer # 假设 Buffer 在这里
# 假设你有加载专家数据的工具
# from utils import load_expert_buffer
# --- 配置 ---
CONFIG = {
"data_dir": "/home/huangfukk/mdsn",
# ... 其他配置 ...
"state_dim": 30, # 你的 Observation 维度
"action_dim": 2, # [steering, throttle]
"rollout_length": 2048, # Buffer 大小
}
def train():
# 1. 环境
env = MultiAgentScenarioEnv(config={...}, agent2policy=None)
# 2. 专家 Buffer (伪代码)
# expert_buffer = RolloutBuffer(...)
# expert_buffer.load(...)
# 必须保证 expert_buffer.sample() 返回 (state, next_state) 供 Discriminator 训练
# 3. 初始化 MAGAIL
magail = MAGAIL(
buffer_exp=expert_buffer, # 传入专家 buffer
input_dim=(CONFIG["state_dim"],),
action_shape=(CONFIG["action_dim"],), # PPO 初始化可能需要 action_shape原代码好像没传
device=torch.device("cuda"),
rollout_length=CONFIG["rollout_length"]
)
writer = SummaryWriter("./logs")
# 4. 训练循环
total_steps = 0
obs_dict = env.reset()
# 将 Dict Obs 转为 Array: (N_agents, Obs_dim)
# 假设所有 Agent 的 Obs 维度相同
agents = list(obs_dict.keys())
current_obs = np.stack([obs_dict[a] for a in agents])
# 如果需要 state_gail假设它就是 obs
current_state_gail = current_obs.copy()
while total_steps < 1e7:
# --- 收集数据 (Rollout) ---
# PPO 的 buffer 长度是 rollout_length
# 我们可以在这里调用 magail.step 来自动处理 explore 和 buffer append
# 注意magail.step 内部调用了 env.step这可能不适用于 Multi-Agent 环境返回 Dict 的情况
# 原 PPO.step 似乎是为 Single-Agent 或 VectorEnv 设计的
# 这里我们需要手动写 Rollout 循环或者修改 PPO.step 以适配 Dict 返回值
# --- 方案:手动 Rollout 适配 Multi-Agent ---
for _ in range(CONFIG["rollout_length"]):
# 1. 决策
actions_list, log_pis_list = magail.explore(current_obs)
# 2. 拼装 Action Dict
action_dict = {agent_id: action for agent_id, action in zip(agents, actions_list)}
# 3. 环境步进
next_obs_dict, rewards_dict, dones_dict, infos_dict = env.step(action_dict)
# 4. 处理返回值
# 需要处理 Agent 死亡/重置的情况。这里简化假设 Agent 数量不变
next_obs = np.stack([next_obs_dict[a] for a in agents])
rewards = np.array([rewards_dict[a] for a in agents])
dones = np.array([dones_dict[a] for a in agents])
# truncated 通常在 infos 里或者 dones 里隐含,需根据 gym 版本确认
terminated = dones # 简化
truncated = [False] * len(agents) # 简化
# 5. 存入 Buffer
# 获取当前 Actor 的均值方差用于 PPO 更新
with torch.no_grad():
# 重新计算一遍或者在 explore 里返回
# 这里 magail.actor(state) 返回的是 mean (deterministic action)
means = magail.actor(torch.tensor(current_obs, device=magail.device, dtype=torch.float)).cpu().numpy()
stds = magail.actor.log_stds.exp().detach().cpu().numpy()
# 如果是 StateIndependentPolicystd 可能是共享的参数,维度需要广播
if stds.shape[0] != len(agents):
stds = np.repeat(stds, len(agents), axis=0)
magail.buffer.append(
current_obs, current_state_gail, actions_list,
rewards, dones, terminated, log_pis_list,
next_obs, next_obs, # next_state_gail = next_obs
means, stds
)
current_obs = next_obs
current_state_gail = next_obs # 更新 gail state
total_steps += len(agents) # 步数增加 N
# 处理 Done
if all(dones):
obs_dict = env.reset()
current_obs = np.stack([obs_dict[a] for a in agents])
current_state_gail = current_obs
# --- 更新 ---
# 当 Buffer 满时 (rollout_length),调用 update
magail.update(writer, total_steps)
# 保存
if total_steps % 10000 == 0:
magail.save_models("./models")