完成回放模式与仿真模式,过滤非车道生成车辆,增加对于行人自行车的过滤功能
This commit is contained in:
@@ -1,40 +1,362 @@
|
||||
import argparse
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from Env.simple_idm_policy import ConstantVelocityPolicy
|
||||
from simple_idm_policy import ConstantVelocityPolicy
|
||||
from replay_policy import ReplayPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
|
||||
WAYMO_DATA_DIR = r"/home/zhy/桌面/MAGAIL_TR/Env"
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn/exp_converted"
|
||||
|
||||
def main():
|
||||
|
||||
def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
|
||||
scenario_id=None, use_scenario_duration=False,
|
||||
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
|
||||
"""
|
||||
回放模式:严格按照专家轨迹回放
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录
|
||||
num_episodes: 回合数(如果指定scenario_id,则忽略)
|
||||
horizon: 最大步数(如果use_scenario_duration=True,则自动设置)
|
||||
render: 是否渲染
|
||||
debug: 是否调试模式
|
||||
scenario_id: 指定场景ID(可选)
|
||||
use_scenario_duration: 是否使用场景原始时长
|
||||
spawn_vehicles: 是否生成车辆(默认True)
|
||||
spawn_pedestrians: 是否生成行人(默认True)
|
||||
spawn_cyclists: 是否生成自行车(默认True)
|
||||
"""
|
||||
print("=" * 50)
|
||||
print("运行模式: 专家轨迹回放 (Replay Mode)")
|
||||
if scenario_id is not None:
|
||||
print(f"指定场景ID: {scenario_id}")
|
||||
if use_scenario_duration:
|
||||
print("使用场景原始时长")
|
||||
print("=" * 50)
|
||||
|
||||
# 如果指定了场景ID,只运行1个回合
|
||||
if scenario_id is not None:
|
||||
num_episodes = 1
|
||||
|
||||
# ✅ 环境创建移到循环外面,避免重复创建
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
# "data_directory": AssetLoader.file_path(AssetLoader.asset_path, "waymo", unix_style=False),
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"horizon": horizon,
|
||||
"use_render": render,
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": False, # 回放模式下不需要反应式交通
|
||||
"manual_control": False,
|
||||
"filter_offroad_vehicles": True, # 启用车道过滤
|
||||
"lane_tolerance": 3.0,
|
||||
"replay_mode": True, # 标记为回放模式
|
||||
"debug": debug,
|
||||
"specific_scenario_id": scenario_id, # 指定场景ID
|
||||
"use_scenario_duration": use_scenario_duration, # 使用场景时长
|
||||
# 对象类型过滤
|
||||
"spawn_vehicles": spawn_vehicles,
|
||||
"spawn_pedestrians": spawn_pedestrians,
|
||||
"spawn_cyclists": spawn_cyclists,
|
||||
},
|
||||
agent2policy=None # 回放模式不需要统一策略
|
||||
)
|
||||
|
||||
try:
|
||||
for episode in range(num_episodes):
|
||||
print(f"\n{'='*50}")
|
||||
print(f"回合 {episode + 1}/{num_episodes}")
|
||||
if scenario_id is not None:
|
||||
print(f"场景ID: {scenario_id}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
# ✅ 如果不是指定场景,使用seed来遍历不同场景
|
||||
seed = scenario_id if scenario_id is not None else episode
|
||||
obs = env.reset(seed=seed)
|
||||
|
||||
# 为每个车辆分配 ReplayPolicy
|
||||
replay_policies = {}
|
||||
for agent_id, vehicle in env.controlled_agents.items():
|
||||
vehicle_id = vehicle.expert_vehicle_id
|
||||
if vehicle_id in env.expert_trajectories:
|
||||
replay_policy = ReplayPolicy(
|
||||
env.expert_trajectories[vehicle_id],
|
||||
vehicle_id
|
||||
)
|
||||
vehicle.set_policy(replay_policy)
|
||||
replay_policies[agent_id] = replay_policy
|
||||
|
||||
# 输出场景信息
|
||||
actual_horizon = env.config["horizon"]
|
||||
print(f"初始化完成:")
|
||||
print(f" 可控车辆数: {len(env.controlled_agents)}")
|
||||
print(f" 专家轨迹数: {len(env.expert_trajectories)}")
|
||||
print(f" 场景时长: {env.scenario_max_duration} 步")
|
||||
print(f" 实际Horizon: {actual_horizon} 步")
|
||||
|
||||
step_count = 0
|
||||
active_vehicles_count = []
|
||||
|
||||
while True:
|
||||
# 在回放模式下,直接使用专家轨迹设置车辆状态
|
||||
for agent_id, vehicle in list(env.controlled_agents.items()):
|
||||
vehicle_id = vehicle.expert_vehicle_id
|
||||
if vehicle_id in env.expert_trajectories and agent_id in replay_policies:
|
||||
target_state = replay_policies[agent_id].get_target_state(env.round)
|
||||
if target_state is not None:
|
||||
# 直接设置车辆状态(绕过物理引擎)
|
||||
# 只使用xy坐标,保持车辆在地面上
|
||||
position_2d = target_state['position'][:2]
|
||||
vehicle.set_position(position_2d)
|
||||
vehicle.set_heading_theta(target_state['heading'])
|
||||
vehicle.set_velocity(target_state['velocity'][:2] if len(target_state['velocity']) > 2 else target_state['velocity'])
|
||||
|
||||
# 使用空动作进行步进
|
||||
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
|
||||
if render:
|
||||
env.render(mode="topdown")
|
||||
|
||||
step_count += 1
|
||||
active_vehicles_count.append(len(env.controlled_agents))
|
||||
|
||||
# 每50步打印一次状态
|
||||
if step_count % 50 == 0:
|
||||
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
|
||||
|
||||
# 调试模式下打印车辆高度信息
|
||||
if debug and len(env.controlled_agents) > 0:
|
||||
sample_vehicle = list(env.controlled_agents.values())[0]
|
||||
z_pos = sample_vehicle.position[2] if len(sample_vehicle.position) > 2 else 0
|
||||
print(f" [DEBUG] 示例车辆高度: z={z_pos:.3f}m")
|
||||
|
||||
if dones["__all__"]:
|
||||
print(f"\n回合结束统计:")
|
||||
print(f" 总步数: {step_count}")
|
||||
print(f" 最大同时车辆数: {max(active_vehicles_count) if active_vehicles_count else 0}")
|
||||
print(f" 平均车辆数: {sum(active_vehicles_count) / len(active_vehicles_count) if active_vehicles_count else 0:.1f}")
|
||||
if use_scenario_duration:
|
||||
print(f" 场景完整回放: {'是' if step_count >= env.scenario_max_duration else '否'}")
|
||||
break
|
||||
finally:
|
||||
# ✅ 确保环境被正确关闭
|
||||
env.close()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("回放完成!")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
|
||||
scenario_id=None, use_scenario_duration=False,
|
||||
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
|
||||
"""
|
||||
仿真模式:使用自定义策略控制车辆
|
||||
车辆根据专家数据的初始位姿生成,然后由策略控制
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录
|
||||
num_episodes: 回合数
|
||||
horizon: 最大步数
|
||||
render: 是否渲染
|
||||
debug: 是否调试模式
|
||||
scenario_id: 指定场景ID(可选)
|
||||
use_scenario_duration: 是否使用场景原始时长
|
||||
spawn_vehicles: 是否生成车辆(默认True)
|
||||
spawn_pedestrians: 是否生成行人(默认True)
|
||||
spawn_cyclists: 是否生成自行车(默认True)
|
||||
"""
|
||||
print("=" * 50)
|
||||
print("运行模式: 策略仿真 (Simulation Mode)")
|
||||
if scenario_id is not None:
|
||||
print(f"指定场景ID: {scenario_id}")
|
||||
if use_scenario_duration:
|
||||
print("使用场景原始时长")
|
||||
print("=" * 50)
|
||||
|
||||
# 如果指定了场景ID,只运行1个回合
|
||||
if scenario_id is not None:
|
||||
num_episodes = 1
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 300,
|
||||
"use_render": True,
|
||||
"horizon": horizon,
|
||||
"use_render": render,
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": True,
|
||||
"manual_control": True,
|
||||
"manual_control": False,
|
||||
"filter_offroad_vehicles": True, # 启用车道过滤
|
||||
"lane_tolerance": 3.0,
|
||||
"replay_mode": False, # 仿真模式
|
||||
"debug": debug,
|
||||
"specific_scenario_id": scenario_id, # 指定场景ID
|
||||
"use_scenario_duration": use_scenario_duration, # 使用场景时长
|
||||
# 对象类型过滤
|
||||
"spawn_vehicles": spawn_vehicles,
|
||||
"spawn_pedestrians": spawn_pedestrians,
|
||||
"spawn_cyclists": spawn_cyclists,
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
obs = env.reset(0
|
||||
)
|
||||
for step in range(10000):
|
||||
actions = {
|
||||
aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents
|
||||
}
|
||||
try:
|
||||
for episode in range(num_episodes):
|
||||
print(f"\n{'='*50}")
|
||||
print(f"回合 {episode + 1}/{num_episodes}")
|
||||
if scenario_id is not None:
|
||||
print(f"场景ID: {scenario_id}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
env.render(mode="topdown")
|
||||
seed = scenario_id if scenario_id is not None else episode
|
||||
obs = env.reset(seed=seed)
|
||||
|
||||
if dones["__all__"]:
|
||||
break
|
||||
actual_horizon = env.config["horizon"]
|
||||
print(f"初始化完成:")
|
||||
print(f" 可控车辆数: {len(env.controlled_agents)}")
|
||||
print(f" 场景时长: {env.scenario_max_duration} 步")
|
||||
print(f" 实际Horizon: {actual_horizon} 步")
|
||||
|
||||
env.close()
|
||||
step_count = 0
|
||||
total_reward = 0.0
|
||||
|
||||
while True:
|
||||
# 使用策略生成动作
|
||||
actions = {
|
||||
aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents
|
||||
}
|
||||
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
|
||||
if render:
|
||||
env.render(mode="topdown")
|
||||
|
||||
step_count += 1
|
||||
total_reward += sum(rewards.values())
|
||||
|
||||
# 每50步打印一次状态
|
||||
if step_count % 50 == 0:
|
||||
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
|
||||
|
||||
if dones["__all__"]:
|
||||
print(f"\n回合结束统计:")
|
||||
print(f" 总步数: {step_count}")
|
||||
print(f" 总奖励: {total_reward:.2f}")
|
||||
break
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("仿真完成!")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MetaDrive 多智能体环境运行脚本")
|
||||
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
choices=["replay", "simulation"],
|
||||
default="simulation",
|
||||
help="运行模式: replay=专家轨迹回放, simulation=策略仿真 (默认: simulation)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default=WAYMO_DATA_DIR,
|
||||
help=f"数据目录路径 (默认: {WAYMO_DATA_DIR})"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="运行回合数 (默认: 1)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--horizon",
|
||||
type=int,
|
||||
default=300,
|
||||
help="每回合最大步数 (默认: 300,如果启用 --use_scenario_duration 则自动设置)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_render",
|
||||
action="store_true",
|
||||
help="禁用渲染(加速运行)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="启用调试模式(显示详细日志)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--scenario_id",
|
||||
type=int,
|
||||
default=None,
|
||||
help="指定场景ID(可选,如指定则只运行该场景)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_scenario_duration",
|
||||
action="store_true",
|
||||
help="使用场景原始时长作为horizon(自动停止)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_vehicles",
|
||||
action="store_true",
|
||||
help="禁止生成车辆"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_pedestrians",
|
||||
action="store_true",
|
||||
help="禁止生成行人"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_cyclists",
|
||||
action="store_true",
|
||||
help="禁止生成自行车"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == "replay":
|
||||
run_replay_mode(
|
||||
data_dir=args.data_dir,
|
||||
num_episodes=args.episodes,
|
||||
horizon=args.horizon,
|
||||
render=not args.no_render,
|
||||
debug=args.debug,
|
||||
scenario_id=args.scenario_id,
|
||||
use_scenario_duration=args.use_scenario_duration,
|
||||
spawn_vehicles=not args.no_vehicles,
|
||||
spawn_pedestrians=not args.no_pedestrians,
|
||||
spawn_cyclists=not args.no_cyclists
|
||||
)
|
||||
else:
|
||||
run_simulation_mode(
|
||||
data_dir=args.data_dir,
|
||||
num_episodes=args.episodes,
|
||||
horizon=args.horizon,
|
||||
render=not args.no_render,
|
||||
debug=args.debug,
|
||||
scenario_id=args.scenario_id,
|
||||
use_scenario_duration=args.use_scenario_duration,
|
||||
spawn_vehicles=not args.no_vehicles,
|
||||
spawn_pedestrians=not args.no_pedestrians,
|
||||
spawn_cyclists=not args.no_cyclists
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user