import torch import numpy as np from torch.utils.tensorboard import SummaryWriter from Env.scenario_env import MultiAgentScenarioEnv from Algorithm.magail import MAGAIL from Algorithm.buffer import RolloutBuffer # 假设 Buffer 在这里 # 假设你有加载专家数据的工具 # from utils import load_expert_buffer # --- 配置 --- CONFIG = { "data_dir": "/home/huangfukk/mdsn", # ... 其他配置 ... "state_dim": 30, # 你的 Observation 维度 "action_dim": 2, # [steering, throttle] "rollout_length": 2048, # Buffer 大小 } def train(): # 1. 环境 env = MultiAgentScenarioEnv(config={...}, agent2policy=None) # 2. 专家 Buffer (伪代码) # expert_buffer = RolloutBuffer(...) # expert_buffer.load(...) # 必须保证 expert_buffer.sample() 返回 (state, next_state) 供 Discriminator 训练 # 3. 初始化 MAGAIL magail = MAGAIL( buffer_exp=expert_buffer, # 传入专家 buffer input_dim=(CONFIG["state_dim"],), action_shape=(CONFIG["action_dim"],), # PPO 初始化可能需要 action_shape,原代码好像没传? device=torch.device("cuda"), rollout_length=CONFIG["rollout_length"] ) writer = SummaryWriter("./logs") # 4. 训练循环 total_steps = 0 obs_dict = env.reset() # 将 Dict Obs 转为 Array: (N_agents, Obs_dim) # 假设所有 Agent 的 Obs 维度相同 agents = list(obs_dict.keys()) current_obs = np.stack([obs_dict[a] for a in agents]) # 如果需要 state_gail,假设它就是 obs current_state_gail = current_obs.copy() while total_steps < 1e7: # --- 收集数据 (Rollout) --- # PPO 的 buffer 长度是 rollout_length # 我们可以在这里调用 magail.step 来自动处理 explore 和 buffer append # 注意:magail.step 内部调用了 env.step,这可能不适用于 Multi-Agent 环境返回 Dict 的情况 # 原 PPO.step 似乎是为 Single-Agent 或 VectorEnv 设计的 # 这里我们需要手动写 Rollout 循环或者修改 PPO.step 以适配 Dict 返回值 # --- 方案:手动 Rollout 适配 Multi-Agent --- for _ in range(CONFIG["rollout_length"]): # 1. 决策 actions_list, log_pis_list = magail.explore(current_obs) # 2. 拼装 Action Dict action_dict = {agent_id: action for agent_id, action in zip(agents, actions_list)} # 3. 环境步进 next_obs_dict, rewards_dict, dones_dict, infos_dict = env.step(action_dict) # 4. 处理返回值 # 需要处理 Agent 死亡/重置的情况。这里简化假设 Agent 数量不变 next_obs = np.stack([next_obs_dict[a] for a in agents]) rewards = np.array([rewards_dict[a] for a in agents]) dones = np.array([dones_dict[a] for a in agents]) # truncated 通常在 infos 里或者 dones 里隐含,需根据 gym 版本确认 terminated = dones # 简化 truncated = [False] * len(agents) # 简化 # 5. 存入 Buffer # 获取当前 Actor 的均值方差用于 PPO 更新 with torch.no_grad(): # 重新计算一遍或者在 explore 里返回 # 这里 magail.actor(state) 返回的是 mean (deterministic action) means = magail.actor(torch.tensor(current_obs, device=magail.device, dtype=torch.float)).cpu().numpy() stds = magail.actor.log_stds.exp().detach().cpu().numpy() # 如果是 StateIndependentPolicy,std 可能是共享的参数,维度需要广播 if stds.shape[0] != len(agents): stds = np.repeat(stds, len(agents), axis=0) magail.buffer.append( current_obs, current_state_gail, actions_list, rewards, dones, terminated, log_pis_list, next_obs, next_obs, # next_state_gail = next_obs means, stds ) current_obs = next_obs current_state_gail = next_obs # 更新 gail state total_steps += len(agents) # 步数增加 N # 处理 Done if all(dones): obs_dict = env.reset() current_obs = np.stack([obs_dict[a] for a in agents]) current_state_gail = current_obs # --- 更新 --- # 当 Buffer 满时 (rollout_length),调用 update magail.update(writer, total_steps) # 保存 if total_steps % 10000 == 0: magail.save_models("./models")