代替点云
This commit is contained in:
115
train_magail.py
Normal file
115
train_magail.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from Env.scenario_env import MultiAgentScenarioEnv
|
||||
from Algorithm.magail import MAGAIL
|
||||
from Algorithm.buffer import RolloutBuffer # 假设 Buffer 在这里
|
||||
# 假设你有加载专家数据的工具
|
||||
# from utils import load_expert_buffer
|
||||
|
||||
# --- 配置 ---
|
||||
CONFIG = {
|
||||
"data_dir": "/home/huangfukk/mdsn",
|
||||
# ... 其他配置 ...
|
||||
"state_dim": 30, # 你的 Observation 维度
|
||||
"action_dim": 2, # [steering, throttle]
|
||||
"rollout_length": 2048, # Buffer 大小
|
||||
}
|
||||
|
||||
def train():
|
||||
# 1. 环境
|
||||
env = MultiAgentScenarioEnv(config={...}, agent2policy=None)
|
||||
|
||||
# 2. 专家 Buffer (伪代码)
|
||||
# expert_buffer = RolloutBuffer(...)
|
||||
# expert_buffer.load(...)
|
||||
# 必须保证 expert_buffer.sample() 返回 (state, next_state) 供 Discriminator 训练
|
||||
|
||||
# 3. 初始化 MAGAIL
|
||||
magail = MAGAIL(
|
||||
buffer_exp=expert_buffer, # 传入专家 buffer
|
||||
input_dim=(CONFIG["state_dim"],),
|
||||
action_shape=(CONFIG["action_dim"],), # PPO 初始化可能需要 action_shape,原代码好像没传?
|
||||
device=torch.device("cuda"),
|
||||
rollout_length=CONFIG["rollout_length"]
|
||||
)
|
||||
|
||||
writer = SummaryWriter("./logs")
|
||||
|
||||
# 4. 训练循环
|
||||
total_steps = 0
|
||||
obs_dict = env.reset()
|
||||
|
||||
# 将 Dict Obs 转为 Array: (N_agents, Obs_dim)
|
||||
# 假设所有 Agent 的 Obs 维度相同
|
||||
agents = list(obs_dict.keys())
|
||||
current_obs = np.stack([obs_dict[a] for a in agents])
|
||||
|
||||
# 如果需要 state_gail,假设它就是 obs
|
||||
current_state_gail = current_obs.copy()
|
||||
|
||||
while total_steps < 1e7:
|
||||
# --- 收集数据 (Rollout) ---
|
||||
# PPO 的 buffer 长度是 rollout_length
|
||||
# 我们可以在这里调用 magail.step 来自动处理 explore 和 buffer append
|
||||
|
||||
# 注意:magail.step 内部调用了 env.step,这可能不适用于 Multi-Agent 环境返回 Dict 的情况
|
||||
# 原 PPO.step 似乎是为 Single-Agent 或 VectorEnv 设计的
|
||||
# 这里我们需要手动写 Rollout 循环或者修改 PPO.step 以适配 Dict 返回值
|
||||
|
||||
# --- 方案:手动 Rollout 适配 Multi-Agent ---
|
||||
for _ in range(CONFIG["rollout_length"]):
|
||||
# 1. 决策
|
||||
actions_list, log_pis_list = magail.explore(current_obs)
|
||||
|
||||
# 2. 拼装 Action Dict
|
||||
action_dict = {agent_id: action for agent_id, action in zip(agents, actions_list)}
|
||||
|
||||
# 3. 环境步进
|
||||
next_obs_dict, rewards_dict, dones_dict, infos_dict = env.step(action_dict)
|
||||
|
||||
# 4. 处理返回值
|
||||
# 需要处理 Agent 死亡/重置的情况。这里简化假设 Agent 数量不变
|
||||
next_obs = np.stack([next_obs_dict[a] for a in agents])
|
||||
rewards = np.array([rewards_dict[a] for a in agents])
|
||||
dones = np.array([dones_dict[a] for a in agents])
|
||||
# truncated 通常在 infos 里或者 dones 里隐含,需根据 gym 版本确认
|
||||
terminated = dones # 简化
|
||||
truncated = [False] * len(agents) # 简化
|
||||
|
||||
# 5. 存入 Buffer
|
||||
# 获取当前 Actor 的均值方差用于 PPO 更新
|
||||
with torch.no_grad():
|
||||
# 重新计算一遍或者在 explore 里返回
|
||||
# 这里 magail.actor(state) 返回的是 mean (deterministic action)
|
||||
means = magail.actor(torch.tensor(current_obs, device=magail.device, dtype=torch.float)).cpu().numpy()
|
||||
stds = magail.actor.log_stds.exp().detach().cpu().numpy()
|
||||
# 如果是 StateIndependentPolicy,std 可能是共享的参数,维度需要广播
|
||||
if stds.shape[0] != len(agents):
|
||||
stds = np.repeat(stds, len(agents), axis=0)
|
||||
|
||||
magail.buffer.append(
|
||||
current_obs, current_state_gail, actions_list,
|
||||
rewards, dones, terminated, log_pis_list,
|
||||
next_obs, next_obs, # next_state_gail = next_obs
|
||||
means, stds
|
||||
)
|
||||
|
||||
current_obs = next_obs
|
||||
current_state_gail = next_obs # 更新 gail state
|
||||
total_steps += len(agents) # 步数增加 N
|
||||
|
||||
# 处理 Done
|
||||
if all(dones):
|
||||
obs_dict = env.reset()
|
||||
current_obs = np.stack([obs_dict[a] for a in agents])
|
||||
current_state_gail = current_obs
|
||||
|
||||
# --- 更新 ---
|
||||
# 当 Buffer 满时 (rollout_length),调用 update
|
||||
magail.update(writer, total_steps)
|
||||
|
||||
# 保存
|
||||
if total_steps % 10000 == 0:
|
||||
magail.save_models("./models")
|
||||
|
||||
Reference in New Issue
Block a user