Files
MAGAIL4AutoDrive/train_magail.py
2025-12-07 20:15:35 +08:00

116 lines
4.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from Env.scenario_env import MultiAgentScenarioEnv
from Algorithm.magail import MAGAIL
from Algorithm.buffer import RolloutBuffer # 假设 Buffer 在这里
# 假设你有加载专家数据的工具
# from utils import load_expert_buffer
# --- 配置 ---
CONFIG = {
"data_dir": "/home/huangfukk/mdsn",
# ... 其他配置 ...
"state_dim": 30, # 你的 Observation 维度
"action_dim": 2, # [steering, throttle]
"rollout_length": 2048, # Buffer 大小
}
def train():
# 1. 环境
env = MultiAgentScenarioEnv(config={...}, agent2policy=None)
# 2. 专家 Buffer (伪代码)
# expert_buffer = RolloutBuffer(...)
# expert_buffer.load(...)
# 必须保证 expert_buffer.sample() 返回 (state, next_state) 供 Discriminator 训练
# 3. 初始化 MAGAIL
magail = MAGAIL(
buffer_exp=expert_buffer, # 传入专家 buffer
input_dim=(CONFIG["state_dim"],),
action_shape=(CONFIG["action_dim"],), # PPO 初始化可能需要 action_shape原代码好像没传
device=torch.device("cuda"),
rollout_length=CONFIG["rollout_length"]
)
writer = SummaryWriter("./logs")
# 4. 训练循环
total_steps = 0
obs_dict = env.reset()
# 将 Dict Obs 转为 Array: (N_agents, Obs_dim)
# 假设所有 Agent 的 Obs 维度相同
agents = list(obs_dict.keys())
current_obs = np.stack([obs_dict[a] for a in agents])
# 如果需要 state_gail假设它就是 obs
current_state_gail = current_obs.copy()
while total_steps < 1e7:
# --- 收集数据 (Rollout) ---
# PPO 的 buffer 长度是 rollout_length
# 我们可以在这里调用 magail.step 来自动处理 explore 和 buffer append
# 注意magail.step 内部调用了 env.step这可能不适用于 Multi-Agent 环境返回 Dict 的情况
# 原 PPO.step 似乎是为 Single-Agent 或 VectorEnv 设计的
# 这里我们需要手动写 Rollout 循环或者修改 PPO.step 以适配 Dict 返回值
# --- 方案:手动 Rollout 适配 Multi-Agent ---
for _ in range(CONFIG["rollout_length"]):
# 1. 决策
actions_list, log_pis_list = magail.explore(current_obs)
# 2. 拼装 Action Dict
action_dict = {agent_id: action for agent_id, action in zip(agents, actions_list)}
# 3. 环境步进
next_obs_dict, rewards_dict, dones_dict, infos_dict = env.step(action_dict)
# 4. 处理返回值
# 需要处理 Agent 死亡/重置的情况。这里简化假设 Agent 数量不变
next_obs = np.stack([next_obs_dict[a] for a in agents])
rewards = np.array([rewards_dict[a] for a in agents])
dones = np.array([dones_dict[a] for a in agents])
# truncated 通常在 infos 里或者 dones 里隐含,需根据 gym 版本确认
terminated = dones # 简化
truncated = [False] * len(agents) # 简化
# 5. 存入 Buffer
# 获取当前 Actor 的均值方差用于 PPO 更新
with torch.no_grad():
# 重新计算一遍或者在 explore 里返回
# 这里 magail.actor(state) 返回的是 mean (deterministic action)
means = magail.actor(torch.tensor(current_obs, device=magail.device, dtype=torch.float)).cpu().numpy()
stds = magail.actor.log_stds.exp().detach().cpu().numpy()
# 如果是 StateIndependentPolicystd 可能是共享的参数,维度需要广播
if stds.shape[0] != len(agents):
stds = np.repeat(stds, len(agents), axis=0)
magail.buffer.append(
current_obs, current_state_gail, actions_list,
rewards, dones, terminated, log_pis_list,
next_obs, next_obs, # next_state_gail = next_obs
means, stds
)
current_obs = next_obs
current_state_gail = next_obs # 更新 gail state
total_steps += len(agents) # 步数增加 N
# 处理 Done
if all(dones):
obs_dict = env.reset()
current_obs = np.stack([obs_dict[a] for a in agents])
current_state_gail = current_obs
# --- 更新 ---
# 当 Buffer 满时 (rollout_length),调用 update
magail.update(writer, total_steps)
# 保存
if total_steps % 10000 == 0:
magail.save_models("./models")