代替点云

This commit is contained in:
2025-12-07 20:15:35 +08:00
parent 113e86bda2
commit 8cb86115c2
10 changed files with 527 additions and 45 deletions

115
train_magail.py Normal file
View File

@@ -0,0 +1,115 @@
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from Env.scenario_env import MultiAgentScenarioEnv
from Algorithm.magail import MAGAIL
from Algorithm.buffer import RolloutBuffer # 假设 Buffer 在这里
# 假设你有加载专家数据的工具
# from utils import load_expert_buffer
# --- 配置 ---
CONFIG = {
"data_dir": "/home/huangfukk/mdsn",
# ... 其他配置 ...
"state_dim": 30, # 你的 Observation 维度
"action_dim": 2, # [steering, throttle]
"rollout_length": 2048, # Buffer 大小
}
def train():
# 1. 环境
env = MultiAgentScenarioEnv(config={...}, agent2policy=None)
# 2. 专家 Buffer (伪代码)
# expert_buffer = RolloutBuffer(...)
# expert_buffer.load(...)
# 必须保证 expert_buffer.sample() 返回 (state, next_state) 供 Discriminator 训练
# 3. 初始化 MAGAIL
magail = MAGAIL(
buffer_exp=expert_buffer, # 传入专家 buffer
input_dim=(CONFIG["state_dim"],),
action_shape=(CONFIG["action_dim"],), # PPO 初始化可能需要 action_shape原代码好像没传
device=torch.device("cuda"),
rollout_length=CONFIG["rollout_length"]
)
writer = SummaryWriter("./logs")
# 4. 训练循环
total_steps = 0
obs_dict = env.reset()
# 将 Dict Obs 转为 Array: (N_agents, Obs_dim)
# 假设所有 Agent 的 Obs 维度相同
agents = list(obs_dict.keys())
current_obs = np.stack([obs_dict[a] for a in agents])
# 如果需要 state_gail假设它就是 obs
current_state_gail = current_obs.copy()
while total_steps < 1e7:
# --- 收集数据 (Rollout) ---
# PPO 的 buffer 长度是 rollout_length
# 我们可以在这里调用 magail.step 来自动处理 explore 和 buffer append
# 注意magail.step 内部调用了 env.step这可能不适用于 Multi-Agent 环境返回 Dict 的情况
# 原 PPO.step 似乎是为 Single-Agent 或 VectorEnv 设计的
# 这里我们需要手动写 Rollout 循环或者修改 PPO.step 以适配 Dict 返回值
# --- 方案:手动 Rollout 适配 Multi-Agent ---
for _ in range(CONFIG["rollout_length"]):
# 1. 决策
actions_list, log_pis_list = magail.explore(current_obs)
# 2. 拼装 Action Dict
action_dict = {agent_id: action for agent_id, action in zip(agents, actions_list)}
# 3. 环境步进
next_obs_dict, rewards_dict, dones_dict, infos_dict = env.step(action_dict)
# 4. 处理返回值
# 需要处理 Agent 死亡/重置的情况。这里简化假设 Agent 数量不变
next_obs = np.stack([next_obs_dict[a] for a in agents])
rewards = np.array([rewards_dict[a] for a in agents])
dones = np.array([dones_dict[a] for a in agents])
# truncated 通常在 infos 里或者 dones 里隐含,需根据 gym 版本确认
terminated = dones # 简化
truncated = [False] * len(agents) # 简化
# 5. 存入 Buffer
# 获取当前 Actor 的均值方差用于 PPO 更新
with torch.no_grad():
# 重新计算一遍或者在 explore 里返回
# 这里 magail.actor(state) 返回的是 mean (deterministic action)
means = magail.actor(torch.tensor(current_obs, device=magail.device, dtype=torch.float)).cpu().numpy()
stds = magail.actor.log_stds.exp().detach().cpu().numpy()
# 如果是 StateIndependentPolicystd 可能是共享的参数,维度需要广播
if stds.shape[0] != len(agents):
stds = np.repeat(stds, len(agents), axis=0)
magail.buffer.append(
current_obs, current_state_gail, actions_list,
rewards, dones, terminated, log_pis_list,
next_obs, next_obs, # next_state_gail = next_obs
means, stds
)
current_obs = next_obs
current_state_gail = next_obs # 更新 gail state
total_steps += len(agents) # 步数增加 N
# 处理 Done
if all(dones):
obs_dict = env.reset()
current_obs = np.stack([obs_dict[a] for a in agents])
current_state_gail = current_obs
# --- 更新 ---
# 当 Buffer 满时 (rollout_length),调用 update
magail.update(writer, total_steps)
# 保存
if total_steps % 10000 == 0:
magail.save_models("./models")