diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..93997ae --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "cursorpyright.analysis.extraPaths": [ + "/home/huangfukk/mdsn/metadrive", + "/home/huangfukk/mdsn/scenarionet" + ] +} \ No newline at end of file diff --git a/Env/__pycache__/replay_policy.cpython-310.pyc b/Env/__pycache__/replay_policy.cpython-310.pyc index fed7d1b..fcf6cca 100644 Binary files a/Env/__pycache__/replay_policy.cpython-310.pyc and b/Env/__pycache__/replay_policy.cpython-310.pyc differ diff --git a/Env/__pycache__/scenario_env.cpython-310.pyc b/Env/__pycache__/scenario_env.cpython-310.pyc index 31f5e12..b8ffe4d 100644 Binary files a/Env/__pycache__/scenario_env.cpython-310.pyc and b/Env/__pycache__/scenario_env.cpython-310.pyc differ diff --git a/Env/__pycache__/scenario_env.cpython-313.pyc b/Env/__pycache__/scenario_env.cpython-313.pyc index 8928dd9..a14c941 100644 Binary files a/Env/__pycache__/scenario_env.cpython-313.pyc and b/Env/__pycache__/scenario_env.cpython-313.pyc differ diff --git a/Env/__pycache__/simple_idm_policy.cpython-310.pyc b/Env/__pycache__/simple_idm_policy.cpython-310.pyc index b0dddc5..0782927 100644 Binary files a/Env/__pycache__/simple_idm_policy.cpython-310.pyc and b/Env/__pycache__/simple_idm_policy.cpython-310.pyc differ diff --git a/Env/check_dataset.py b/Env/check_dataset.py new file mode 100644 index 0000000..98ae66b --- /dev/null +++ b/Env/check_dataset.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +数据集检查脚本:统计可用的场景数量 +""" + +import pickle +import os +from pathlib import Path +from metadrive.engine.asset_loader import AssetLoader + +def check_dataset(data_dir, subfolder="exp_filtered"): + """ + 检查数据集中的场景数量 + + Args: + data_dir: 数据根目录 + subfolder: 数据子目录(exp_filtered 或 exp_converted) + """ + print("=" * 80) + print("数据集检查工具") + print("=" * 80) + + # 获取完整路径 + full_path = AssetLoader.file_path(data_dir, subfolder, unix_style=False) + print(f"\n数据集路径: {full_path}") + + # 检查文件结构 + if not os.path.exists(full_path): + print(f"❌ 错误:路径不存在: {full_path}") + return + + print(f"✅ 路径存在") + + # 读取数据集映射 + mapping_file = os.path.join(full_path, "dataset_mapping.pkl") + if os.path.exists(mapping_file): + print(f"\n读取 dataset_mapping.pkl...") + with open(mapping_file, 'rb') as f: + dataset_mapping = pickle.load(f) + + print(f"✅ 数据集映射文件存在") + print(f" 映射的场景数量: {len(dataset_mapping)}") + + # 统计各个子目录的分布 + subdirs = {} + for filename, subdir in dataset_mapping.items(): + if subdir not in subdirs: + subdirs[subdir] = [] + subdirs[subdir].append(filename) + + print(f"\n场景分布:") + for subdir, files in subdirs.items(): + print(f" {subdir}: {len(files)} 个场景") + + # 读取数据集摘要 + summary_file = os.path.join(full_path, "dataset_summary.pkl") + if os.path.exists(summary_file): + print(f"\n读取 dataset_summary.pkl...") + with open(summary_file, 'rb') as f: + dataset_summary = pickle.load(f) + + print(f"✅ 数据集摘要文件存在") + print(f" 摘要的场景数量: {len(dataset_summary)}") + + # 打印前几个场景的ID + print(f"\n前10个场景ID:") + for i, (scenario_id, info) in enumerate(list(dataset_summary.items())[:10]): + if isinstance(info, dict): + track_length = info.get('track_length', 'N/A') + num_objects = info.get('number_summary', {}).get('num_objects', 'N/A') + print(f" {i}: {scenario_id[:16]}... (时长: {track_length}, 对象数: {num_objects})") + + # 检查实际文件 + print(f"\n检查实际文件...") + pkl_files = list(Path(full_path).rglob("*.pkl")) + # 排除dataset_mapping和dataset_summary + scenario_files = [f for f in pkl_files if f.name not in ["dataset_mapping.pkl", "dataset_summary.pkl"]] + print(f" 实际场景文件数量: {len(scenario_files)}") + + # 检查子目录 + print(f"\n子目录结构:") + for item in os.listdir(full_path): + item_path = os.path.join(full_path, item) + if os.path.isdir(item_path): + pkl_count = len([f for f in os.listdir(item_path) if f.endswith('.pkl')]) + print(f" {item}/: {pkl_count} 个pkl文件") + + print("\n" + "=" * 80) + print("检查完成") + print("=" * 80) + + # 返回场景数量 + return len(dataset_mapping) if 'dataset_mapping' in locals() else 0 + + +if __name__ == "__main__": + WAYMO_DATA_DIR = r"/home/huangfukk/mdsn" + + print("\n【检查 exp_filtered 数据集】") + num_filtered = check_dataset(WAYMO_DATA_DIR, "exp_filtered") + + print("\n\n【检查 exp_converted 数据集】") + num_converted = check_dataset(WAYMO_DATA_DIR, "exp_converted") + + print("\n\n总结:") + print(f" exp_filtered: {num_filtered} 个场景") + print(f" exp_converted: {num_converted} 个场景") + diff --git a/Env/run_multiagent_env.py b/Env/run_multiagent_env.py index 7f8e70d..4821e45 100644 --- a/Env/run_multiagent_env.py +++ b/Env/run_multiagent_env.py @@ -4,7 +4,7 @@ from simple_idm_policy import ConstantVelocityPolicy from replay_policy import ReplayPolicy from metadrive.engine.asset_loader import AssetLoader -WAYMO_DATA_DIR = r"/home/huangfukk/mdsn/exp_converted" +WAYMO_DATA_DIR = r"/home/huangfukk/mdsn" def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False, @@ -40,10 +40,10 @@ def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=Fa # ✅ 环境创建移到循环外面,避免重复创建 env = MultiAgentScenarioEnv( config={ - "data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False), + "data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False), "is_multi_agent": True, "horizon": horizon, - "use_render": render, + "use_render": render, # 如果False会完全禁用渲染,避免LANE_FREEWAY错误 "sequential_seed": True, "reactive_traffic": False, # 回放模式下不需要反应式交通 "manual_control": False, @@ -57,20 +57,33 @@ def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=Fa "spawn_vehicles": spawn_vehicles, "spawn_pedestrians": spawn_pedestrians, "spawn_cyclists": spawn_cyclists, + # ✅ 关键:设置可用场景数量 + #"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数 }, agent2policy=None # 回放模式不需要统一策略 ) try: + # 获取可用场景数量 + num_scenarios = env.config.get("num_scenarios", 1) + print(f"可用场景数量: {num_scenarios}") + for episode in range(num_episodes): print(f"\n{'='*50}") print(f"回合 {episode + 1}/{num_episodes}") if scenario_id is not None: print(f"场景ID: {scenario_id}") + else: + # 循环使用场景 + scenario_idx = episode % num_scenarios + print(f"使用场景索引: {scenario_idx}") print(f"{'='*50}") - # ✅ 如果不是指定场景,使用seed来遍历不同场景 - seed = scenario_id if scenario_id is not None else episode + # ✅ 如果不是指定场景,使用循环的场景索引 + if scenario_id is not None: + seed = scenario_id + else: + seed = episode % num_scenarios obs = env.reset(seed=seed) # 为每个车辆分配 ReplayPolicy @@ -180,11 +193,11 @@ def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debu env = MultiAgentScenarioEnv( config={ - "data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False), + "data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False), "is_multi_agent": True, "num_controlled_agents": 3, "horizon": horizon, - "use_render": render, + "use_render": render, # 如果False会完全禁用渲染,避免LANE_FREEWAY错误 "sequential_seed": True, "reactive_traffic": True, "manual_control": False, @@ -198,19 +211,33 @@ def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debu "spawn_vehicles": spawn_vehicles, "spawn_pedestrians": spawn_pedestrians, "spawn_cyclists": spawn_cyclists, + # ✅ 关键:设置可用场景数量 + #"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数 }, agent2policy=ConstantVelocityPolicy(target_speed=50) ) try: + # 获取可用场景数量 + num_scenarios = env.config.get("num_scenarios", 1) + print(f"可用场景数量: {num_scenarios}") + for episode in range(num_episodes): print(f"\n{'='*50}") print(f"回合 {episode + 1}/{num_episodes}") if scenario_id is not None: print(f"场景ID: {scenario_id}") + else: + # 循环使用场景 + scenario_idx = episode % num_scenarios + print(f"使用场景索引: {scenario_idx}") print(f"{'='*50}") - seed = scenario_id if scenario_id is not None else episode + # ✅ 如果不是指定场景,使用循环的场景索引 + if scenario_id is not None: + seed = scenario_id + else: + seed = episode % num_scenarios obs = env.reset(seed=seed) actual_horizon = env.config["horizon"] diff --git a/Env/scenario_env.py b/Env/scenario_env.py index 85a616d..e3c0b21 100644 --- a/Env/scenario_env.py +++ b/Env/scenario_env.py @@ -306,43 +306,72 @@ class MultiAgentScenarioEnv(ScenarioEnv): return False def _spawn_controlled_agents(self): + """ + 生成应该在当前或之前出现的车辆 + 如果round=0且所有车辆的show_time>0,则生成show_time最小的车辆(保证至少有车辆出现) + """ + vehicles_to_spawn = [] + for car in self.car_birth_info_list: - if car['show_time'] == self.round: - agent_id = f"controlled_{car['id']}" - vehicle_config = {} - vehicle = self.engine.spawn_object( - PolicyVehicle, - vehicle_config=vehicle_config, - position=car['begin'], - heading=car['heading'] + if car['show_time'] <= self.round: + vehicles_to_spawn.append(car) + + # 如果当前round没有车辆应该出现,但车辆列表不为空,则生成最早出现的车辆 + # 这样可以确保在reset时至少有车辆出现 + # if len(vehicles_to_spawn) == 0 and len(self.car_birth_info_list) > 0: + # if self.config.get("debug", False): + # self.logger.debug( + # f"No vehicles to spawn at round {self.round}, " + # f"spawning earliest vehicle instead" + # ) + # # 找到show_time最小的车辆 + # earliest_car = min(self.car_birth_info_list, key=lambda x: x['show_time']) + # vehicles_to_spawn.append(earliest_car) + + for car in vehicles_to_spawn: + agent_id = f"controlled_{car['id']}" + + # 避免重复生成 + if agent_id in self.controlled_agents: + continue + + vehicle_config = {} + vehicle = self.engine.spawn_object( + PolicyVehicle, + vehicle_config=vehicle_config, + position=car['begin'], + heading=car['heading'] + ) + + # 重置车辆状态 + reset_kwargs = { + 'position': car['begin'], + 'heading': car['heading'] + } + + # 如果启用速度继承,设置初始速度 + if car.get('velocity') is not None: + reset_kwargs['velocity'] = car['velocity'] + + vehicle.reset(**reset_kwargs) + + # 设置策略和目的地 + vehicle.set_policy(self.policy) + vehicle.set_destination(car['end']) + vehicle.set_expert_vehicle_id(car['id']) + + self.controlled_agents[agent_id] = vehicle + self.controlled_agent_ids.append(agent_id) + + # 注册到引擎的 active_agents + self.engine.agent_manager.active_agents[agent_id] = vehicle + + if self.config.get("debug", False): + self.logger.debug( + f"Spawned vehicle {agent_id} at round {self.round} " + f"(show_time={car['show_time']}), position {car['begin']}" ) - # 重置车辆状态 - reset_kwargs = { - 'position': car['begin'], - 'heading': car['heading'] - } - - # 如果启用速度继承,设置初始速度 - if car.get('velocity') is not None: - reset_kwargs['velocity'] = car['velocity'] - - vehicle.reset(**reset_kwargs) - - # 设置策略和目的地 - vehicle.set_policy(self.policy) - vehicle.set_destination(car['end']) - vehicle.set_expert_vehicle_id(car['id']) - - self.controlled_agents[agent_id] = vehicle - self.controlled_agent_ids.append(agent_id) - - # 注册到引擎的 active_agents - self.engine.agent_manager.active_agents[agent_id] = vehicle - - if self.config.get("debug", False): - self.logger.debug(f"Spawned vehicle {agent_id} at round {self.round}, position {car['begin']}") - def _get_all_obs(self): self.obs_list = [] @@ -364,8 +393,20 @@ class MultiAgentScenarioEnv(ScenarioEnv): traffic_light = 0 break - lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle, - physics_world=self.engine.physics_world.dynamic_world) + # 使用最近10辆车的相对位置与相对速度替代原80维LiDAR点云 + lidar_cloud_points, detected_objects = self.engine.get_sensor("lidar").perceive( + num_lasers=80, + distance=30, + base_vehicle=vehicle, + physics_world=self.engine.physics_world.dynamic_world + ) + nearest_vehicle_info = self.engine.get_sensor("lidar").get_surrounding_vehicles_info( + vehicle, + detected_objects, + perceive_distance=30, + num_others=10, + add_others_navi=False + ) side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8, base_vehicle=vehicle, physics_world=self.engine.physics_world.static_world) @@ -374,7 +415,7 @@ class MultiAgentScenarioEnv(ScenarioEnv): physics_world=self.engine.physics_world.static_world) obs = (list(state['position'][:2]) + list(state['velocity']) + [state['heading_theta']] - + lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light] + + nearest_vehicle_info + side_lidar[0] + lane_line_lidar[0] + [traffic_light] + list(vehicle.destination)) self.obs_list.append(obs) diff --git a/Env/verify_observation.py b/Env/verify_observation.py new file mode 100644 index 0000000..36ab298 --- /dev/null +++ b/Env/verify_observation.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +简单验证脚本:检查车辆是否正确获取观测空间 +用法:python verify_observations.py +""" + +import argparse +from scenario_env import MultiAgentScenarioEnv +from replay_policy import ReplayPolicy +from metadrive.engine.asset_loader import AssetLoader +import numpy as np + +def verify_observations(data_dir, scenario_id=0): + """ + 验证观测空间是否正确获取 + + Args: + data_dir: 数据目录 + scenario_id: 场景ID + """ + print("=" * 60) + print("观测空间验证工具") + print("=" * 60) + + # 创建环境 + env = MultiAgentScenarioEnv( + config={ + "data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False), + "is_multi_agent": True, + "horizon": 300, + "use_render": False, # 不渲染,加速运行 + "sequential_seed": True, + "reactive_traffic": False, + "manual_control": False, + "filter_offroad_vehicles": True, + "lane_tolerance": 3.0, + "replay_mode": True, + "debug": True, # 启用调试以查看详细信息 + "specific_scenario_id": scenario_id, + "use_scenario_duration": True, + }, + agent2policy=None + ) + + # 重置环境 + print(f"\n加载场景 {scenario_id}...") + obs = env.reset(seed=scenario_id) + + # 输出基本信息 + print(f"\n场景信息:") + print(f" - 可控车辆数: {len(env.controlled_agents)}") + print(f" - 观测数量: {len(obs)}") + print(f" - 场景时长: {env.scenario_max_duration} 步") + print(f" - 车辆生成列表长度: {len(env.car_birth_info_list)}") + print(f" - 当前回合数 (round): {env.round}") + + # 检查车辆生成信息 + if len(env.car_birth_info_list) > 0: + print(f"\n车辆生成信息分析:") + show_times = [car['show_time'] for car in env.car_birth_info_list] + print(f" - show_time 分布: min={min(show_times)}, max={max(show_times)}") + print(f" - show_time == 0 的车辆数: {sum(1 for st in show_times if st == 0)}") + print(f" - 前5个车辆的 show_time: {show_times[:5]}") + else: + print(f"\n⚠️ 警告: 车辆生成列表为空!可能原因:") + print(f" 1. 所有车辆都被车道过滤移除") + print(f" 2. 所有车辆都被类型过滤移除") + print(f" 3. 场景数据中没有有效车辆") + + # 验证观测空间 + print(f"\n" + "=" * 60) + print("观测空间验证") + print("=" * 60) + + if len(obs) == 0: + print("❌ 错误:没有获取到任何观测!") + env.close() + return False + + # 检查第一个观测 + first_obs = obs[0] + print(f"\n第一个车辆的观测:") + print(f" - 观测类型: {type(first_obs)}") + print(f" - 观测维度: {len(first_obs)}") + + # 详细解析观测 + if isinstance(first_obs, (list, np.ndarray)): + obs_array = np.array(first_obs) + print(f" - 观测形状: {obs_array.shape}") + print(f" - 数据类型: {obs_array.dtype}") + print(f"\n观测内容分解:") + # 新观测划分(与 Env/scenario_env._get_all_obs 对齐): + # [x, y] (2) + [vx, vy] (2) + heading (1) + # + nearest vehicles info (10 vehicles * 4 = 40) + # + side_lidar (10) + lane_line_lidar (10) + # + traffic_light (1) + destination (2) + lidar_len = 40 + side_len = 10 + lane_len = 10 + pos_slice = slice(0, 2) + vel_slice = slice(2, 4) + heading_idx = 4 + lidar_slice = slice(5, 5 + lidar_len) + side_slice = slice(lidar_slice.stop, lidar_slice.stop + side_len) + lane_slice = slice(side_slice.stop, side_slice.stop + lane_len) + tl_idx = lane_slice.stop + dest_slice = slice(tl_idx + 1, tl_idx + 3) + + print(f" 位置 (x, y): {obs_array[pos_slice]}") + print(f" 速度 (vx, vy): {obs_array[vel_slice]}") + print(f" 航向角: {obs_array[heading_idx]:.3f} 弧度") + print(f" 最近车辆信息: {len(obs_array[lidar_slice])} 维 (10辆*4: 相对x/相对y/相对vx/相对vy)") + print(f" 侧向检测: {len(obs_array[side_slice])} 个点") + print(f" 车道线检测: {len(obs_array[lane_slice])} 个点") + print(f" 交通灯状态: {obs_array[tl_idx]}") + print(f" 目的地 (x, y): {obs_array[dest_slice]}") + + # 检查数据有效性 + print(f"\n数据有效性检查:") + has_nan = np.isnan(obs_array).any() + has_inf = np.isinf(obs_array).any() + + if has_nan: + print(f" ❌ 观测包含 NaN 值!") + else: + print(f" ✅ 无 NaN 值") + + if has_inf: + print(f" ❌ 观测包含 Inf 值!") + else: + print(f" ✅ 无 Inf 值") + + # 检查最近车辆信息数据(40维) + lidar_data = obs_array[lidar_slice] + lidar_min = np.min(lidar_data) + lidar_max = np.max(lidar_data) + print(f"\n 最近车辆特征范围: [{lidar_min:.2f}, {lidar_max:.2f}]") + + if lidar_max > 0: + print(f" ✅ 存在非零最近车辆特征") + else: + print(f" ⚠️ 最近车辆特征全零(可能无邻车或距离过远)") + + # 运行几步,验证观测持续有效 + print(f"\n" + "=" * 60) + print("多步运行验证(前 5 步)") + print("=" * 60) + + for step in range(5): + # 空动作 + actions = {aid: [0.0, 0.0] for aid in env.controlled_agents} + obs, rewards, dones, infos = env.step(actions) + + print(f"\nStep {step + 1}:") + print(f" - 活跃车辆数: {len(env.controlled_agents)}") + print(f" - 观测数量: {len(obs)}") + + if len(obs) > 0: + sample_obs = np.array(obs[0]) + print(f" - 第一辆车位置: ({sample_obs[0]:.2f}, {sample_obs[1]:.2f})") + print(f" - 数据有效: {'✅' if not (np.isnan(sample_obs).any() or np.isinf(sample_obs).any()) else '❌'}") + + if dones["__all__"]: + print(f" - 场景结束") + break + + env.close() + + print(f"\n" + "=" * 60) + print("✅ 验证完成!观测空间正常工作") + print("=" * 60) + + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="验证观测空间") + parser.add_argument("--data_dir", type=str, default="/home/huangfukk/mdsn", help="数据目录") + parser.add_argument("--scenario_id", type=int, default=0, help="场景ID") + + args = parser.parse_args() + + verify_observations(args.data_dir, args.scenario_id) diff --git a/train_magail.py b/train_magail.py new file mode 100644 index 0000000..0fa9c58 --- /dev/null +++ b/train_magail.py @@ -0,0 +1,115 @@ +import torch +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from Env.scenario_env import MultiAgentScenarioEnv +from Algorithm.magail import MAGAIL +from Algorithm.buffer import RolloutBuffer # 假设 Buffer 在这里 +# 假设你有加载专家数据的工具 +# from utils import load_expert_buffer + +# --- 配置 --- +CONFIG = { + "data_dir": "/home/huangfukk/mdsn", + # ... 其他配置 ... + "state_dim": 30, # 你的 Observation 维度 + "action_dim": 2, # [steering, throttle] + "rollout_length": 2048, # Buffer 大小 +} + +def train(): + # 1. 环境 + env = MultiAgentScenarioEnv(config={...}, agent2policy=None) + + # 2. 专家 Buffer (伪代码) + # expert_buffer = RolloutBuffer(...) + # expert_buffer.load(...) + # 必须保证 expert_buffer.sample() 返回 (state, next_state) 供 Discriminator 训练 + + # 3. 初始化 MAGAIL + magail = MAGAIL( + buffer_exp=expert_buffer, # 传入专家 buffer + input_dim=(CONFIG["state_dim"],), + action_shape=(CONFIG["action_dim"],), # PPO 初始化可能需要 action_shape,原代码好像没传? + device=torch.device("cuda"), + rollout_length=CONFIG["rollout_length"] + ) + + writer = SummaryWriter("./logs") + + # 4. 训练循环 + total_steps = 0 + obs_dict = env.reset() + + # 将 Dict Obs 转为 Array: (N_agents, Obs_dim) + # 假设所有 Agent 的 Obs 维度相同 + agents = list(obs_dict.keys()) + current_obs = np.stack([obs_dict[a] for a in agents]) + + # 如果需要 state_gail,假设它就是 obs + current_state_gail = current_obs.copy() + + while total_steps < 1e7: + # --- 收集数据 (Rollout) --- + # PPO 的 buffer 长度是 rollout_length + # 我们可以在这里调用 magail.step 来自动处理 explore 和 buffer append + + # 注意:magail.step 内部调用了 env.step,这可能不适用于 Multi-Agent 环境返回 Dict 的情况 + # 原 PPO.step 似乎是为 Single-Agent 或 VectorEnv 设计的 + # 这里我们需要手动写 Rollout 循环或者修改 PPO.step 以适配 Dict 返回值 + + # --- 方案:手动 Rollout 适配 Multi-Agent --- + for _ in range(CONFIG["rollout_length"]): + # 1. 决策 + actions_list, log_pis_list = magail.explore(current_obs) + + # 2. 拼装 Action Dict + action_dict = {agent_id: action for agent_id, action in zip(agents, actions_list)} + + # 3. 环境步进 + next_obs_dict, rewards_dict, dones_dict, infos_dict = env.step(action_dict) + + # 4. 处理返回值 + # 需要处理 Agent 死亡/重置的情况。这里简化假设 Agent 数量不变 + next_obs = np.stack([next_obs_dict[a] for a in agents]) + rewards = np.array([rewards_dict[a] for a in agents]) + dones = np.array([dones_dict[a] for a in agents]) + # truncated 通常在 infos 里或者 dones 里隐含,需根据 gym 版本确认 + terminated = dones # 简化 + truncated = [False] * len(agents) # 简化 + + # 5. 存入 Buffer + # 获取当前 Actor 的均值方差用于 PPO 更新 + with torch.no_grad(): + # 重新计算一遍或者在 explore 里返回 + # 这里 magail.actor(state) 返回的是 mean (deterministic action) + means = magail.actor(torch.tensor(current_obs, device=magail.device, dtype=torch.float)).cpu().numpy() + stds = magail.actor.log_stds.exp().detach().cpu().numpy() + # 如果是 StateIndependentPolicy,std 可能是共享的参数,维度需要广播 + if stds.shape[0] != len(agents): + stds = np.repeat(stds, len(agents), axis=0) + + magail.buffer.append( + current_obs, current_state_gail, actions_list, + rewards, dones, terminated, log_pis_list, + next_obs, next_obs, # next_state_gail = next_obs + means, stds + ) + + current_obs = next_obs + current_state_gail = next_obs # 更新 gail state + total_steps += len(agents) # 步数增加 N + + # 处理 Done + if all(dones): + obs_dict = env.reset() + current_obs = np.stack([obs_dict[a] for a in agents]) + current_state_gail = current_obs + + # --- 更新 --- + # 当 Buffer 满时 (rollout_length),调用 update + magail.update(writer, total_steps) + + # 保存 + if total_steps % 10000 == 0: + magail.save_models("./models") +