116 lines
4.7 KiB
Python
116 lines
4.7 KiB
Python
import torch
|
||
import numpy as np
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
from Env.scenario_env import MultiAgentScenarioEnv
|
||
from Algorithm.magail import MAGAIL
|
||
from Algorithm.buffer import RolloutBuffer # 假设 Buffer 在这里
|
||
# 假设你有加载专家数据的工具
|
||
# from utils import load_expert_buffer
|
||
|
||
# --- 配置 ---
|
||
CONFIG = {
|
||
"data_dir": "/home/huangfukk/mdsn",
|
||
# ... 其他配置 ...
|
||
"state_dim": 30, # 你的 Observation 维度
|
||
"action_dim": 2, # [steering, throttle]
|
||
"rollout_length": 2048, # Buffer 大小
|
||
}
|
||
|
||
def train():
|
||
# 1. 环境
|
||
env = MultiAgentScenarioEnv(config={...}, agent2policy=None)
|
||
|
||
# 2. 专家 Buffer (伪代码)
|
||
# expert_buffer = RolloutBuffer(...)
|
||
# expert_buffer.load(...)
|
||
# 必须保证 expert_buffer.sample() 返回 (state, next_state) 供 Discriminator 训练
|
||
|
||
# 3. 初始化 MAGAIL
|
||
magail = MAGAIL(
|
||
buffer_exp=expert_buffer, # 传入专家 buffer
|
||
input_dim=(CONFIG["state_dim"],),
|
||
action_shape=(CONFIG["action_dim"],), # PPO 初始化可能需要 action_shape,原代码好像没传?
|
||
device=torch.device("cuda"),
|
||
rollout_length=CONFIG["rollout_length"]
|
||
)
|
||
|
||
writer = SummaryWriter("./logs")
|
||
|
||
# 4. 训练循环
|
||
total_steps = 0
|
||
obs_dict = env.reset()
|
||
|
||
# 将 Dict Obs 转为 Array: (N_agents, Obs_dim)
|
||
# 假设所有 Agent 的 Obs 维度相同
|
||
agents = list(obs_dict.keys())
|
||
current_obs = np.stack([obs_dict[a] for a in agents])
|
||
|
||
# 如果需要 state_gail,假设它就是 obs
|
||
current_state_gail = current_obs.copy()
|
||
|
||
while total_steps < 1e7:
|
||
# --- 收集数据 (Rollout) ---
|
||
# PPO 的 buffer 长度是 rollout_length
|
||
# 我们可以在这里调用 magail.step 来自动处理 explore 和 buffer append
|
||
|
||
# 注意:magail.step 内部调用了 env.step,这可能不适用于 Multi-Agent 环境返回 Dict 的情况
|
||
# 原 PPO.step 似乎是为 Single-Agent 或 VectorEnv 设计的
|
||
# 这里我们需要手动写 Rollout 循环或者修改 PPO.step 以适配 Dict 返回值
|
||
|
||
# --- 方案:手动 Rollout 适配 Multi-Agent ---
|
||
for _ in range(CONFIG["rollout_length"]):
|
||
# 1. 决策
|
||
actions_list, log_pis_list = magail.explore(current_obs)
|
||
|
||
# 2. 拼装 Action Dict
|
||
action_dict = {agent_id: action for agent_id, action in zip(agents, actions_list)}
|
||
|
||
# 3. 环境步进
|
||
next_obs_dict, rewards_dict, dones_dict, infos_dict = env.step(action_dict)
|
||
|
||
# 4. 处理返回值
|
||
# 需要处理 Agent 死亡/重置的情况。这里简化假设 Agent 数量不变
|
||
next_obs = np.stack([next_obs_dict[a] for a in agents])
|
||
rewards = np.array([rewards_dict[a] for a in agents])
|
||
dones = np.array([dones_dict[a] for a in agents])
|
||
# truncated 通常在 infos 里或者 dones 里隐含,需根据 gym 版本确认
|
||
terminated = dones # 简化
|
||
truncated = [False] * len(agents) # 简化
|
||
|
||
# 5. 存入 Buffer
|
||
# 获取当前 Actor 的均值方差用于 PPO 更新
|
||
with torch.no_grad():
|
||
# 重新计算一遍或者在 explore 里返回
|
||
# 这里 magail.actor(state) 返回的是 mean (deterministic action)
|
||
means = magail.actor(torch.tensor(current_obs, device=magail.device, dtype=torch.float)).cpu().numpy()
|
||
stds = magail.actor.log_stds.exp().detach().cpu().numpy()
|
||
# 如果是 StateIndependentPolicy,std 可能是共享的参数,维度需要广播
|
||
if stds.shape[0] != len(agents):
|
||
stds = np.repeat(stds, len(agents), axis=0)
|
||
|
||
magail.buffer.append(
|
||
current_obs, current_state_gail, actions_list,
|
||
rewards, dones, terminated, log_pis_list,
|
||
next_obs, next_obs, # next_state_gail = next_obs
|
||
means, stds
|
||
)
|
||
|
||
current_obs = next_obs
|
||
current_state_gail = next_obs # 更新 gail state
|
||
total_steps += len(agents) # 步数增加 N
|
||
|
||
# 处理 Done
|
||
if all(dones):
|
||
obs_dict = env.reset()
|
||
current_obs = np.stack([obs_dict[a] for a in agents])
|
||
current_state_gail = current_obs
|
||
|
||
# --- 更新 ---
|
||
# 当 Buffer 满时 (rollout_length),调用 update
|
||
magail.update(writer, total_steps)
|
||
|
||
# 保存
|
||
if total_steps % 10000 == 0:
|
||
magail.save_models("./models")
|
||
|