代替点云

This commit is contained in:
2025-12-07 20:15:35 +08:00
parent 113e86bda2
commit 8cb86115c2
10 changed files with 527 additions and 45 deletions

View File

@@ -4,7 +4,7 @@ from simple_idm_policy import ConstantVelocityPolicy
from replay_policy import ReplayPolicy
from metadrive.engine.asset_loader import AssetLoader
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn/exp_converted"
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
@@ -40,10 +40,10 @@ def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=Fa
# ✅ 环境创建移到循环外面,避免重复创建
env = MultiAgentScenarioEnv(
config={
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
"is_multi_agent": True,
"horizon": horizon,
"use_render": render,
"use_render": render, # 如果False会完全禁用渲染避免LANE_FREEWAY错误
"sequential_seed": True,
"reactive_traffic": False, # 回放模式下不需要反应式交通
"manual_control": False,
@@ -57,20 +57,33 @@ def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=Fa
"spawn_vehicles": spawn_vehicles,
"spawn_pedestrians": spawn_pedestrians,
"spawn_cyclists": spawn_cyclists,
# ✅ 关键:设置可用场景数量
#"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数
},
agent2policy=None # 回放模式不需要统一策略
)
try:
# 获取可用场景数量
num_scenarios = env.config.get("num_scenarios", 1)
print(f"可用场景数量: {num_scenarios}")
for episode in range(num_episodes):
print(f"\n{'='*50}")
print(f"回合 {episode + 1}/{num_episodes}")
if scenario_id is not None:
print(f"场景ID: {scenario_id}")
else:
# 循环使用场景
scenario_idx = episode % num_scenarios
print(f"使用场景索引: {scenario_idx}")
print(f"{'='*50}")
# ✅ 如果不是指定场景,使用seed来遍历不同场景
seed = scenario_id if scenario_id is not None else episode
# ✅ 如果不是指定场景,使用循环的场景索引
if scenario_id is not None:
seed = scenario_id
else:
seed = episode % num_scenarios
obs = env.reset(seed=seed)
# 为每个车辆分配 ReplayPolicy
@@ -180,11 +193,11 @@ def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debu
env = MultiAgentScenarioEnv(
config={
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
"is_multi_agent": True,
"num_controlled_agents": 3,
"horizon": horizon,
"use_render": render,
"use_render": render, # 如果False会完全禁用渲染避免LANE_FREEWAY错误
"sequential_seed": True,
"reactive_traffic": True,
"manual_control": False,
@@ -198,19 +211,33 @@ def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debu
"spawn_vehicles": spawn_vehicles,
"spawn_pedestrians": spawn_pedestrians,
"spawn_cyclists": spawn_cyclists,
# ✅ 关键:设置可用场景数量
#"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数
},
agent2policy=ConstantVelocityPolicy(target_speed=50)
)
try:
# 获取可用场景数量
num_scenarios = env.config.get("num_scenarios", 1)
print(f"可用场景数量: {num_scenarios}")
for episode in range(num_episodes):
print(f"\n{'='*50}")
print(f"回合 {episode + 1}/{num_episodes}")
if scenario_id is not None:
print(f"场景ID: {scenario_id}")
else:
# 循环使用场景
scenario_idx = episode % num_scenarios
print(f"使用场景索引: {scenario_idx}")
print(f"{'='*50}")
seed = scenario_id if scenario_id is not None else episode
# ✅ 如果不是指定场景,使用循环的场景索引
if scenario_id is not None:
seed = scenario_id
else:
seed = episode % num_scenarios
obs = env.reset(seed=seed)
actual_horizon = env.config["horizon"]