Compare commits
3 Commits
dev
...
train_not_
| Author | SHA1 | Date | |
|---|---|---|---|
| 3f7e183c4b | |||
| b626702cbb | |||
| 22ce995916 |
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# 日志文件
|
||||
Env/logs/
|
||||
*.log
|
||||
6
.vscode/settings.json
vendored
6
.vscode/settings.json
vendored
@@ -1,6 +0,0 @@
|
||||
{
|
||||
"cursorpyright.analysis.extraPaths": [
|
||||
"/home/huangfukk/mdsn/metadrive",
|
||||
"/home/huangfukk/mdsn/scenarionet"
|
||||
]
|
||||
}
|
||||
27
Algorithm/__init__.py
Normal file
27
Algorithm/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
MAGAIL Algorithm Package
|
||||
|
||||
多智能体生成对抗模仿学习算法实现
|
||||
"""
|
||||
|
||||
from .magail import MAGAIL
|
||||
from .ppo import PPO
|
||||
from .disc import GAILDiscrim
|
||||
from .bert import Bert
|
||||
from .policy import StateIndependentPolicy
|
||||
from .buffer import RolloutBuffer
|
||||
from .utils import Normalizer, build_mlp, reparameterize, evaluate_lop_pi
|
||||
|
||||
__all__ = [
|
||||
'MAGAIL',
|
||||
'PPO',
|
||||
'GAILDiscrim',
|
||||
'Bert',
|
||||
'StateIndependentPolicy',
|
||||
'RolloutBuffer',
|
||||
'Normalizer',
|
||||
'build_mlp',
|
||||
'reparameterize',
|
||||
'evaluate_lop_pi',
|
||||
]
|
||||
|
||||
BIN
Algorithm/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
Algorithm/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Algorithm/__pycache__/__init__.cpython-313.pyc
Normal file
BIN
Algorithm/__pycache__/__init__.cpython-313.pyc
Normal file
Binary file not shown.
BIN
Algorithm/__pycache__/bert.cpython-310.pyc
Normal file
BIN
Algorithm/__pycache__/bert.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Algorithm/__pycache__/buffer.cpython-310.pyc
Normal file
BIN
Algorithm/__pycache__/buffer.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Algorithm/__pycache__/disc.cpython-310.pyc
Normal file
BIN
Algorithm/__pycache__/disc.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Algorithm/__pycache__/magail.cpython-310.pyc
Normal file
BIN
Algorithm/__pycache__/magail.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Algorithm/__pycache__/magail.cpython-313.pyc
Normal file
BIN
Algorithm/__pycache__/magail.cpython-313.pyc
Normal file
Binary file not shown.
BIN
Algorithm/__pycache__/policy.cpython-310.pyc
Normal file
BIN
Algorithm/__pycache__/policy.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Algorithm/__pycache__/ppo.cpython-310.pyc
Normal file
BIN
Algorithm/__pycache__/ppo.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Algorithm/__pycache__/utils.cpython-310.pyc
Normal file
BIN
Algorithm/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
@@ -28,17 +28,26 @@ class Bert(nn.Module):
|
||||
self.classifier.train()
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
# x可以是2D (batch_size, input_dim) 或 3D (batch_size, seq_len, feature_dim)
|
||||
is_2d_input = (x.dim() == 2)
|
||||
|
||||
if is_2d_input:
|
||||
# 如果输入是2D,添加一个序列维度
|
||||
x = x.unsqueeze(1) # (batch_size, 1, input_dim)
|
||||
|
||||
# x: (batch_size, seq_len, input_dim)
|
||||
# 线性投影
|
||||
x = self.projection(x) # (batch_size, input_dim, embed_dim)
|
||||
x = self.projection(x) # (batch_size, seq_len, embed_dim)
|
||||
|
||||
batch_size = x.size(0)
|
||||
if self.CLS:
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
x = torch.cat([cls_tokens, x], dim=1) # (batch_size, 29, embed_dim)
|
||||
x = torch.cat([cls_tokens, x], dim=1) # (batch_size, seq_len+1, embed_dim)
|
||||
|
||||
# 添加位置编码
|
||||
x = x + self.pos_embed
|
||||
# 添加位置编码(截断或扩展以匹配序列长度)
|
||||
seq_len = x.size(1)
|
||||
pos_embed = self.pos_embed[:, :seq_len, :]
|
||||
x = x + pos_embed
|
||||
|
||||
# 转置为(seq_len, batch_size, embed_dim)
|
||||
x = x.permute(1, 0, 2)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from .bert import Bert
|
||||
try:
|
||||
from .bert import Bert
|
||||
except ImportError:
|
||||
from bert import Bert
|
||||
|
||||
|
||||
DISC_LOGIT_INIT_SCALE = 1.0
|
||||
|
||||
@@ -2,21 +2,30 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .disc import GAILDiscrim
|
||||
from .ppo import PPO
|
||||
from .utils import Normalizer
|
||||
try:
|
||||
from .disc import GAILDiscrim
|
||||
from .ppo import PPO
|
||||
from .utils import Normalizer
|
||||
except ImportError:
|
||||
from disc import GAILDiscrim
|
||||
from ppo import PPO
|
||||
from utils import Normalizer
|
||||
|
||||
|
||||
class MAGAIL(PPO):
|
||||
def __init__(self, buffer_exp, input_dim, device,
|
||||
def __init__(self, buffer_exp, input_dim, device, action_shape=(2,),
|
||||
disc_coef=20.0, disc_grad_penalty=0.1, disc_logit_reg=0.25, disc_weight_decay=0.0005,
|
||||
lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True
|
||||
lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True,
|
||||
**kwargs # 接受其他PPO参数
|
||||
):
|
||||
super().__init__(state_shape=input_dim, device=device)
|
||||
super().__init__(state_shape=input_dim, device=device, action_shape=action_shape, **kwargs)
|
||||
self.learning_steps = 0
|
||||
self.learning_steps_disc = 0
|
||||
|
||||
self.disc = GAILDiscrim(input_dim=input_dim)
|
||||
# 如果input_dim是元组,提取第一个元素
|
||||
state_dim = input_dim[0] if isinstance(input_dim, tuple) else input_dim
|
||||
# 判别器输入是state+next_state拼接,所以维度是state_dim*2
|
||||
self.disc = GAILDiscrim(input_dim=state_dim*2).to(device) # 移动到指定设备
|
||||
self.disc_grad_penalty = disc_grad_penalty
|
||||
self.disc_coef = disc_coef
|
||||
self.disc_logit_reg = disc_logit_reg
|
||||
@@ -27,7 +36,9 @@ class MAGAIL(PPO):
|
||||
|
||||
self.normalizer = None
|
||||
if use_gail_norm:
|
||||
self.normalizer = Normalizer(self.state_shape[0]*2)
|
||||
# state_shape已经是元组形式
|
||||
state_dim = self.state_shape[0] if isinstance(self.state_shape, tuple) else self.state_shape
|
||||
self.normalizer = Normalizer(state_dim*2)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.buffer_exp = buffer_exp
|
||||
@@ -52,7 +63,7 @@ class MAGAIL(PPO):
|
||||
# grad penalty
|
||||
sample_expert = states_exp_cp
|
||||
sample_expert.requires_grad = True
|
||||
disc = self.disc.linear(self.disc.trunk(sample_expert))
|
||||
disc = self.disc(sample_expert) # 直接调用forward方法
|
||||
ones = torch.ones(disc.size(), device=disc.device)
|
||||
disc_demo_grad = torch.autograd.grad(disc, sample_expert,
|
||||
grad_outputs=ones,
|
||||
@@ -91,7 +102,8 @@ class MAGAIL(PPO):
|
||||
|
||||
# Samples from current policy trajectories.
|
||||
samples_policy = self.buffer.sample(self.batch_size)
|
||||
states, next_states = samples_policy[1], samples_policy[-3]
|
||||
# samples_policy返回: (states, actions, rewards, dones, tm_dones, log_pis, next_states, means, stds)
|
||||
states, next_states = samples_policy[0], samples_policy[6] # 修正: 使用states而不是actions
|
||||
states = torch.cat([states, next_states], dim=-1)
|
||||
|
||||
# Samples from expert demonstrations.
|
||||
@@ -129,6 +141,8 @@ class MAGAIL(PPO):
|
||||
return rewards_t.mean().item() + rewards_i.mean().item()
|
||||
|
||||
def save_models(self, path):
|
||||
# 确保目录存在
|
||||
os.makedirs(path, exist_ok=True)
|
||||
torch.save({
|
||||
'actor': self.actor.state_dict(),
|
||||
'critic': self.critic.state_dict(),
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from .utils import build_mlp, reparameterize, evaluate_lop_pi
|
||||
try:
|
||||
from .utils import build_mlp, reparameterize, evaluate_lop_pi
|
||||
except ImportError:
|
||||
from utils import build_mlp, reparameterize, evaluate_lop_pi
|
||||
|
||||
class StateIndependentPolicy(nn.Module):
|
||||
|
||||
|
||||
@@ -3,9 +3,14 @@ import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
from buffer import RolloutBuffer
|
||||
from bert import Bert
|
||||
from policy import StateIndependentPolicy
|
||||
try:
|
||||
from .buffer import RolloutBuffer
|
||||
from .bert import Bert
|
||||
from .policy import StateIndependentPolicy
|
||||
except ImportError:
|
||||
from buffer import RolloutBuffer
|
||||
from bert import Bert
|
||||
from policy import StateIndependentPolicy
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
@@ -55,7 +60,7 @@ class Algorithm(ABC):
|
||||
|
||||
class PPO(Algorithm):
|
||||
|
||||
def __init__(self, state_shape, device, gamma=0.995, rollout_length=2048,
|
||||
def __init__(self, state_shape, device, action_shape=(2,), gamma=0.995, rollout_length=2048,
|
||||
units_actor=(64, 64), epoch_ppo=10, clip_eps=0.2,
|
||||
lambd=0.97, max_grad_norm=1.0, desired_kl=0.01, surrogate_loss_coef=2.,
|
||||
value_loss_coef=5., entropy_coef=0., bounds_loss_coef=10., lr_actor=1e-3, lr_critic=1e-3,
|
||||
@@ -66,6 +71,7 @@ class PPO(Algorithm):
|
||||
self.lr_critic = lr_critic
|
||||
self.lr_disc = lr_disc
|
||||
self.auto_lr = auto_lr
|
||||
self.action_shape = action_shape
|
||||
|
||||
self.use_adv_norm = use_adv_norm
|
||||
|
||||
@@ -86,8 +92,10 @@ class PPO(Algorithm):
|
||||
).to(device)
|
||||
|
||||
# Critic.
|
||||
# 如果state_shape是元组,提取第一个元素
|
||||
state_dim = state_shape[0] if isinstance(state_shape, tuple) else state_shape
|
||||
self.critic = Bert(
|
||||
input_dim=state_shape,
|
||||
input_dim=state_dim,
|
||||
output_dim=1
|
||||
).to(device)
|
||||
|
||||
@@ -145,14 +153,12 @@ class PPO(Algorithm):
|
||||
targets, gaes = self.calculate_gae(
|
||||
values, rewards, dones, tm_dones, next_values, self.gamma, self.lambd)
|
||||
|
||||
state_list = states.permute(1, 0, 2)
|
||||
action_list = actions.permute(1, 0, 2)
|
||||
|
||||
# 处理批量数据(不需要按智能体分组,因为buffer中已经混合了所有智能体的数据)
|
||||
for i in range(self.epoch_ppo):
|
||||
self.learning_steps_ppo += 1
|
||||
self.update_critic(states, targets, writer)
|
||||
for state, action, log_pi in state_list, action_list, log_pi_list:
|
||||
self.update_actor(state, action, log_pi, gaes, mus, sigmas, writer)
|
||||
# 直接使用整个batch进行actor更新
|
||||
self.update_actor(states, actions, log_pi_list, gaes, mus, sigmas, writer)
|
||||
|
||||
# self.lr_decay(total_steps, writer)
|
||||
|
||||
|
||||
136
CHANGELOG.md
Normal file
136
CHANGELOG.md
Normal file
@@ -0,0 +1,136 @@
|
||||
# 更新日志
|
||||
|
||||
## 2025-01-20 问题修复与优化
|
||||
|
||||
### ✅ 已解决的问题
|
||||
|
||||
#### 1. 车辆生成位置偏差问题
|
||||
**问题描述:** 部分车辆生成于草坪、停车场等非车道区域
|
||||
|
||||
**解决方案:**
|
||||
- 实现 `_is_position_on_lane()` 方法:检测位置是否在有效车道上
|
||||
- 实现 `_filter_valid_spawn_positions()` 方法:自动过滤非车道区域车辆
|
||||
- 支持容差参数(默认3米)处理边界情况
|
||||
- 在 `reset()` 时自动执行过滤,并输出统计信息
|
||||
|
||||
**配置参数:**
|
||||
```python
|
||||
"filter_offroad_vehicles": True, # 启用/禁用过滤
|
||||
"lane_tolerance": 3.0, # 容差范围(米)
|
||||
"max_controlled_vehicles": 10, # 最大车辆数限制
|
||||
```
|
||||
|
||||
#### 2. 红绿灯信息采集问题
|
||||
**问题描述:**
|
||||
- 部分红绿灯状态为 None
|
||||
- 车道分段时部分车辆无法获取红绿灯状态
|
||||
|
||||
**解决方案:**
|
||||
- 实现 `_get_traffic_light_state()` 方法,采用双重检测策略
|
||||
- 方法1(优先):从导航模块获取当前车道,直接查询(高效)
|
||||
- 方法2(兜底):遍历所有车道匹配位置(处理特殊情况)
|
||||
- 完善异常处理,None 状态返回 0(无红绿灯)
|
||||
- 返回值:0=无/未知, 1=绿灯, 2=黄灯, 3=红灯
|
||||
|
||||
#### 3. 性能优化问题
|
||||
**问题描述:** FPS只有15帧,CPU利用率不到20%
|
||||
|
||||
**解决方案:**
|
||||
- 创建 `run_multiagent_env_fast.py`:激光雷达优化版(30-60 FPS)
|
||||
- 创建 `run_multiagent_env_parallel.py`:多进程并行版(300-600 steps/s)
|
||||
- 提供详细的性能优化文档
|
||||
|
||||
### 📝 修改的文件
|
||||
|
||||
1. **Env/scenario_env.py**
|
||||
- 新增 `_is_position_on_lane()` 方法
|
||||
- 新增 `_filter_valid_spawn_positions()` 方法
|
||||
- 新增 `_get_traffic_light_state()` 方法
|
||||
- 更新 `default_config()` 添加配置参数
|
||||
- 更新 `reset()` 调用过滤逻辑
|
||||
- 更新 `_get_all_obs()` 使用新的红绿灯检测方法
|
||||
|
||||
2. **Env/run_multiagent_env.py**
|
||||
- 添加车道过滤配置参数
|
||||
|
||||
3. **Env/run_multiagent_env_fast.py**
|
||||
- 添加车道过滤配置
|
||||
- 性能优化配置
|
||||
|
||||
4. **Env/run_multiagent_env_parallel.py**
|
||||
- 添加车道过滤配置
|
||||
- 多进程并行实现
|
||||
|
||||
5. **README.md**
|
||||
- 更新问题说明,添加解决方案
|
||||
- 添加配置示例和测试方法
|
||||
- 添加问题解决总结
|
||||
|
||||
6. **新增文件**
|
||||
- `Env/test_lane_filter.py`:功能测试脚本
|
||||
|
||||
### 🧪 测试方法
|
||||
|
||||
```bash
|
||||
# 测试车道过滤和红绿灯检测功能
|
||||
python Env/test_lane_filter.py
|
||||
|
||||
# 运行标准版本(带过滤和可视化)
|
||||
python Env/run_multiagent_env.py
|
||||
|
||||
# 运行高性能版本(适合训练)
|
||||
python Env/run_multiagent_env_fast.py
|
||||
|
||||
# 运行多进程并行版本(最高吞吐量)
|
||||
python Env/run_multiagent_env_parallel.py
|
||||
```
|
||||
|
||||
### 💡 使用建议
|
||||
|
||||
1. **调试阶段**:使用 `run_multiagent_env.py`,启用渲染和车道过滤
|
||||
2. **训练阶段**:使用 `run_multiagent_env_fast.py`,关闭渲染,启用所有优化
|
||||
3. **大规模训练**:使用 `run_multiagent_env_parallel.py`,充分利用多核CPU
|
||||
|
||||
### ⚙️ 配置说明
|
||||
|
||||
所有配置参数都可以在创建环境时通过 `config` 字典传递:
|
||||
|
||||
```python
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
# 基础配置
|
||||
"data_directory": "...",
|
||||
"is_multi_agent": True,
|
||||
"horizon": 300,
|
||||
|
||||
# 车道过滤(新增)
|
||||
"filter_offroad_vehicles": True, # 启用车道过滤
|
||||
"lane_tolerance": 3.0, # 容差3米
|
||||
"max_controlled_vehicles": 10, # 最多10辆车
|
||||
|
||||
# 性能优化
|
||||
"use_render": False,
|
||||
"decision_repeat": 5,
|
||||
...
|
||||
},
|
||||
agent2policy=your_policy
|
||||
)
|
||||
```
|
||||
|
||||
### 🔍 技术细节
|
||||
|
||||
**车道检测逻辑:**
|
||||
1. 使用 `lane.lane.point_on_lane()` 精确检测
|
||||
2. 使用 `lane.local_coordinates()` 计算横向距离
|
||||
3. 支持容差参数处理边界情况
|
||||
|
||||
**红绿灯检测逻辑:**
|
||||
1. 优先从 `vehicle.navigation.current_lane` 获取
|
||||
2. 失败时遍历所有车道查找
|
||||
3. 所有异常均有保护,确保稳定性
|
||||
|
||||
**性能优化原理:**
|
||||
- 减少激光束数量降低计算量
|
||||
- 多进程绕过Python GIL限制
|
||||
- 充分利用多核CPU
|
||||
|
||||
339
Env/DEBUG_GUIDE.md
Normal file
339
Env/DEBUG_GUIDE.md
Normal file
@@ -0,0 +1,339 @@
|
||||
# 调试功能使用指南
|
||||
|
||||
## 📋 概述
|
||||
|
||||
已为车道过滤和红绿灯检测功能添加了详细的调试输出,帮助您诊断和理解代码行为。
|
||||
|
||||
---
|
||||
|
||||
## 🎛️ 调试开关
|
||||
|
||||
### 1. 配置参数
|
||||
|
||||
在创建环境时,可以通过 `config` 参数启用调试模式:
|
||||
|
||||
```python
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
# ... 其他配置 ...
|
||||
|
||||
# 🔥 调试开关
|
||||
"debug_lane_filter": True, # 启用车道过滤调试
|
||||
"debug_traffic_light": True, # 启用红绿灯检测调试
|
||||
},
|
||||
agent2policy=your_policy
|
||||
)
|
||||
```
|
||||
|
||||
### 2. 默认值
|
||||
|
||||
两个调试开关默认都是 `False`(关闭),避免正常运行时产生大量日志。
|
||||
|
||||
---
|
||||
|
||||
## 📊 车道过滤调试 (`debug_lane_filter=True`)
|
||||
|
||||
### 输出内容
|
||||
|
||||
```
|
||||
📍 场景信息统计:
|
||||
- 总车道数: 123
|
||||
|
||||
🔍 开始车道过滤: 共 51 辆车待检测
|
||||
|
||||
车辆 1/51: ID=128
|
||||
🔍 检测位置 (-4.11, 46.76), 容差=3.0m
|
||||
✅ 在车道上 (车道184, 检查了32条)
|
||||
✅ 保留
|
||||
|
||||
车辆 7/51: ID=134
|
||||
🔍 检测位置 (-51.34, -3.77), 容差=3.0m
|
||||
❌ 不在任何车道上 (检查了123条车道)
|
||||
❌ 过滤 (原因: 不在车道上)
|
||||
|
||||
... (所有车辆)
|
||||
|
||||
📊 过滤结果: 保留 45 辆, 过滤 6 辆
|
||||
```
|
||||
|
||||
### 调试信息说明
|
||||
|
||||
| 信息 | 含义 |
|
||||
|------|------|
|
||||
| 📍 场景信息统计 | 场景的基本信息(车道数、红绿灯数) |
|
||||
| 🔍 开始车道过滤 | 开始过滤,显示待检测车辆总数 |
|
||||
| 🔍 检测位置 | 车辆的坐标和使用的容差值 |
|
||||
| ✅ 在车道上 | 找到了车辆所在的车道,显示车道ID和检查次数 |
|
||||
| ❌ 不在任何车道上 | 所有车道都检查完了,未找到匹配的车道 |
|
||||
| 📊 过滤结果 | 最终统计:保留多少辆,过滤多少辆 |
|
||||
|
||||
### 典型输出案例
|
||||
|
||||
**情况1:车辆在正常车道上**
|
||||
```
|
||||
车辆 1/51: ID=128
|
||||
🔍 检测位置 (-4.11, 46.76), 容差=3.0m
|
||||
✅ 在车道上 (车道184, 检查了32条)
|
||||
✅ 保留
|
||||
```
|
||||
→ 检查了32条车道后找到匹配的车道184
|
||||
|
||||
**情况2:车辆在草坪/停车场**
|
||||
```
|
||||
车辆 7/51: ID=134
|
||||
🔍 检测位置 (-51.34, -3.77), 容差=3.0m
|
||||
❌ 不在任何车道上 (检查了123条车道)
|
||||
❌ 过滤 (原因: 不在车道上)
|
||||
```
|
||||
→ 检查了所有123条车道都不匹配,该车辆被过滤
|
||||
|
||||
---
|
||||
|
||||
## 🚦 红绿灯检测调试 (`debug_traffic_light=True`)
|
||||
|
||||
### 输出内容
|
||||
|
||||
```
|
||||
📍 场景信息统计:
|
||||
- 总车道数: 123
|
||||
- 有红绿灯的车道数: 0
|
||||
⚠️ 场景中没有红绿灯!
|
||||
|
||||
🚦 检测车辆红绿灯 - 位置: (-4.1, 46.8)
|
||||
方法1-导航模块:
|
||||
current_lane = <metadrive.component.lane.straight_lane.StraightLane object>
|
||||
lane_index = 184
|
||||
has_traffic_light = False
|
||||
该车道没有红绿灯
|
||||
方法2-遍历车道: 开始遍历 123 条车道
|
||||
✓ 找到车辆所在车道: 184 (检查了32条)
|
||||
has_traffic_light = False
|
||||
该车道没有红绿灯
|
||||
结果: 返回 0 (无红绿灯/未知)
|
||||
```
|
||||
|
||||
### 调试信息说明
|
||||
|
||||
| 信息 | 含义 |
|
||||
|------|------|
|
||||
| 有红绿灯的车道数 | 统计场景中有多少个红绿灯 |
|
||||
| ⚠️ 场景中没有红绿灯 | 如果数量为0,会特别提示 |
|
||||
| 方法1-导航模块 | 尝试从导航系统获取 |
|
||||
| current_lane | 导航系统返回的当前车道对象 |
|
||||
| lane_index | 车道的唯一标识符 |
|
||||
| has_traffic_light | 该车道是否有红绿灯 |
|
||||
| status | 红绿灯的状态(GREEN/YELLOW/RED/None) |
|
||||
| 方法2-遍历车道 | 兜底方案,遍历所有车道查找 |
|
||||
| ✓ 找到车辆所在车道 | 遍历找到了匹配的车道 |
|
||||
|
||||
### 典型输出案例
|
||||
|
||||
**情况1:场景没有红绿灯**
|
||||
```
|
||||
📍 场景信息统计:
|
||||
- 有红绿灯的车道数: 0
|
||||
⚠️ 场景中没有红绿灯!
|
||||
|
||||
🚦 检测车辆红绿灯 - 位置: (-4.1, 46.8)
|
||||
方法1-导航模块:
|
||||
...
|
||||
has_traffic_light = False
|
||||
该车道没有红绿灯
|
||||
结果: 返回 0 (无红绿灯/未知)
|
||||
```
|
||||
→ 所有车辆都会返回0,这是正常的
|
||||
|
||||
**情况2:有红绿灯且状态正常**
|
||||
```
|
||||
🚦 检测车辆红绿灯 - 位置: (10.5, 20.3)
|
||||
方法1-导航模块:
|
||||
current_lane = <...>
|
||||
lane_index = 205
|
||||
has_traffic_light = True
|
||||
status = TRAFFIC_LIGHT_GREEN
|
||||
✅ 方法1成功: 绿灯
|
||||
```
|
||||
→ 方法1直接成功,返回1(绿灯)
|
||||
|
||||
**情况3:红绿灯状态为None**
|
||||
```
|
||||
🚦 检测车辆红绿灯 - 位置: (10.5, 20.3)
|
||||
方法1-导航模块:
|
||||
current_lane = <...>
|
||||
lane_index = 205
|
||||
has_traffic_light = True
|
||||
status = None
|
||||
⚠️ 方法1: 红绿灯状态为None
|
||||
```
|
||||
→ 有红绿灯,但状态异常,返回0
|
||||
|
||||
**情况4:导航失败,方法2兜底**
|
||||
```
|
||||
🚦 检测车辆红绿灯 - 位置: (15.2, 30.5)
|
||||
方法1-导航模块: 不可用 (hasattr=True, not_none=False)
|
||||
方法2-遍历车道: 开始遍历 123 条车道
|
||||
✓ 找到车辆所在车道: 210 (检查了45条)
|
||||
has_traffic_light = True
|
||||
status = TRAFFIC_LIGHT_RED
|
||||
✅ 方法2成功: 红灯
|
||||
```
|
||||
→ 方法1失败,方法2兜底成功,返回3(红灯)
|
||||
|
||||
---
|
||||
|
||||
## 🧪 测试方法
|
||||
|
||||
### 方式1:使用测试脚本
|
||||
|
||||
```bash
|
||||
# 标准测试(无详细调试)
|
||||
python Env/test_lane_filter.py
|
||||
|
||||
# 调试模式(详细输出)
|
||||
python Env/test_lane_filter.py --debug
|
||||
```
|
||||
|
||||
### 方式2:在代码中直接启用
|
||||
|
||||
```python
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from simple_idm_policy import ConstantVelocityPolicy
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": "...",
|
||||
"use_render": False,
|
||||
|
||||
# 启用调试
|
||||
"debug_lane_filter": True,
|
||||
"debug_traffic_light": True,
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
obs = env.reset(0)
|
||||
# 调试信息会自动输出
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📝 调试输出控制
|
||||
|
||||
### 场景1:只想看车道过滤
|
||||
|
||||
```python
|
||||
config = {
|
||||
"debug_lane_filter": True,
|
||||
"debug_traffic_light": False, # 关闭红绿灯调试
|
||||
}
|
||||
```
|
||||
|
||||
### 场景2:只想看红绿灯检测
|
||||
|
||||
```python
|
||||
config = {
|
||||
"debug_lane_filter": False,
|
||||
"debug_traffic_light": True, # 只看红绿灯
|
||||
}
|
||||
```
|
||||
|
||||
### 场景3:生产环境(关闭所有调试)
|
||||
|
||||
```python
|
||||
config = {
|
||||
"debug_lane_filter": False,
|
||||
"debug_traffic_light": False,
|
||||
}
|
||||
# 或者直接不设置这两个参数,默认就是False
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 💡 常见问题诊断
|
||||
|
||||
### 问题1:所有红绿灯状态都是0
|
||||
|
||||
**检查调试输出:**
|
||||
```
|
||||
📍 场景信息统计:
|
||||
- 有红绿灯的车道数: 0
|
||||
⚠️ 场景中没有红绿灯!
|
||||
```
|
||||
|
||||
**结论:** 场景本身没有红绿灯,返回0是正常的
|
||||
|
||||
---
|
||||
|
||||
### 问题2:车辆被过滤但不应该过滤
|
||||
|
||||
**检查调试输出:**
|
||||
```
|
||||
车辆 X: ID=XXX
|
||||
🔍 检测位置 (x, y), 容差=3.0m
|
||||
❌ 不在任何车道上 (检查了123条车道)
|
||||
❌ 过滤 (原因: 不在车道上)
|
||||
```
|
||||
|
||||
**可能原因:**
|
||||
1. 车辆确实在非车道区域(草坪/停车场)
|
||||
2. 容差值太小,可以尝试增大 `lane_tolerance`
|
||||
3. 车道数据有问题
|
||||
|
||||
**解决方案:**
|
||||
```python
|
||||
config = {
|
||||
"lane_tolerance": 5.0, # 增大容差到5米
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 问题3:性能下降
|
||||
|
||||
启用调试模式会有大量输出,影响性能:
|
||||
|
||||
**解决方案:**
|
||||
- 只在开发/调试时启用
|
||||
- 生产环境关闭所有调试开关
|
||||
- 或者只测试少量车辆:
|
||||
```python
|
||||
config = {
|
||||
"max_controlled_vehicles": 5, # 只测试5辆车
|
||||
"debug_traffic_light": True,
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📌 最佳实践
|
||||
|
||||
1. **开发阶段**:启用调试,理解代码行为
|
||||
2. **调试问题**:根据需要选择性启用调试
|
||||
3. **性能测试**:关闭所有调试
|
||||
4. **生产运行**:永久关闭调试
|
||||
|
||||
---
|
||||
|
||||
## 🔧 调试输出示例
|
||||
|
||||
完整的调试运行示例:
|
||||
|
||||
```bash
|
||||
cd /home/huangfukk/MAGAIL4AutoDrive
|
||||
python Env/test_lane_filter.py --debug
|
||||
```
|
||||
|
||||
输出会包含:
|
||||
- 场景统计信息
|
||||
- 每辆车的详细检测过程
|
||||
- 最终的过滤/检测结果
|
||||
- 性能统计
|
||||
|
||||
---
|
||||
|
||||
## 📖 相关文档
|
||||
|
||||
- `README.md` - 项目总览和问题解决
|
||||
- `CHANGELOG.md` - 更新日志
|
||||
- `PERFORMANCE_OPTIMIZATION.md` - 性能优化指南
|
||||
|
||||
221
Env/GPU_ACCELERATION.md
Normal file
221
Env/GPU_ACCELERATION.md
Normal file
@@ -0,0 +1,221 @@
|
||||
# GPU加速指南
|
||||
|
||||
## 当前性能瓶颈分析
|
||||
|
||||
从测试结果看,即使关闭渲染,FPS仍然只有15-20左右,主要瓶颈是:
|
||||
|
||||
### 计算量分析(51辆车)
|
||||
```
|
||||
激光雷达计算:
|
||||
- 前向雷达:80束 × 51车 = 4,080次射线检测
|
||||
- 侧向雷达:10束 × 51车 = 510次射线检测
|
||||
- 车道线雷达:10束 × 51车 = 510次射线检测
|
||||
合计:5,100次射线检测/帧
|
||||
|
||||
红绿灯检测:
|
||||
- 遍历所有车道 × 51车 = 数千次几何计算
|
||||
```
|
||||
|
||||
**关键问题**:这些计算都是CPU单线程串行的,无法利用多核和GPU!
|
||||
|
||||
---
|
||||
|
||||
## GPU加速方案
|
||||
|
||||
### 方案1:优化激光雷达计算(已实现)✅
|
||||
|
||||
**优化内容:**
|
||||
1. 减少激光束数量:100束 → 52束(减少48%)
|
||||
2. 优化红绿灯检测:避免遍历所有车道
|
||||
3. 激光雷达缓存:每N帧才重新计算一次
|
||||
|
||||
**预期提升:** 2-4倍(30-60 FPS)
|
||||
|
||||
**使用方法:**
|
||||
```bash
|
||||
python Env/run_multiagent_env_fast.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 方案2:MetaDrive GPU渲染(有限支持)
|
||||
|
||||
**说明:**
|
||||
MetaDrive基于Panda3D引擎,理论上支持GPU渲染,但:
|
||||
- GPU主要用于**图形渲染**,不是物理计算
|
||||
- 激光雷达的射线检测仍在CPU上
|
||||
- GPU渲染主要加速可视化,不加速训练
|
||||
|
||||
**启用方法:**
|
||||
```python
|
||||
config = {
|
||||
"use_render": True,
|
||||
"render_mode": "onscreen", # 或 "offscreen"
|
||||
# Panda3D会自动尝试使用GPU
|
||||
}
|
||||
```
|
||||
|
||||
**限制:**
|
||||
- 需要显示器或虚拟显示(Xvfb)
|
||||
- WSL2环境需要配置X11转发
|
||||
- 对无渲染训练无帮助
|
||||
|
||||
---
|
||||
|
||||
### 方案3:使用GPU加速的物理引擎(推荐但需要迁移)
|
||||
|
||||
**选项A:Isaac Gym (NVIDIA)**
|
||||
- 完全在GPU上运行物理模拟和渲染
|
||||
- 可同时模拟数千个环境
|
||||
- **缺点**:需要完全重写环境代码,迁移成本高
|
||||
|
||||
**选项B:IsaacSim/Omniverse**
|
||||
- NVIDIA的高级仿真平台
|
||||
- 支持GPU加速的激光雷达
|
||||
- **缺点**:学习曲线陡峭,环境配置复杂
|
||||
|
||||
**选项C:Brax (Google)**
|
||||
- JAX驱动,完全在GPU/TPU上运行
|
||||
- **缺点**:功能有限,不支持复杂场景
|
||||
|
||||
---
|
||||
|
||||
### 方案4:策略网络GPU加速(推荐)✅
|
||||
|
||||
虽然环境仿真在CPU,但可以让**策略网络在GPU上运行**:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
# 创建GPU上的策略模型
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
policy = PolicyNetwork().to(device)
|
||||
|
||||
# 批量处理观测
|
||||
obs_batch = torch.tensor(obs_list).to(device)
|
||||
with torch.no_grad():
|
||||
actions = policy(obs_batch)
|
||||
actions = actions.cpu().numpy()
|
||||
```
|
||||
|
||||
**优势:**
|
||||
- 51辆车的推理可以并行
|
||||
- 如果使用RL训练,GPU加速训练过程
|
||||
- 不需要修改环境代码
|
||||
|
||||
---
|
||||
|
||||
### 方案5:多进程并行(最实用)✅
|
||||
|
||||
既然单个环境受限于CPU单线程,可以**并行运行多个环境**:
|
||||
|
||||
```python
|
||||
from multiprocessing import Pool
|
||||
import os
|
||||
|
||||
def run_single_env(seed):
|
||||
"""运行单个环境实例"""
|
||||
env = MultiAgentScenarioEnv(config=...)
|
||||
obs = env.reset(seed)
|
||||
|
||||
for step in range(1000):
|
||||
actions = {...}
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
if dones["__all__"]:
|
||||
break
|
||||
|
||||
env.close()
|
||||
return results
|
||||
|
||||
# 使用进程池并行运行
|
||||
if __name__ == "__main__":
|
||||
num_processes = os.cpu_count() # 12600KF有10核20线程
|
||||
seeds = list(range(num_processes))
|
||||
|
||||
with Pool(processes=num_processes) as pool:
|
||||
results = pool.map(run_single_env, seeds)
|
||||
```
|
||||
|
||||
**预期提升:** 接近线性(10核 ≈ 10倍吞吐量)
|
||||
|
||||
**CPU利用率:** 可达80-100%
|
||||
|
||||
---
|
||||
|
||||
## 推荐的完整优化方案
|
||||
|
||||
### 1. 立即可用(已实现)
|
||||
```bash
|
||||
# 使用优化版本,激光束减少+缓存
|
||||
python Env/run_multiagent_env_fast.py
|
||||
```
|
||||
**预期:** 30-60 FPS(2-4倍提升)
|
||||
|
||||
### 2. 短期优化(1-2小时)
|
||||
- 实现多进程并行
|
||||
- 策略网络迁移到GPU
|
||||
|
||||
**预期:** 300-600 FPS(总吞吐量)
|
||||
|
||||
### 3. 中期优化(1-2天)
|
||||
- 使用NumPy矢量化批量处理观测
|
||||
- 优化Python代码热点(用Cython/Numba)
|
||||
|
||||
**预期:** 额外20-30%提升
|
||||
|
||||
### 4. 长期方案(1-2周)
|
||||
- 迁移到Isaac Gym等GPU加速仿真器
|
||||
- 或使用分布式训练框架(Ray/RLlib)
|
||||
|
||||
**预期:** 10-100倍提升
|
||||
|
||||
---
|
||||
|
||||
## 为什么MetaDrive无法直接使用GPU?
|
||||
|
||||
### 架构限制:
|
||||
1. **物理引擎**:使用Bullet/Panda3D的CPU物理引擎
|
||||
2. **射线检测**:串行CPU计算,无法并行
|
||||
3. **Python GIL**:全局解释器锁限制多线程
|
||||
4. **设计目标**:MetaDrive设计时主要考虑灵活性而非极致性能
|
||||
|
||||
### GPU在仿真中的作用:
|
||||
- ✅ **图形渲染**:绘制画面(但我们训练时不需要)
|
||||
- ✅ **神经网络推理/训练**:策略模型计算
|
||||
- ❌ **物理计算**:MetaDrive的物理引擎在CPU
|
||||
- ❌ **传感器模拟**:激光雷达等在CPU
|
||||
|
||||
---
|
||||
|
||||
## 检查GPU是否可用
|
||||
|
||||
```bash
|
||||
# 检查NVIDIA GPU
|
||||
nvidia-smi
|
||||
|
||||
# 检查PyTorch GPU支持
|
||||
python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"
|
||||
|
||||
# 检查MetaDrive渲染设备
|
||||
python -c "from panda3d.core import GraphicsPipeSelection; print(GraphicsPipeSelection.get_global_ptr().get_default_pipe())"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
| 方案 | 实现难度 | 性能提升 | GPU使用 | 推荐度 |
|
||||
|------|----------|----------|---------|--------|
|
||||
| 减少激光束 | ⭐ | 2-4x | ❌ | ⭐⭐⭐⭐⭐ |
|
||||
| 激光雷达缓存 | ⭐ | 1.5-3x | ❌ | ⭐⭐⭐⭐⭐ |
|
||||
| 多进程并行 | ⭐⭐ | 5-10x | ❌ | ⭐⭐⭐⭐⭐ |
|
||||
| 策略GPU加速 | ⭐⭐ | 2-5x | ✅ | ⭐⭐⭐⭐ |
|
||||
| GPU渲染 | ⭐⭐⭐ | 1.2x | ✅ | ⭐⭐ |
|
||||
| 迁移Isaac Gym | ⭐⭐⭐⭐⭐ | 10-100x | ✅ | ⭐⭐⭐ |
|
||||
|
||||
**结论:**
|
||||
1. 先用已实现的优化(减少激光束+缓存)
|
||||
2. 再实现多进程并行
|
||||
3. 策略网络用GPU训练
|
||||
4. 如果还不够,考虑迁移到GPU仿真器
|
||||
|
||||
413
Env/LOGGING_GUIDE.md
Normal file
413
Env/LOGGING_GUIDE.md
Normal file
@@ -0,0 +1,413 @@
|
||||
# 日志记录功能使用指南
|
||||
|
||||
## 📋 概述
|
||||
|
||||
为所有运行脚本添加了日志记录功能,可以将终端输出同时保存到文本文件,方便后续分析和问题排查。
|
||||
|
||||
---
|
||||
|
||||
## 🎯 功能特点
|
||||
|
||||
1. **双向输出**:同时输出到终端和文件,不影响实时查看
|
||||
2. **自动管理**:使用上下文管理器,自动处理文件开启/关闭
|
||||
3. **灵活配置**:支持自定义文件名和日志目录
|
||||
4. **时间戳命名**:默认使用时间戳生成唯一文件名
|
||||
5. **无缝集成**:只需添加命令行参数,无需修改代码
|
||||
|
||||
---
|
||||
|
||||
## 🚀 快速使用
|
||||
|
||||
### 1. 基础用法
|
||||
|
||||
```bash
|
||||
# 不启用日志(默认)
|
||||
python Env/run_multiagent_env.py
|
||||
|
||||
# 启用日志记录
|
||||
python Env/run_multiagent_env.py --log
|
||||
|
||||
# 或使用短选项
|
||||
python Env/run_multiagent_env.py -l
|
||||
```
|
||||
|
||||
### 2. 自定义文件名
|
||||
|
||||
```bash
|
||||
# 使用自定义日志文件名
|
||||
python Env/run_multiagent_env.py --log --log-file=my_test.log
|
||||
|
||||
# 测试脚本也支持
|
||||
python Env/test_lane_filter.py --log --log-file=test_results.log
|
||||
```
|
||||
|
||||
### 3. 组合使用调试和日志
|
||||
|
||||
```bash
|
||||
# 测试脚本:调试模式 + 日志记录
|
||||
python Env/test_lane_filter.py --debug --log
|
||||
|
||||
# 会生成类似:test_debug_20251021_123456.log
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📁 日志文件位置
|
||||
|
||||
默认日志目录:`Env/logs/`
|
||||
|
||||
### 文件命名规则
|
||||
|
||||
| 脚本 | 默认文件名格式 | 示例 |
|
||||
|------|---------------|------|
|
||||
| `run_multiagent_env.py` | `run_YYYYMMDD_HHMMSS.log` | `run_20251021_143022.log` |
|
||||
| `run_multiagent_env_fast.py` | `run_fast.log` | `run_fast.log` |
|
||||
| `test_lane_filter.py` | `test_{mode}_YYYYMMDD_HHMMSS.log` | `test_debug_20251021_143500.log` |
|
||||
|
||||
**说明**:
|
||||
- `YYYYMMDD_HHMMSS` 是时间戳(年月日_时分秒)
|
||||
- `{mode}` 是测试模式(`standard` 或 `debug`)
|
||||
|
||||
---
|
||||
|
||||
## 📝 所有支持的脚本
|
||||
|
||||
### 1. run_multiagent_env.py(标准运行脚本)
|
||||
|
||||
```bash
|
||||
# 不启用日志
|
||||
python Env/run_multiagent_env.py
|
||||
|
||||
# 启用日志(自动生成时间戳文件名)
|
||||
python Env/run_multiagent_env.py --log
|
||||
|
||||
# 自定义文件名
|
||||
python Env/run_multiagent_env.py --log --log-file=run_test1.log
|
||||
```
|
||||
|
||||
**日志位置**:`Env/logs/run_YYYYMMDD_HHMMSS.log`
|
||||
|
||||
---
|
||||
|
||||
### 2. run_multiagent_env_fast.py(高性能版本)
|
||||
|
||||
```bash
|
||||
# 启用日志
|
||||
python Env/run_multiagent_env_fast.py --log
|
||||
|
||||
# 自定义文件名
|
||||
python Env/run_multiagent_env_fast.py --log --log-file=fast_test.log
|
||||
```
|
||||
|
||||
**日志位置**:`Env/logs/run_fast.log`(默认)
|
||||
|
||||
---
|
||||
|
||||
### 3. test_lane_filter.py(测试脚本)
|
||||
|
||||
```bash
|
||||
# 标准测试 + 日志
|
||||
python Env/test_lane_filter.py --log
|
||||
|
||||
# 调试测试 + 日志
|
||||
python Env/test_lane_filter.py --debug --log
|
||||
|
||||
# 自定义文件名
|
||||
python Env/test_lane_filter.py --log --log-file=my_test.log
|
||||
|
||||
# 组合使用
|
||||
python Env/test_lane_filter.py --debug --log --log-file=debug_run.log
|
||||
```
|
||||
|
||||
**日志位置**:
|
||||
- 标准模式:`Env/logs/test_standard_YYYYMMDD_HHMMSS.log`
|
||||
- 调试模式:`Env/logs/test_debug_YYYYMMDD_HHMMSS.log`
|
||||
|
||||
---
|
||||
|
||||
## 💻 编程接口
|
||||
|
||||
如果您想在代码中直接使用日志功能:
|
||||
|
||||
```python
|
||||
from logger_utils import setup_logger
|
||||
|
||||
# 方式1:使用上下文管理器(推荐)
|
||||
with setup_logger(log_file="my_log.log", log_dir="logs"):
|
||||
print("这条消息会同时输出到终端和文件")
|
||||
# 运行您的代码
|
||||
# ...
|
||||
|
||||
# 方式2:手动管理
|
||||
from logger_utils import LoggerContext
|
||||
|
||||
logger = LoggerContext(log_file="custom.log", log_dir="output")
|
||||
logger.__enter__() # 开启日志
|
||||
print("输出消息")
|
||||
logger.__exit__(None, None, None) # 关闭日志
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📊 日志内容示例
|
||||
|
||||
### 标准运行
|
||||
|
||||
```
|
||||
📝 日志记录已启用
|
||||
📁 日志文件: Env/logs/run_20251021_143022.log
|
||||
------------------------------------------------------------
|
||||
💡 提示: 使用 --log 或 -l 参数启用日志记录
|
||||
示例: python run_multiagent_env.py --log
|
||||
自定义文件名: python run_multiagent_env.py --log --log-file=my_run.log
|
||||
------------------------------------------------------------
|
||||
[INFO] Environment: MultiAgentScenarioEnv
|
||||
[INFO] MetaDrive version: 0.4.3
|
||||
...
|
||||
------------------------------------------------------------
|
||||
✅ 日志已保存到: Env/logs/run_20251021_143022.log
|
||||
```
|
||||
|
||||
### 调试模式
|
||||
|
||||
```
|
||||
📝 日志记录已启用
|
||||
📁 日志文件: Env/logs/test_debug_20251021_143500.log
|
||||
------------------------------------------------------------
|
||||
🐛 调试模式启用
|
||||
============================================================
|
||||
|
||||
📍 场景信息统计:
|
||||
- 总车道数: 123
|
||||
- 有红绿灯的车道数: 0
|
||||
⚠️ 场景中没有红绿灯!
|
||||
|
||||
🔍 开始车道过滤: 共 51 辆车待检测
|
||||
...
|
||||
------------------------------------------------------------
|
||||
✅ 日志已保存到: Env/logs/test_debug_20251021_143500.log
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔧 高级配置
|
||||
|
||||
### 自定义日志目录
|
||||
|
||||
```python
|
||||
from logger_utils import setup_logger
|
||||
|
||||
# 指定不同的日志目录
|
||||
with setup_logger(log_file="test.log", log_dir="my_logs"):
|
||||
print("日志会保存到 my_logs/test.log")
|
||||
```
|
||||
|
||||
### 追加模式
|
||||
|
||||
```python
|
||||
from logger_utils import setup_logger
|
||||
|
||||
# 追加到现有文件(而不是覆盖)
|
||||
with setup_logger(log_file="test.log", mode='a'): # mode='a' 表示追加
|
||||
print("这条消息会追加到文件末尾")
|
||||
```
|
||||
|
||||
### 只重定向特定输出
|
||||
|
||||
```python
|
||||
from logger_utils import LoggerContext
|
||||
|
||||
# 只重定向stdout,不重定向stderr
|
||||
logger = LoggerContext(
|
||||
log_file="test.log",
|
||||
redirect_stdout=True, # 重定向标准输出
|
||||
redirect_stderr=False # 不重定向错误输出
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📋 命令行参数总结
|
||||
|
||||
| 参数 | 短选项 | 说明 | 示例 |
|
||||
|------|--------|------|------|
|
||||
| `--log` | `-l` | 启用日志记录 | `--log` |
|
||||
| `--log-file=NAME` | 无 | 指定日志文件名 | `--log-file=test.log` |
|
||||
| `--debug` | `-d` | 启用调试模式(test_lane_filter.py) | `--debug` |
|
||||
|
||||
### 参数组合
|
||||
|
||||
```bash
|
||||
# 示例1:标准模式 + 日志
|
||||
python Env/test_lane_filter.py --log
|
||||
|
||||
# 示例2:调试模式 + 日志
|
||||
python Env/test_lane_filter.py --debug --log
|
||||
|
||||
# 示例3:调试 + 自定义文件名
|
||||
python Env/test_lane_filter.py -d --log --log-file=my_debug.log
|
||||
|
||||
# 示例4:所有参数
|
||||
python Env/test_lane_filter.py --debug --log --log-file=full_test.log
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ 常见问题
|
||||
|
||||
### Q1: 日志文件在哪里?
|
||||
|
||||
**A**: 默认在 `Env/logs/` 目录下。如果目录不存在,会自动创建。
|
||||
|
||||
```bash
|
||||
# 查看所有日志文件
|
||||
ls -lh Env/logs/
|
||||
|
||||
# 查看最新的日志
|
||||
ls -lt Env/logs/ | head -5
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Q2: 如何查看日志内容?
|
||||
|
||||
**A**: 使用任何文本编辑器或命令行工具:
|
||||
|
||||
```bash
|
||||
# 方式1:使用cat
|
||||
cat Env/logs/run_20251021_143022.log
|
||||
|
||||
# 方式2:使用less(可翻页)
|
||||
less Env/logs/run_20251021_143022.log
|
||||
|
||||
# 方式3:查看末尾内容
|
||||
tail -n 50 Env/logs/run_20251021_143022.log
|
||||
|
||||
# 方式4:实时监控(适合长时间运行)
|
||||
tail -f Env/logs/run_20251021_143022.log
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Q3: 日志文件太多怎么办?
|
||||
|
||||
**A**: 可以定期清理旧日志:
|
||||
|
||||
```bash
|
||||
# 删除7天前的日志
|
||||
find Env/logs/ -name "*.log" -mtime +7 -delete
|
||||
|
||||
# 只保留最新的10个日志
|
||||
cd Env/logs && ls -t *.log | tail -n +11 | xargs rm -f
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Q4: 日志会影响性能吗?
|
||||
|
||||
**A**: 影响很小,因为:
|
||||
1. 文件I/O是异步的
|
||||
2. 使用了缓冲区
|
||||
3. 立即刷新确保数据不丢失
|
||||
|
||||
如果追求极致性能,建议训练时不启用日志,只在需要分析时启用。
|
||||
|
||||
---
|
||||
|
||||
### Q5: 可以同时记录多个脚本的日志吗?
|
||||
|
||||
**A**: 可以,每个脚本使用不同的日志文件:
|
||||
|
||||
```bash
|
||||
# 终端1
|
||||
python Env/run_multiagent_env.py --log --log-file=script1.log
|
||||
|
||||
# 终端2(同时运行)
|
||||
python Env/test_lane_filter.py --log --log-file=script2.log
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 💡 最佳实践
|
||||
|
||||
### 1. 开发阶段
|
||||
|
||||
```bash
|
||||
# 使用调试模式 + 日志,方便排查问题
|
||||
python Env/test_lane_filter.py --debug --log
|
||||
```
|
||||
|
||||
### 2. 长时间运行
|
||||
|
||||
```bash
|
||||
# 启用日志,避免输出丢失
|
||||
nohup python Env/run_multiagent_env.py --log > /dev/null 2>&1 &
|
||||
|
||||
# 查看实时输出
|
||||
tail -f Env/logs/run_*.log
|
||||
```
|
||||
|
||||
### 3. 批量实验
|
||||
|
||||
```bash
|
||||
# 为每次实验使用不同的日志文件
|
||||
for i in {1..5}; do
|
||||
python Env/run_multiagent_env.py --log --log-file=exp_${i}.log
|
||||
done
|
||||
```
|
||||
|
||||
### 4. 性能测试
|
||||
|
||||
```bash
|
||||
# 不启用日志,获得最佳性能
|
||||
python Env/run_multiagent_env_fast.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📖 相关文档
|
||||
|
||||
- `README.md` - 项目总览
|
||||
- `DEBUG_GUIDE.md` - 调试功能使用指南
|
||||
- `CHANGELOG.md` - 更新日志
|
||||
|
||||
---
|
||||
|
||||
## 🔍 技术细节
|
||||
|
||||
### 实现原理
|
||||
|
||||
1. **TeeLogger类**:实现同时写入终端和文件
|
||||
2. **上下文管理器**:自动管理资源(文件打开/关闭)
|
||||
3. **sys.stdout重定向**:拦截所有print输出
|
||||
4. **即时刷新**:每次写入后立即刷新,确保数据不丢失
|
||||
|
||||
### 源代码
|
||||
|
||||
详见 `Env/logger_utils.py`
|
||||
|
||||
```python
|
||||
# 简化示例
|
||||
class TeeLogger:
|
||||
def write(self, message):
|
||||
self.terminal.write(message) # 输出到终端
|
||||
self.log_file.write(message) # 写入文件
|
||||
self.log_file.flush() # 立即刷新
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ✅ 总结
|
||||
|
||||
- ✅ 简单易用:只需添加 `--log` 参数
|
||||
- ✅ 不影响输出:终端仍可实时查看
|
||||
- ✅ 自动管理:文件自动开启/关闭
|
||||
- ✅ 灵活配置:支持自定义文件名和目录
|
||||
- ✅ 完整记录:包含所有调试信息
|
||||
|
||||
立即开始使用:
|
||||
|
||||
```bash
|
||||
python Env/test_lane_filter.py --debug --log
|
||||
```
|
||||
|
||||
131
Env/PERFORMANCE_OPTIMIZATION.md
Normal file
131
Env/PERFORMANCE_OPTIMIZATION.md
Normal file
@@ -0,0 +1,131 @@
|
||||
# MetaDrive 性能优化指南
|
||||
|
||||
## 为什么帧率只有15FPS且CPU利用率不高?
|
||||
|
||||
### 主要原因:
|
||||
|
||||
1. **渲染瓶颈(最主要)**
|
||||
- `use_render: True` + 每帧调用 `env.render()` 会严重限制帧率
|
||||
- MetaDrive 使用 Panda3D 渲染引擎,渲染是**同步阻塞**的
|
||||
- 即使CPU有余力,也要等待渲染完成才能继续下一步
|
||||
- 这就是为什么CPU利用率低但帧率也低的原因
|
||||
|
||||
2. **激光雷达计算开销**
|
||||
- 每帧对每辆车进行3次激光雷达扫描(100个激光束)
|
||||
- 需要进行物理射线检测,计算量较大
|
||||
|
||||
3. **物理引擎同步**
|
||||
- 默认物理步长很小(0.02s),需要频繁计算
|
||||
|
||||
4. **Python GIL限制**
|
||||
- Python全局解释器锁限制了多核并行
|
||||
- 即使是多核CPU,Python单线程性能才是瓶颈
|
||||
|
||||
## 性能优化方案
|
||||
|
||||
### 方案1:关闭渲染(推荐用于训练)
|
||||
**预期提升:10-20倍(150-300+ FPS)**
|
||||
|
||||
```python
|
||||
config = {
|
||||
"use_render": False, # 关闭渲染
|
||||
"render_pipeline": False,
|
||||
"image_observation": False,
|
||||
"interface_panel": [],
|
||||
"manual_control": False,
|
||||
}
|
||||
```
|
||||
|
||||
### 方案2:降低物理计算频率
|
||||
**预期提升:2-3倍**
|
||||
|
||||
```python
|
||||
config = {
|
||||
"physics_world_step_size": 0.05, # 默认0.02,增大步长
|
||||
"decision_repeat": 5, # 每5个物理步执行一次决策
|
||||
}
|
||||
```
|
||||
|
||||
### 方案3:优化激光雷达
|
||||
**预期提升:1.5-2倍**
|
||||
|
||||
修改 `scenario_env.py` 中的 `_get_all_obs()` 函数:
|
||||
|
||||
```python
|
||||
# 减少激光束数量
|
||||
lidar = self.engine.get_sensor("lidar").perceive(
|
||||
num_lasers=40, # 从80减到40
|
||||
distance=30,
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.dynamic_world
|
||||
)
|
||||
|
||||
# 或者降低扫描频率(每N步才扫描一次)
|
||||
if self.round % 5 == 0:
|
||||
lidar = self.engine.get_sensor("lidar").perceive(...)
|
||||
else:
|
||||
lidar = self.last_lidar[agent_id] # 使用缓存
|
||||
```
|
||||
|
||||
### 方案4:间歇性渲染
|
||||
**适用场景:既需要可视化又想提升性能**
|
||||
|
||||
```python
|
||||
# 每10步渲染一次,而不是每步都渲染
|
||||
if step % 10 == 0:
|
||||
env.render(mode="topdown")
|
||||
```
|
||||
|
||||
### 方案5:使用多进程并行(高级)
|
||||
**预期提升:接近线性(取决于进程数)**
|
||||
|
||||
```python
|
||||
from multiprocessing import Pool
|
||||
|
||||
def run_env(seed):
|
||||
env = MultiAgentScenarioEnv(config=...)
|
||||
# 运行仿真
|
||||
return results
|
||||
|
||||
# 使用进程池并行运行多个环境
|
||||
with Pool(processes=8) as pool:
|
||||
results = pool.map(run_env, range(8))
|
||||
```
|
||||
|
||||
## 文件说明
|
||||
|
||||
- `run_multiagent_env.py` - **标准版本**(无渲染,基础优化)
|
||||
- `run_multiagent_env_fast.py` - **极速版本**(激光雷达优化+缓存)⭐推荐
|
||||
- `run_multiagent_env_parallel.py` - **并行版本**(多进程,最高吞吐量)⭐⭐推荐
|
||||
- `run_multiagent_env_visual.py` - **可视化版本**(有渲染,适合调试)
|
||||
|
||||
## 性能对比
|
||||
|
||||
| 配置 | 单环境FPS | 总吞吐量 | CPU利用率 | 文件 | 适用场景 |
|
||||
|------|-----------|----------|-----------|------|----------|
|
||||
| 原始配置(有渲染) | 15-20 | 15-20 | 15-20% | visual | 实时可视化调试 |
|
||||
| 关闭渲染 | 20-25 | 20-25 | 20-30% | 标准版 | 基础训练 |
|
||||
| 激光雷达优化+缓存 | 30-60 | 30-60 | 30-50% | fast | 快速训练⭐ |
|
||||
| 多进程并行(10核) | 30-60 | 300-600 | 90-100% | parallel | 大规模训练⭐⭐ |
|
||||
|
||||
**说明:**
|
||||
- **单环境FPS**:单个环境实例的帧率
|
||||
- **总吞吐量**:所有进程合计的 steps/second
|
||||
- 12600KF(10核20线程)推荐使用并行版本
|
||||
|
||||
## 建议
|
||||
|
||||
1. **训练时**:使用高性能版本(关闭渲染)
|
||||
2. **调试时**:使用可视化版本,或间歇性渲染
|
||||
3. **大规模实验**:使用多进程并行
|
||||
4. **如果需要GPU加速**:考虑使用GPU渲染或将策略网络部署到GPU上
|
||||
|
||||
## 为什么CPU利用率低?
|
||||
|
||||
- **渲染阻塞**:CPU在等待渲染完成
|
||||
- **Python GIL**:限制了多核利用
|
||||
- **I/O等待**:可能在等待磁盘读取数据
|
||||
- **单线程瓶颈**:MetaDrive主循环是单线程的
|
||||
|
||||
解决方法:关闭渲染 + 多进程并行
|
||||
|
||||
241
Env/QUICK_START.md
Normal file
241
Env/QUICK_START.md
Normal file
@@ -0,0 +1,241 @@
|
||||
# 快速使用指南
|
||||
|
||||
## 🚀 已实现的性能优化
|
||||
|
||||
根据您的测试结果,原始版本FPS只有15左右,现已进行了全面优化。
|
||||
|
||||
---
|
||||
|
||||
## 📊 性能瓶颈分析
|
||||
|
||||
您的CPU是12600KF(10核20线程),但利用率不到20%,原因是:
|
||||
|
||||
1. **激光雷达计算瓶颈**:51辆车 × 100个激光束 = 每帧5100次射线检测
|
||||
2. **红绿灯检测低效**:遍历所有车道进行几何计算
|
||||
3. **Python GIL限制**:单线程执行,无法利用多核
|
||||
4. **计算串行化**:所有车辆依次处理,没有并行
|
||||
|
||||
---
|
||||
|
||||
## 🎯 推荐使用方案
|
||||
|
||||
### 方案1:极速单环境(推荐新手)⭐
|
||||
```bash
|
||||
python Env/run_multiagent_env_fast.py
|
||||
```
|
||||
|
||||
**优化内容:**
|
||||
- ✅ 激光束:100束 → 52束(减少48%计算量)
|
||||
- ✅ 激光雷达缓存:每3帧才重新计算
|
||||
- ✅ 红绿灯检测优化:避免遍历所有车道
|
||||
- ✅ 关闭所有渲染和调试
|
||||
|
||||
**预期性能:** 30-60 FPS(2-4倍提升)
|
||||
|
||||
---
|
||||
|
||||
### 方案2:多进程并行(推荐训练)⭐⭐
|
||||
```bash
|
||||
python Env/run_multiagent_env_parallel.py
|
||||
```
|
||||
|
||||
**优化内容:**
|
||||
- ✅ 同时运行10个独立环境(充分利用10核CPU)
|
||||
- ✅ 每个环境应用所有单环境优化
|
||||
- ✅ CPU利用率可达90-100%
|
||||
|
||||
**预期性能:** 300-600 steps/s(20-40倍总吞吐量)
|
||||
|
||||
---
|
||||
|
||||
### 方案3:可视化调试
|
||||
```bash
|
||||
python Env/run_multiagent_env_visual.py
|
||||
```
|
||||
|
||||
**说明:** 保留渲染功能,FPS约15,仅用于调试
|
||||
|
||||
---
|
||||
|
||||
## 🔧 关于GPU加速
|
||||
|
||||
### GPU能否加速MetaDrive?
|
||||
|
||||
**简短回答:有限支持,主要瓶颈不在GPU**
|
||||
|
||||
**详细说明:**
|
||||
|
||||
1. **物理计算(主要瓶颈)** ❌ 不支持GPU
|
||||
- MetaDrive使用Bullet物理引擎,只在CPU运行
|
||||
- 激光雷达射线检测也在CPU
|
||||
- 这是FPS低的主要原因
|
||||
|
||||
2. **图形渲染** ✅ 支持GPU
|
||||
- Panda3D会自动使用GPU渲染
|
||||
- 但我们训练时关闭了渲染,所以GPU无用武之地
|
||||
|
||||
3. **策略网络** ✅ 支持GPU
|
||||
- 可以把Policy模型放到GPU上
|
||||
- 但环境本身仍在CPU
|
||||
|
||||
### GPU渲染配置(可选)
|
||||
```python
|
||||
config = {
|
||||
"use_render": True,
|
||||
# GPU会自动用于渲染
|
||||
}
|
||||
```
|
||||
|
||||
### 策略网络GPU加速(推荐)
|
||||
```python
|
||||
import torch
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
policy_model = PolicyNet().to(device)
|
||||
|
||||
# 批量推理
|
||||
obs_tensor = torch.tensor(obs_list).to(device)
|
||||
actions = policy_model(obs_tensor)
|
||||
```
|
||||
|
||||
**详细说明请看:** `GPU_ACCELERATION.md`
|
||||
|
||||
---
|
||||
|
||||
## 📈 性能对比
|
||||
|
||||
| 版本 | FPS | CPU利用率 | 改进 |
|
||||
|------|-----|-----------|------|
|
||||
| 原始版本 | 15 | 20% | - |
|
||||
| 极速版本 | 30-60 | 30-50% | 2-4x |
|
||||
| 并行版本 | 30-60/env | 90-100% | 总吞吐20-40x |
|
||||
|
||||
---
|
||||
|
||||
## 💡 使用建议
|
||||
|
||||
### 场景1:快速测试环境
|
||||
```bash
|
||||
python Env/run_multiagent_env_fast.py
|
||||
```
|
||||
单环境,快速验证功能
|
||||
|
||||
### 场景2:大规模数据收集
|
||||
```bash
|
||||
python Env/run_multiagent_env_parallel.py
|
||||
```
|
||||
多进程,最大化数据收集速度
|
||||
|
||||
### 场景3:RL训练
|
||||
```bash
|
||||
# 推荐使用Ray RLlib等框架,它们内置了并行环境管理
|
||||
# 或者修改parallel版本,保存经验到replay buffer
|
||||
```
|
||||
|
||||
### 场景4:调试/可视化
|
||||
```bash
|
||||
python Env/run_multiagent_env_visual.py
|
||||
```
|
||||
带渲染,可以看到车辆运行
|
||||
|
||||
---
|
||||
|
||||
## 🔍 性能监控
|
||||
|
||||
所有版本都内置了性能统计,运行时会显示:
|
||||
```
|
||||
Step 100: FPS = 45.23, 车辆数 = 51, 平均步时间 = 22.10ms
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ⚙️ 高级优化选项
|
||||
|
||||
### 调整激光雷达缓存频率
|
||||
|
||||
编辑 `run_multiagent_env_fast.py`:
|
||||
```python
|
||||
env.lidar_cache_interval = 3 # 改为5可进一步提速(但观测会更旧)
|
||||
```
|
||||
|
||||
### 调整并行进程数
|
||||
|
||||
编辑 `run_multiagent_env_parallel.py`:
|
||||
```python
|
||||
num_workers = 10 # 改为更少的进程数(如果内存不足)
|
||||
```
|
||||
|
||||
### 进一步减少激光束
|
||||
|
||||
编辑 `scenario_env.py` 的 `_get_all_obs()` 函数:
|
||||
```python
|
||||
lidar = self.engine.get_sensor("lidar").perceive(
|
||||
num_lasers=20, # 从40进一步减少到20
|
||||
distance=20, # 从30减少到20米
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🎓 为什么CPU利用率低?
|
||||
|
||||
### 原因分析:
|
||||
|
||||
1. **单线程瓶颈**
|
||||
- Python GIL限制
|
||||
- MetaDrive主循环是单线程的
|
||||
- 即使有10个核心,也只用1个
|
||||
|
||||
2. **I/O等待**
|
||||
- 等待渲染完成(如果开启)
|
||||
- 等待磁盘读取数据
|
||||
|
||||
3. **计算不均衡**
|
||||
- 某些计算很重(激光雷达),某些很轻
|
||||
- CPU在重计算之间有空闲
|
||||
|
||||
### 解决方案:
|
||||
|
||||
✅ **已实现:** 多进程并行(`run_multiagent_env_parallel.py`)
|
||||
- 每个进程占用1个核心
|
||||
- 10个进程可充分利用10核CPU
|
||||
- CPU利用率可达90-100%
|
||||
|
||||
---
|
||||
|
||||
## 📚 相关文档
|
||||
|
||||
- `PERFORMANCE_OPTIMIZATION.md` - 详细的性能优化指南
|
||||
- `GPU_ACCELERATION.md` - GPU加速的完整说明
|
||||
|
||||
---
|
||||
|
||||
## ❓ 常见问题
|
||||
|
||||
### Q: 为什么关闭渲染后FPS还是只有20?
|
||||
A: 主要瓶颈是激光雷达计算,不是渲染。请使用 `run_multiagent_env_fast.py`。
|
||||
|
||||
### Q: GPU能加速训练吗?
|
||||
A: 环境模拟在CPU,但策略网络可以在GPU上训练。
|
||||
|
||||
### Q: 如何最大化CPU利用率?
|
||||
A: 使用 `run_multiagent_env_parallel.py` 多进程版本。
|
||||
|
||||
### Q: 会影响观测精度吗?
|
||||
A: 激光束减少会略微降低精度,但实践中影响很小。缓存会让观测滞后1-2帧。
|
||||
|
||||
### Q: 如何恢复原始配置?
|
||||
A: 使用 `run_multiagent_env_visual.py` 或修改配置文件中的参数。
|
||||
|
||||
---
|
||||
|
||||
## 🚦 下一步
|
||||
|
||||
1. 先测试 `run_multiagent_env_fast.py`,验证性能提升
|
||||
2. 如果满意,用于日常训练
|
||||
3. 需要大规模训练时,使用 `run_multiagent_env_parallel.py`
|
||||
4. 考虑将策略网络迁移到GPU
|
||||
|
||||
祝训练顺利!🎉
|
||||
|
||||
15
Env/__init__.py
Normal file
15
Env/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Multi-Agent Scenario Environment
|
||||
|
||||
多智能体场景环境
|
||||
"""
|
||||
|
||||
from .scenario_env import MultiAgentScenarioEnv, PolicyVehicle
|
||||
from .simple_idm_policy import ConstantVelocityPolicy
|
||||
|
||||
__all__ = [
|
||||
'MultiAgentScenarioEnv',
|
||||
'PolicyVehicle',
|
||||
'ConstantVelocityPolicy',
|
||||
]
|
||||
|
||||
BIN
Env/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
Env/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Env/__pycache__/logger_utils.cpython-310.pyc
Normal file
BIN
Env/__pycache__/logger_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Env/__pycache__/logger_utils.cpython-313.pyc
Normal file
BIN
Env/__pycache__/logger_utils.cpython-313.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
Env/__pycache__/run_multiagent_env.cpython-310.pyc
Normal file
BIN
Env/__pycache__/run_multiagent_env.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Env/__pycache__/simple_idm_policy.cpython-313.pyc
Normal file
BIN
Env/__pycache__/simple_idm_policy.cpython-313.pyc
Normal file
Binary file not shown.
@@ -1,109 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据集检查脚本:统计可用的场景数量
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import os
|
||||
from pathlib import Path
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
|
||||
def check_dataset(data_dir, subfolder="exp_filtered"):
|
||||
"""
|
||||
检查数据集中的场景数量
|
||||
|
||||
Args:
|
||||
data_dir: 数据根目录
|
||||
subfolder: 数据子目录(exp_filtered 或 exp_converted)
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("数据集检查工具")
|
||||
print("=" * 80)
|
||||
|
||||
# 获取完整路径
|
||||
full_path = AssetLoader.file_path(data_dir, subfolder, unix_style=False)
|
||||
print(f"\n数据集路径: {full_path}")
|
||||
|
||||
# 检查文件结构
|
||||
if not os.path.exists(full_path):
|
||||
print(f"❌ 错误:路径不存在: {full_path}")
|
||||
return
|
||||
|
||||
print(f"✅ 路径存在")
|
||||
|
||||
# 读取数据集映射
|
||||
mapping_file = os.path.join(full_path, "dataset_mapping.pkl")
|
||||
if os.path.exists(mapping_file):
|
||||
print(f"\n读取 dataset_mapping.pkl...")
|
||||
with open(mapping_file, 'rb') as f:
|
||||
dataset_mapping = pickle.load(f)
|
||||
|
||||
print(f"✅ 数据集映射文件存在")
|
||||
print(f" 映射的场景数量: {len(dataset_mapping)}")
|
||||
|
||||
# 统计各个子目录的分布
|
||||
subdirs = {}
|
||||
for filename, subdir in dataset_mapping.items():
|
||||
if subdir not in subdirs:
|
||||
subdirs[subdir] = []
|
||||
subdirs[subdir].append(filename)
|
||||
|
||||
print(f"\n场景分布:")
|
||||
for subdir, files in subdirs.items():
|
||||
print(f" {subdir}: {len(files)} 个场景")
|
||||
|
||||
# 读取数据集摘要
|
||||
summary_file = os.path.join(full_path, "dataset_summary.pkl")
|
||||
if os.path.exists(summary_file):
|
||||
print(f"\n读取 dataset_summary.pkl...")
|
||||
with open(summary_file, 'rb') as f:
|
||||
dataset_summary = pickle.load(f)
|
||||
|
||||
print(f"✅ 数据集摘要文件存在")
|
||||
print(f" 摘要的场景数量: {len(dataset_summary)}")
|
||||
|
||||
# 打印前几个场景的ID
|
||||
print(f"\n前10个场景ID:")
|
||||
for i, (scenario_id, info) in enumerate(list(dataset_summary.items())[:10]):
|
||||
if isinstance(info, dict):
|
||||
track_length = info.get('track_length', 'N/A')
|
||||
num_objects = info.get('number_summary', {}).get('num_objects', 'N/A')
|
||||
print(f" {i}: {scenario_id[:16]}... (时长: {track_length}, 对象数: {num_objects})")
|
||||
|
||||
# 检查实际文件
|
||||
print(f"\n检查实际文件...")
|
||||
pkl_files = list(Path(full_path).rglob("*.pkl"))
|
||||
# 排除dataset_mapping和dataset_summary
|
||||
scenario_files = [f for f in pkl_files if f.name not in ["dataset_mapping.pkl", "dataset_summary.pkl"]]
|
||||
print(f" 实际场景文件数量: {len(scenario_files)}")
|
||||
|
||||
# 检查子目录
|
||||
print(f"\n子目录结构:")
|
||||
for item in os.listdir(full_path):
|
||||
item_path = os.path.join(full_path, item)
|
||||
if os.path.isdir(item_path):
|
||||
pkl_count = len([f for f in os.listdir(item_path) if f.endswith('.pkl')])
|
||||
print(f" {item}/: {pkl_count} 个pkl文件")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("检查完成")
|
||||
print("=" * 80)
|
||||
|
||||
# 返回场景数量
|
||||
return len(dataset_mapping) if 'dataset_mapping' in locals() else 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
|
||||
|
||||
print("\n【检查 exp_filtered 数据集】")
|
||||
num_filtered = check_dataset(WAYMO_DATA_DIR, "exp_filtered")
|
||||
|
||||
print("\n\n【检查 exp_converted 数据集】")
|
||||
num_converted = check_dataset(WAYMO_DATA_DIR, "exp_converted")
|
||||
|
||||
print("\n\n总结:")
|
||||
print(f" exp_filtered: {num_filtered} 个场景")
|
||||
print(f" exp_converted: {num_converted} 个场景")
|
||||
|
||||
116
Env/example_with_logging.py
Normal file
116
Env/example_with_logging.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
日志记录功能示例
|
||||
演示如何在自定义脚本中使用日志功能
|
||||
"""
|
||||
from logger_utils import setup_logger
|
||||
from datetime import datetime
|
||||
import time
|
||||
|
||||
def example_without_logging():
|
||||
"""示例1:不使用日志"""
|
||||
print("=" * 60)
|
||||
print("示例1:普通输出(不记录日志)")
|
||||
print("=" * 60)
|
||||
|
||||
print("这是普通的print输出")
|
||||
print("只会显示在终端")
|
||||
print("不会保存到文件")
|
||||
print()
|
||||
|
||||
|
||||
def example_with_logging():
|
||||
"""示例2:使用日志记录"""
|
||||
print("=" * 60)
|
||||
print("示例2:使用日志记录")
|
||||
print("=" * 60)
|
||||
|
||||
# 使用with语句,自动管理日志文件
|
||||
with setup_logger(log_file="example_demo.log", log_dir="logs"):
|
||||
print("✅ 这条消息会同时输出到终端和文件")
|
||||
print("✅ 运行一些计算...")
|
||||
|
||||
for i in range(5):
|
||||
print(f" 步骤 {i+1}/5: 处理中...")
|
||||
time.sleep(0.1)
|
||||
|
||||
print("✅ 计算完成!")
|
||||
|
||||
print("日志文件已关闭")
|
||||
print()
|
||||
|
||||
|
||||
def example_custom_filename():
|
||||
"""示例3:使用时间戳命名"""
|
||||
print("=" * 60)
|
||||
print("示例3:自动生成时间戳文件名")
|
||||
print("=" * 60)
|
||||
|
||||
# log_file=None 会自动生成时间戳文件名
|
||||
with setup_logger(log_file=None, log_dir="logs"):
|
||||
print("文件名会自动包含时间戳")
|
||||
print("适合批量实验,避免覆盖")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def example_append_mode():
|
||||
"""示例4:追加模式"""
|
||||
print("=" * 60)
|
||||
print("示例4:追加到现有文件")
|
||||
print("=" * 60)
|
||||
|
||||
# 第一次写入
|
||||
with setup_logger(log_file="append_test.log", log_dir="logs", mode='w'):
|
||||
print("第一次写入:这会覆盖文件")
|
||||
|
||||
# 第二次写入(追加)
|
||||
with setup_logger(log_file="append_test.log", log_dir="logs", mode='a'):
|
||||
print("第二次写入:这会追加到文件末尾")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def example_complex_output():
|
||||
"""示例5:复杂输出(包含颜色、格式)"""
|
||||
print("=" * 60)
|
||||
print("示例5:复杂输出格式")
|
||||
print("=" * 60)
|
||||
|
||||
with setup_logger(log_file="complex_output.log", log_dir="logs"):
|
||||
# 模拟多种输出格式
|
||||
print("\n📊 实验统计:")
|
||||
print(" - 实验名称:车道过滤测试")
|
||||
print(" - 开始时间:", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
print(" - 车辆总数:51")
|
||||
print(" - 过滤后:45")
|
||||
print("\n🚦 红绿灯检测:")
|
||||
print(" ✅ 方法1成功:3辆")
|
||||
print(" ✅ 方法2成功:2辆")
|
||||
print(" ⚠️ 未检测到:40辆")
|
||||
print("\n" + "="*50)
|
||||
print("实验完成!")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
"""运行所有示例"""
|
||||
print("\n" + "🎯 " + "="*56)
|
||||
print("日志记录功能完整示例")
|
||||
print("="*60 + "\n")
|
||||
|
||||
example_without_logging()
|
||||
example_with_logging()
|
||||
example_custom_filename()
|
||||
example_append_mode()
|
||||
example_complex_output()
|
||||
|
||||
print("="*60)
|
||||
print("✅ 所有示例运行完成!")
|
||||
print("📁 查看日志文件:ls -lh logs/")
|
||||
print("="*60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
170
Env/logger_utils.py
Normal file
170
Env/logger_utils.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
日志工具模块
|
||||
提供将终端输出同时保存到文件的功能
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class TeeLogger:
|
||||
"""
|
||||
双向输出类:同时输出到终端和文件
|
||||
"""
|
||||
def __init__(self, filename, mode='w', terminal=None):
|
||||
"""
|
||||
Args:
|
||||
filename: 日志文件路径
|
||||
mode: 文件打开模式 ('w'=覆盖, 'a'=追加)
|
||||
terminal: 原始输出流(通常是sys.stdout或sys.stderr)
|
||||
"""
|
||||
self.terminal = terminal or sys.stdout
|
||||
self.log_file = open(filename, mode, encoding='utf-8')
|
||||
|
||||
def write(self, message):
|
||||
"""写入消息到终端和文件"""
|
||||
self.terminal.write(message)
|
||||
self.log_file.write(message)
|
||||
self.log_file.flush() # 立即写入磁盘
|
||||
|
||||
def flush(self):
|
||||
"""刷新缓冲区"""
|
||||
self.terminal.flush()
|
||||
self.log_file.flush()
|
||||
|
||||
def close(self):
|
||||
"""关闭日志文件"""
|
||||
if self.log_file:
|
||||
self.log_file.close()
|
||||
|
||||
|
||||
class LoggerContext:
|
||||
"""
|
||||
日志上下文管理器
|
||||
使用with语句自动管理日志的开启和关闭
|
||||
"""
|
||||
def __init__(self, log_file=None, log_dir="logs", mode='w',
|
||||
redirect_stdout=True, redirect_stderr=True):
|
||||
"""
|
||||
Args:
|
||||
log_file: 日志文件名(None则自动生成时间戳文件名)
|
||||
log_dir: 日志目录
|
||||
mode: 文件打开模式 ('w'=覆盖, 'a'=追加)
|
||||
redirect_stdout: 是否重定向标准输出
|
||||
redirect_stderr: 是否重定向标准错误
|
||||
"""
|
||||
self.log_dir = log_dir
|
||||
self.mode = mode
|
||||
self.redirect_stdout = redirect_stdout
|
||||
self.redirect_stderr = redirect_stderr
|
||||
|
||||
# 创建日志目录
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# 生成日志文件名
|
||||
if log_file is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
log_file = f"run_{timestamp}.log"
|
||||
|
||||
self.log_path = os.path.join(log_dir, log_file)
|
||||
|
||||
# 保存原始的stdout和stderr
|
||||
self.original_stdout = sys.stdout
|
||||
self.original_stderr = sys.stderr
|
||||
|
||||
# 日志对象
|
||||
self.stdout_logger = None
|
||||
self.stderr_logger = None
|
||||
|
||||
def __enter__(self):
|
||||
"""进入上下文:开启日志"""
|
||||
print(f"📝 日志记录已启用")
|
||||
print(f"📁 日志文件: {self.log_path}")
|
||||
print("-" * 60)
|
||||
|
||||
# 创建TeeLogger对象
|
||||
if self.redirect_stdout:
|
||||
self.stdout_logger = TeeLogger(
|
||||
self.log_path,
|
||||
mode=self.mode,
|
||||
terminal=self.original_stdout
|
||||
)
|
||||
sys.stdout = self.stdout_logger
|
||||
|
||||
if self.redirect_stderr:
|
||||
self.stderr_logger = TeeLogger(
|
||||
self.log_path,
|
||||
mode='a', # stderr总是追加模式
|
||||
terminal=self.original_stderr
|
||||
)
|
||||
sys.stderr = self.stderr_logger
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""退出上下文:关闭日志"""
|
||||
# 恢复原始输出
|
||||
sys.stdout = self.original_stdout
|
||||
sys.stderr = self.original_stderr
|
||||
|
||||
# 关闭日志文件
|
||||
if self.stdout_logger:
|
||||
self.stdout_logger.close()
|
||||
if self.stderr_logger:
|
||||
self.stderr_logger.close()
|
||||
|
||||
print("-" * 60)
|
||||
print(f"✅ 日志已保存到: {self.log_path}")
|
||||
|
||||
# 返回False表示不抑制异常
|
||||
return False
|
||||
|
||||
|
||||
def setup_logger(log_file=None, log_dir="logs", mode='w'):
|
||||
"""
|
||||
快速设置日志记录
|
||||
|
||||
Args:
|
||||
log_file: 日志文件名(None则自动生成)
|
||||
log_dir: 日志目录
|
||||
mode: 文件模式 ('w'=覆盖, 'a'=追加)
|
||||
|
||||
Returns:
|
||||
LoggerContext对象
|
||||
|
||||
Example:
|
||||
with setup_logger("my_test.log"):
|
||||
print("这条消息会同时输出到终端和文件")
|
||||
"""
|
||||
return LoggerContext(log_file=log_file, log_dir=log_dir, mode=mode)
|
||||
|
||||
|
||||
def get_default_log_filename(prefix="run"):
|
||||
"""
|
||||
生成默认的日志文件名(带时间戳)
|
||||
|
||||
Args:
|
||||
prefix: 文件名前缀
|
||||
|
||||
Returns:
|
||||
str: 格式为 "prefix_YYYYMMDD_HHMMSS.log"
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return f"{prefix}_{timestamp}.log"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
print("测试1: 使用默认配置")
|
||||
with setup_logger():
|
||||
print("这是测试消息1")
|
||||
print("这是测试消息2")
|
||||
print("日志记录已结束\n")
|
||||
|
||||
print("测试2: 使用自定义文件名")
|
||||
with setup_logger(log_file="test_custom.log"):
|
||||
print("自定义文件名测试")
|
||||
for i in range(3):
|
||||
print(f" 消息 {i+1}")
|
||||
print("完成")
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
class ReplayPolicy:
|
||||
"""
|
||||
严格回放策略:根据专家轨迹数据,逐帧回放车辆状态
|
||||
"""
|
||||
|
||||
def __init__(self, expert_trajectory, vehicle_id):
|
||||
"""
|
||||
Args:
|
||||
expert_trajectory: 专家轨迹字典,包含 positions, headings, velocities, valid
|
||||
vehicle_id: 车辆ID(用于调试)
|
||||
"""
|
||||
self.trajectory = expert_trajectory
|
||||
self.vehicle_id = vehicle_id
|
||||
self.current_step = 0
|
||||
|
||||
def act(self, observation=None):
|
||||
"""
|
||||
返回动作:在回放模式下返回空动作
|
||||
实际状态由环境直接设置
|
||||
"""
|
||||
return [0.0, 0.0]
|
||||
|
||||
def get_target_state(self, step):
|
||||
"""
|
||||
获取指定时间步的目标状态
|
||||
|
||||
Args:
|
||||
step: 时间步
|
||||
|
||||
Returns:
|
||||
dict: 包含 position, heading, velocity 的字典,如果无效则返回 None
|
||||
"""
|
||||
if step >= len(self.trajectory['valid']):
|
||||
return None
|
||||
|
||||
if not self.trajectory['valid'][step]:
|
||||
return None
|
||||
|
||||
return {
|
||||
'position': self.trajectory['positions'][step],
|
||||
'heading': self.trajectory['headings'][step],
|
||||
'velocity': self.trajectory['velocities'][step]
|
||||
}
|
||||
|
||||
def is_finished(self, step):
|
||||
"""
|
||||
判断轨迹是否已经播放完毕
|
||||
|
||||
Args:
|
||||
step: 当前时间步
|
||||
|
||||
Returns:
|
||||
bool: 如果轨迹已播放完或当前步无效,返回 True
|
||||
"""
|
||||
# 超出轨迹长度
|
||||
if step >= len(self.trajectory['valid']):
|
||||
return True
|
||||
|
||||
# 当前步及之后都无效
|
||||
return not any(self.trajectory['valid'][step:])
|
||||
@@ -1,390 +1,78 @@
|
||||
import argparse
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from simple_idm_policy import ConstantVelocityPolicy
|
||||
from replay_policy import ReplayPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
from logger_utils import setup_logger
|
||||
import sys
|
||||
import os
|
||||
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/MAGAIL4AutoDrive/Env"
|
||||
|
||||
|
||||
def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
|
||||
scenario_id=None, use_scenario_duration=False,
|
||||
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
|
||||
def main(enable_logging=False, log_file=None):
|
||||
"""
|
||||
回放模式:严格按照专家轨迹回放
|
||||
主函数
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录
|
||||
num_episodes: 回合数(如果指定scenario_id,则忽略)
|
||||
horizon: 最大步数(如果use_scenario_duration=True,则自动设置)
|
||||
render: 是否渲染
|
||||
debug: 是否调试模式
|
||||
scenario_id: 指定场景ID(可选)
|
||||
use_scenario_duration: 是否使用场景原始时长
|
||||
spawn_vehicles: 是否生成车辆(默认True)
|
||||
spawn_pedestrians: 是否生成行人(默认True)
|
||||
spawn_cyclists: 是否生成自行车(默认True)
|
||||
enable_logging: 是否启用日志记录到文件
|
||||
log_file: 日志文件名(None则自动生成时间戳文件名)
|
||||
"""
|
||||
print("=" * 50)
|
||||
print("运行模式: 专家轨迹回放 (Replay Mode)")
|
||||
if scenario_id is not None:
|
||||
print(f"指定场景ID: {scenario_id}")
|
||||
if use_scenario_duration:
|
||||
print("使用场景原始时长")
|
||||
print("=" * 50)
|
||||
|
||||
# 如果指定了场景ID,只运行1个回合
|
||||
if scenario_id is not None:
|
||||
num_episodes = 1
|
||||
|
||||
# ✅ 环境创建移到循环外面,避免重复创建
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"horizon": horizon,
|
||||
"use_render": render, # 如果False会完全禁用渲染,避免LANE_FREEWAY错误
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": False, # 回放模式下不需要反应式交通
|
||||
"manual_control": False,
|
||||
"filter_offroad_vehicles": True, # 启用车道过滤
|
||||
"lane_tolerance": 3.0,
|
||||
"replay_mode": True, # 标记为回放模式
|
||||
"debug": debug,
|
||||
"specific_scenario_id": scenario_id, # 指定场景ID
|
||||
"use_scenario_duration": use_scenario_duration, # 使用场景时长
|
||||
# 对象类型过滤
|
||||
"spawn_vehicles": spawn_vehicles,
|
||||
"spawn_pedestrians": spawn_pedestrians,
|
||||
"spawn_cyclists": spawn_cyclists,
|
||||
# ✅ 关键:设置可用场景数量
|
||||
#"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数
|
||||
},
|
||||
agent2policy=None # 回放模式不需要统一策略
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取可用场景数量
|
||||
num_scenarios = env.config.get("num_scenarios", 1)
|
||||
print(f"可用场景数量: {num_scenarios}")
|
||||
|
||||
for episode in range(num_episodes):
|
||||
print(f"\n{'='*50}")
|
||||
print(f"回合 {episode + 1}/{num_episodes}")
|
||||
if scenario_id is not None:
|
||||
print(f"场景ID: {scenario_id}")
|
||||
else:
|
||||
# 循环使用场景
|
||||
scenario_idx = episode % num_scenarios
|
||||
print(f"使用场景索引: {scenario_idx}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
# ✅ 如果不是指定场景,使用循环的场景索引
|
||||
if scenario_id is not None:
|
||||
seed = scenario_id
|
||||
else:
|
||||
seed = episode % num_scenarios
|
||||
obs = env.reset(seed=seed)
|
||||
|
||||
# 为每个车辆分配 ReplayPolicy
|
||||
replay_policies = {}
|
||||
for agent_id, vehicle in env.controlled_agents.items():
|
||||
vehicle_id = vehicle.expert_vehicle_id
|
||||
if vehicle_id in env.expert_trajectories:
|
||||
replay_policy = ReplayPolicy(
|
||||
env.expert_trajectories[vehicle_id],
|
||||
vehicle_id
|
||||
)
|
||||
vehicle.set_policy(replay_policy)
|
||||
replay_policies[agent_id] = replay_policy
|
||||
|
||||
# 输出场景信息
|
||||
actual_horizon = env.config["horizon"]
|
||||
print(f"初始化完成:")
|
||||
print(f" 可控车辆数: {len(env.controlled_agents)}")
|
||||
print(f" 专家轨迹数: {len(env.expert_trajectories)}")
|
||||
print(f" 场景时长: {env.scenario_max_duration} 步")
|
||||
print(f" 实际Horizon: {actual_horizon} 步")
|
||||
|
||||
step_count = 0
|
||||
active_vehicles_count = []
|
||||
|
||||
while True:
|
||||
# 在回放模式下,直接使用专家轨迹设置车辆状态
|
||||
for agent_id, vehicle in list(env.controlled_agents.items()):
|
||||
vehicle_id = vehicle.expert_vehicle_id
|
||||
if vehicle_id in env.expert_trajectories and agent_id in replay_policies:
|
||||
target_state = replay_policies[agent_id].get_target_state(env.round)
|
||||
if target_state is not None:
|
||||
# 直接设置车辆状态(绕过物理引擎)
|
||||
# 只使用xy坐标,保持车辆在地面上
|
||||
position_2d = target_state['position'][:2]
|
||||
vehicle.set_position(position_2d)
|
||||
vehicle.set_heading_theta(target_state['heading'])
|
||||
vehicle.set_velocity(target_state['velocity'][:2] if len(target_state['velocity']) > 2 else target_state['velocity'])
|
||||
|
||||
# 使用空动作进行步进
|
||||
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
|
||||
if render:
|
||||
env.render(mode="topdown")
|
||||
|
||||
step_count += 1
|
||||
active_vehicles_count.append(len(env.controlled_agents))
|
||||
|
||||
# 每50步打印一次状态
|
||||
if step_count % 50 == 0:
|
||||
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
|
||||
|
||||
# 调试模式下打印车辆高度信息
|
||||
if debug and len(env.controlled_agents) > 0:
|
||||
sample_vehicle = list(env.controlled_agents.values())[0]
|
||||
z_pos = sample_vehicle.position[2] if len(sample_vehicle.position) > 2 else 0
|
||||
print(f" [DEBUG] 示例车辆高度: z={z_pos:.3f}m")
|
||||
|
||||
if dones["__all__"]:
|
||||
print(f"\n回合结束统计:")
|
||||
print(f" 总步数: {step_count}")
|
||||
print(f" 最大同时车辆数: {max(active_vehicles_count) if active_vehicles_count else 0}")
|
||||
print(f" 平均车辆数: {sum(active_vehicles_count) / len(active_vehicles_count) if active_vehicles_count else 0:.1f}")
|
||||
if use_scenario_duration:
|
||||
print(f" 场景完整回放: {'是' if step_count >= env.scenario_max_duration else '否'}")
|
||||
break
|
||||
finally:
|
||||
# ✅ 确保环境被正确关闭
|
||||
env.close()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("回放完成!")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
|
||||
scenario_id=None, use_scenario_duration=False,
|
||||
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
|
||||
"""
|
||||
仿真模式:使用自定义策略控制车辆
|
||||
车辆根据专家数据的初始位姿生成,然后由策略控制
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录
|
||||
num_episodes: 回合数
|
||||
horizon: 最大步数
|
||||
render: 是否渲染
|
||||
debug: 是否调试模式
|
||||
scenario_id: 指定场景ID(可选)
|
||||
use_scenario_duration: 是否使用场景原始时长
|
||||
spawn_vehicles: 是否生成车辆(默认True)
|
||||
spawn_pedestrians: 是否生成行人(默认True)
|
||||
spawn_cyclists: 是否生成自行车(默认True)
|
||||
"""
|
||||
print("=" * 50)
|
||||
print("运行模式: 策略仿真 (Simulation Mode)")
|
||||
if scenario_id is not None:
|
||||
print(f"指定场景ID: {scenario_id}")
|
||||
if use_scenario_duration:
|
||||
print("使用场景原始时长")
|
||||
print("=" * 50)
|
||||
|
||||
# 如果指定了场景ID,只运行1个回合
|
||||
if scenario_id is not None:
|
||||
num_episodes = 1
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
|
||||
# "data_directory": AssetLoader.file_path(AssetLoader.asset_path, "waymo", unix_style=False),
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": horizon,
|
||||
"use_render": render, # 如果False会完全禁用渲染,避免LANE_FREEWAY错误
|
||||
"horizon": 300,
|
||||
"use_render": True,
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": True,
|
||||
"manual_control": False,
|
||||
"filter_offroad_vehicles": True, # 启用车道过滤
|
||||
"lane_tolerance": 3.0,
|
||||
"replay_mode": False, # 仿真模式
|
||||
"debug": debug,
|
||||
"specific_scenario_id": scenario_id, # 指定场景ID
|
||||
"use_scenario_duration": use_scenario_duration, # 使用场景时长
|
||||
# 对象类型过滤
|
||||
"spawn_vehicles": spawn_vehicles,
|
||||
"spawn_pedestrians": spawn_pedestrians,
|
||||
"spawn_cyclists": spawn_cyclists,
|
||||
# ✅ 关键:设置可用场景数量
|
||||
#"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数
|
||||
"manual_control": True,
|
||||
|
||||
# 车道检测与过滤配置
|
||||
"filter_offroad_vehicles": True, # 启用车道区域过滤,过滤草坪等非车道区域的车辆
|
||||
"lane_tolerance": 3.0, # 车道检测容差(米),可根据需要调整
|
||||
"max_controlled_vehicles": 2, # 限制最大车辆数(可选,None表示不限制)
|
||||
"debug_lane_filter": True,
|
||||
"debug_traffic_light": True,
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取可用场景数量
|
||||
num_scenarios = env.config.get("num_scenarios", 1)
|
||||
print(f"可用场景数量: {num_scenarios}")
|
||||
obs = env.reset(0)
|
||||
for step in range(10000):
|
||||
actions = {
|
||||
aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents
|
||||
}
|
||||
|
||||
for episode in range(num_episodes):
|
||||
print(f"\n{'='*50}")
|
||||
print(f"回合 {episode + 1}/{num_episodes}")
|
||||
if scenario_id is not None:
|
||||
print(f"场景ID: {scenario_id}")
|
||||
else:
|
||||
# 循环使用场景
|
||||
scenario_idx = episode % num_scenarios
|
||||
print(f"使用场景索引: {scenario_idx}")
|
||||
print(f"{'='*50}")
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
env.render(mode="topdown")
|
||||
|
||||
# ✅ 如果不是指定场景,使用循环的场景索引
|
||||
if scenario_id is not None:
|
||||
seed = scenario_id
|
||||
else:
|
||||
seed = episode % num_scenarios
|
||||
obs = env.reset(seed=seed)
|
||||
if dones["__all__"]:
|
||||
break
|
||||
|
||||
actual_horizon = env.config["horizon"]
|
||||
print(f"初始化完成:")
|
||||
print(f" 可控车辆数: {len(env.controlled_agents)}")
|
||||
print(f" 场景时长: {env.scenario_max_duration} 步")
|
||||
print(f" 实际Horizon: {actual_horizon} 步")
|
||||
|
||||
step_count = 0
|
||||
total_reward = 0.0
|
||||
|
||||
while True:
|
||||
# 使用策略生成动作
|
||||
actions = {
|
||||
aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents
|
||||
}
|
||||
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
|
||||
if render:
|
||||
env.render(mode="topdown")
|
||||
|
||||
step_count += 1
|
||||
total_reward += sum(rewards.values())
|
||||
|
||||
# 每50步打印一次状态
|
||||
if step_count % 50 == 0:
|
||||
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
|
||||
|
||||
if dones["__all__"]:
|
||||
print(f"\n回合结束统计:")
|
||||
print(f" 总步数: {step_count}")
|
||||
print(f" 总奖励: {total_reward:.2f}")
|
||||
break
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("仿真完成!")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MetaDrive 多智能体环境运行脚本")
|
||||
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
choices=["replay", "simulation"],
|
||||
default="simulation",
|
||||
help="运行模式: replay=专家轨迹回放, simulation=策略仿真 (默认: simulation)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default=WAYMO_DATA_DIR,
|
||||
help=f"数据目录路径 (默认: {WAYMO_DATA_DIR})"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="运行回合数 (默认: 1)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--horizon",
|
||||
type=int,
|
||||
default=300,
|
||||
help="每回合最大步数 (默认: 300,如果启用 --use_scenario_duration 则自动设置)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_render",
|
||||
action="store_true",
|
||||
help="禁用渲染(加速运行)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="启用调试模式(显示详细日志)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--scenario_id",
|
||||
type=int,
|
||||
default=None,
|
||||
help="指定场景ID(可选,如指定则只运行该场景)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_scenario_duration",
|
||||
action="store_true",
|
||||
help="使用场景原始时长作为horizon(自动停止)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_vehicles",
|
||||
action="store_true",
|
||||
help="禁止生成车辆"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_pedestrians",
|
||||
action="store_true",
|
||||
help="禁止生成行人"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_cyclists",
|
||||
action="store_true",
|
||||
help="禁止生成自行车"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == "replay":
|
||||
run_replay_mode(
|
||||
data_dir=args.data_dir,
|
||||
num_episodes=args.episodes,
|
||||
horizon=args.horizon,
|
||||
render=not args.no_render,
|
||||
debug=args.debug,
|
||||
scenario_id=args.scenario_id,
|
||||
use_scenario_duration=args.use_scenario_duration,
|
||||
spawn_vehicles=not args.no_vehicles,
|
||||
spawn_pedestrians=not args.no_pedestrians,
|
||||
spawn_cyclists=not args.no_cyclists
|
||||
)
|
||||
else:
|
||||
run_simulation_mode(
|
||||
data_dir=args.data_dir,
|
||||
num_episodes=args.episodes,
|
||||
horizon=args.horizon,
|
||||
render=not args.no_render,
|
||||
debug=args.debug,
|
||||
scenario_id=args.scenario_id,
|
||||
use_scenario_duration=args.use_scenario_duration,
|
||||
spawn_vehicles=not args.no_vehicles,
|
||||
spawn_pedestrians=not args.no_pedestrians,
|
||||
spawn_cyclists=not args.no_cyclists
|
||||
)
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# 解析命令行参数
|
||||
enable_logging = "--log" in sys.argv or "-l" in sys.argv
|
||||
|
||||
# 提取自定义日志文件名
|
||||
log_file = None
|
||||
for arg in sys.argv:
|
||||
if arg.startswith("--log-file="):
|
||||
log_file = arg.split("=")[1]
|
||||
break
|
||||
|
||||
if enable_logging:
|
||||
# 使用日志记录
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
with setup_logger(log_file=log_file, log_dir=log_dir):
|
||||
main(enable_logging=True, log_file=log_file)
|
||||
else:
|
||||
# 普通运行(只输出到终端)
|
||||
print("💡 提示: 使用 --log 或 -l 参数启用日志记录")
|
||||
print(" 示例: python run_multiagent_env.py --log")
|
||||
print(" 自定义文件名: python run_multiagent_env.py --log --log-file=my_run.log")
|
||||
print("-" * 60)
|
||||
main(enable_logging=False)
|
||||
115
Env/run_multiagent_env_fast.py
Normal file
115
Env/run_multiagent_env_fast.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from simple_idm_policy import ConstantVelocityPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
from logger_utils import setup_logger
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/MAGAIL4AutoDrive/Env"
|
||||
|
||||
def main(enable_logging=False):
|
||||
"""极致性能优化版本 - 启用所有优化选项"""
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 300,
|
||||
|
||||
# 关闭所有渲染
|
||||
"use_render": False,
|
||||
"render_pipeline": False,
|
||||
"image_observation": False,
|
||||
"interface_panel": [],
|
||||
"manual_control": False,
|
||||
"show_fps": False,
|
||||
"debug": False,
|
||||
|
||||
# 物理引擎优化
|
||||
"physics_world_step_size": 0.02,
|
||||
"decision_repeat": 5,
|
||||
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": True,
|
||||
|
||||
# 车道检测与过滤配置
|
||||
"filter_offroad_vehicles": True, # 过滤非车道区域的车辆
|
||||
"lane_tolerance": 3.0,
|
||||
"max_controlled_vehicles": 15, # 限制车辆数以提升性能
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
# 【关键优化】启用激光雷达缓存
|
||||
# 每3帧才重新计算激光雷达,其余帧使用缓存
|
||||
# 可将激光雷达计算量减少到原来的1/3
|
||||
env.lidar_cache_interval = 3
|
||||
|
||||
obs = env.reset(0)
|
||||
|
||||
# 性能统计
|
||||
start_time = time.time()
|
||||
total_steps = 0
|
||||
|
||||
print("=" * 60)
|
||||
print("极致性能模式")
|
||||
print("激光雷达优化:80→40束 (前向), 10→6束 (侧向+车道线)")
|
||||
print("激光雷达缓存:每3帧计算一次,中间帧使用缓存")
|
||||
print("预期性能提升:3-5倍")
|
||||
print("=" * 60)
|
||||
|
||||
for step in range(10000):
|
||||
actions = {
|
||||
aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents
|
||||
}
|
||||
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
total_steps += 1
|
||||
|
||||
# 每100步输出一次性能统计
|
||||
if step % 100 == 0 and step > 0:
|
||||
elapsed = time.time() - start_time
|
||||
fps = total_steps / elapsed
|
||||
print(f"Step {step:4d}: FPS = {fps:6.2f}, 车辆数 = {len(env.controlled_agents):3d}, "
|
||||
f"平均步时间 = {1000/fps:.2f}ms")
|
||||
|
||||
if dones["__all__"]:
|
||||
break
|
||||
|
||||
# 最终统计
|
||||
elapsed = time.time() - start_time
|
||||
fps = total_steps / elapsed
|
||||
print("\n" + "=" * 60)
|
||||
print(f"总计: {total_steps} 步")
|
||||
print(f"耗时: {elapsed:.2f}s")
|
||||
print(f"平均FPS: {fps:.2f}")
|
||||
print(f"单步平均耗时: {1000/fps:.2f}ms")
|
||||
print("=" * 60)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 解析命令行参数
|
||||
enable_logging = "--log" in sys.argv or "-l" in sys.argv
|
||||
|
||||
# 提取自定义日志文件名
|
||||
log_file = None
|
||||
for arg in sys.argv:
|
||||
if arg.startswith("--log-file="):
|
||||
log_file = arg.split("=")[1]
|
||||
break
|
||||
|
||||
if enable_logging:
|
||||
# 使用日志记录
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
with setup_logger(log_file=log_file or "run_fast.log", log_dir=log_dir):
|
||||
main(enable_logging=True)
|
||||
else:
|
||||
# 普通运行(只输出到终端)
|
||||
print("💡 提示: 使用 --log 或 -l 参数启用日志记录")
|
||||
print("-" * 60)
|
||||
main(enable_logging=False)
|
||||
|
||||
156
Env/run_multiagent_env_parallel.py
Normal file
156
Env/run_multiagent_env_parallel.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
多进程并行版本 - 充分利用多核CPU
|
||||
适合大规模数据收集和训练
|
||||
"""
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from simple_idm_policy import ConstantVelocityPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
import time
|
||||
import os
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/MAGAIL4AutoDrive/Env"
|
||||
|
||||
|
||||
def run_single_env(args):
|
||||
"""在单个进程中运行一个环境实例"""
|
||||
seed, num_steps, worker_id = args
|
||||
|
||||
# 创建环境(每个进程独立)
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 300,
|
||||
|
||||
# 性能优化
|
||||
"use_render": False,
|
||||
"render_pipeline": False,
|
||||
"image_observation": False,
|
||||
"interface_panel": [],
|
||||
"manual_control": False,
|
||||
"show_fps": False,
|
||||
"debug": False,
|
||||
|
||||
"physics_world_step_size": 0.02,
|
||||
"decision_repeat": 5,
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": True,
|
||||
|
||||
# 车道检测与过滤配置
|
||||
"filter_offroad_vehicles": True,
|
||||
"lane_tolerance": 3.0,
|
||||
"max_controlled_vehicles": 15,
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
# 启用激光雷达缓存
|
||||
env.lidar_cache_interval = 3
|
||||
|
||||
# 运行仿真
|
||||
start_time = time.time()
|
||||
obs = env.reset(seed)
|
||||
total_steps = 0
|
||||
total_agents = 0
|
||||
|
||||
for step in range(num_steps):
|
||||
actions = {
|
||||
aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents
|
||||
}
|
||||
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
total_steps += 1
|
||||
total_agents += len(env.controlled_agents)
|
||||
|
||||
if dones["__all__"]:
|
||||
break
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
fps = total_steps / elapsed if elapsed > 0 else 0
|
||||
avg_agents = total_agents / total_steps if total_steps > 0 else 0
|
||||
|
||||
env.close()
|
||||
|
||||
return {
|
||||
'worker_id': worker_id,
|
||||
'seed': seed,
|
||||
'steps': total_steps,
|
||||
'elapsed': elapsed,
|
||||
'fps': fps,
|
||||
'avg_agents': avg_agents,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数:协调多个并行环境"""
|
||||
# 获取CPU核心数
|
||||
num_cores = cpu_count()
|
||||
# 建议使用物理核心数(12600KF是10核20线程,使用10个进程)
|
||||
num_workers = min(10, num_cores)
|
||||
|
||||
print("=" * 80)
|
||||
print(f"多进程并行模式")
|
||||
print(f"CPU核心数: {num_cores}")
|
||||
print(f"并行进程数: {num_workers}")
|
||||
print(f"每个环境运行: 1000步")
|
||||
print("=" * 80)
|
||||
|
||||
# 准备任务参数
|
||||
num_steps_per_env = 1000
|
||||
tasks = [(seed, num_steps_per_env, worker_id)
|
||||
for worker_id, seed in enumerate(range(num_workers))]
|
||||
|
||||
# 启动多进程池
|
||||
start_time = time.time()
|
||||
|
||||
with Pool(processes=num_workers) as pool:
|
||||
results = pool.map(run_single_env, tasks)
|
||||
|
||||
total_elapsed = time.time() - start_time
|
||||
|
||||
# 统计结果
|
||||
print("\n" + "=" * 80)
|
||||
print("各进程执行结果:")
|
||||
print("-" * 80)
|
||||
print(f"{'Worker':<8} {'Seed':<6} {'Steps':<8} {'Time(s)':<10} {'FPS':<8} {'平均车辆数':<12}")
|
||||
print("-" * 80)
|
||||
|
||||
total_steps = 0
|
||||
total_fps = 0
|
||||
|
||||
for result in results:
|
||||
print(f"{result['worker_id']:<8} "
|
||||
f"{result['seed']:<6} "
|
||||
f"{result['steps']:<8} "
|
||||
f"{result['elapsed']:<10.2f} "
|
||||
f"{result['fps']:<8.2f} "
|
||||
f"{result['avg_agents']:<12.1f}")
|
||||
total_steps += result['steps']
|
||||
total_fps += result['fps']
|
||||
|
||||
print("-" * 80)
|
||||
avg_fps_per_env = total_fps / len(results)
|
||||
total_throughput = total_steps / total_elapsed
|
||||
|
||||
print(f"\n总体统计:")
|
||||
print(f" 总步数: {total_steps}")
|
||||
print(f" 总耗时: {total_elapsed:.2f}s")
|
||||
print(f" 单环境平均FPS: {avg_fps_per_env:.2f}")
|
||||
print(f" 总吞吐量: {total_throughput:.2f} steps/s")
|
||||
print(f" 并行效率: {total_throughput / avg_fps_per_env:.1f}x")
|
||||
print("=" * 80)
|
||||
|
||||
# 与单进程对比
|
||||
print(f"\n性能对比:")
|
||||
print(f" 单进程FPS (预估): ~30 FPS")
|
||||
print(f" 多进程吞吐量: {total_throughput:.2f} steps/s")
|
||||
print(f" 性能提升: {total_throughput / 30:.1f}x")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
60
Env/run_multiagent_env_visual.py
Normal file
60
Env/run_multiagent_env_visual.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from simple_idm_policy import ConstantVelocityPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
import time
|
||||
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/MAGAIL4AutoDrive/Env"
|
||||
|
||||
def main():
|
||||
"""带可视化的版本(低FPS,约15帧)"""
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 300,
|
||||
|
||||
# 可视化设置(牺牲性能)
|
||||
"use_render": True,
|
||||
"manual_control": False,
|
||||
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": True,
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
obs = env.reset(0)
|
||||
|
||||
start_time = time.time()
|
||||
total_steps = 0
|
||||
|
||||
for step in range(10000):
|
||||
actions = {
|
||||
aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents
|
||||
}
|
||||
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
env.render(mode="topdown") # 实时渲染
|
||||
|
||||
total_steps += 1
|
||||
|
||||
if step % 100 == 0 and step > 0:
|
||||
elapsed = time.time() - start_time
|
||||
fps = total_steps / elapsed
|
||||
print(f"Step {step}: FPS = {fps:.2f}, 车辆数 = {len(env.controlled_agents)}")
|
||||
|
||||
if dones["__all__"]:
|
||||
break
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
fps = total_steps / elapsed
|
||||
print(f"\n总计: {total_steps} 步,耗时 {elapsed:.2f}s,平均FPS = {fps:.2f}")
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -15,7 +15,6 @@ class PolicyVehicle(DefaultVehicle):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.policy = None
|
||||
self.destination = None
|
||||
self.expert_vehicle_id = None # 关联专家车辆ID
|
||||
|
||||
def set_policy(self, policy):
|
||||
self.policy = policy
|
||||
@@ -23,9 +22,6 @@ class PolicyVehicle(DefaultVehicle):
|
||||
def set_destination(self, des):
|
||||
self.destination = des
|
||||
|
||||
def set_expert_vehicle_id(self, vid):
|
||||
self.expert_vehicle_id = vid
|
||||
|
||||
def act(self, observation, policy=None):
|
||||
if self.policy is not None:
|
||||
return self.policy.act(observation)
|
||||
@@ -57,15 +53,13 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
data_directory=None,
|
||||
num_controlled_agents=3,
|
||||
horizon=1000,
|
||||
filter_offroad_vehicles=True, # 车道过滤开关
|
||||
lane_tolerance=3.0, # 车道检测容差(米)
|
||||
replay_mode=False, # 回放模式开关
|
||||
specific_scenario_id=None, # 新增:指定场景ID(仅回放模式)
|
||||
use_scenario_duration=False, # 新增:使用场景原始时长作为horizon
|
||||
# 对象类型过滤选项
|
||||
spawn_vehicles=True, # 是否生成车辆
|
||||
spawn_pedestrians=True, # 是否生成行人
|
||||
spawn_cyclists=True, # 是否生成自行车
|
||||
# 车道检测与过滤配置
|
||||
filter_offroad_vehicles=True, # 是否过滤非车道区域的车辆
|
||||
lane_tolerance=3.0, # 车道检测容差(米),用于放宽边界条件
|
||||
max_controlled_vehicles=None, # 最大可控车辆数限制(None表示不限制)
|
||||
# 调试模式配置
|
||||
debug_traffic_light=False, # 是否启用红绿灯检测调试输出
|
||||
debug_lane_filter=False, # 是否启用车道过滤调试输出
|
||||
))
|
||||
return config
|
||||
|
||||
@@ -75,179 +69,96 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
self.controlled_agent_ids = []
|
||||
self.obs_list = []
|
||||
self.round = 0
|
||||
self.expert_trajectories = {} # 存储完整专家轨迹
|
||||
self.replay_mode = config.get("replay_mode", False)
|
||||
self.scenario_max_duration = 0 # 场景实际最大时长
|
||||
# 调试模式配置
|
||||
self.debug_traffic_light = config.get("debug_traffic_light", False)
|
||||
self.debug_lane_filter = config.get("debug_lane_filter", False)
|
||||
super().__init__(config)
|
||||
|
||||
def reset(self, seed: Union[None, int] = None):
|
||||
self.round = 0
|
||||
|
||||
if self.logger is None:
|
||||
self.logger = get_logger()
|
||||
log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO)
|
||||
set_log_level(log_level)
|
||||
|
||||
# ✅ 关键修复:在每次 reset 前清理所有自定义生成的对象
|
||||
if hasattr(self, 'engine') and self.engine is not None:
|
||||
if hasattr(self, 'controlled_agents') and self.controlled_agents:
|
||||
# 先从 agent_manager 中移除
|
||||
if hasattr(self.engine, 'agent_manager'):
|
||||
for agent_id in list(self.controlled_agents.keys()):
|
||||
if agent_id in self.engine.agent_manager.active_agents:
|
||||
self.engine.agent_manager.active_agents.pop(agent_id)
|
||||
|
||||
# 然后清理对象
|
||||
for agent_id, vehicle in list(self.controlled_agents.items()):
|
||||
try:
|
||||
self.engine.clear_objects([vehicle.id])
|
||||
except:
|
||||
pass
|
||||
|
||||
self.controlled_agents.clear()
|
||||
self.controlled_agent_ids.clear()
|
||||
log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO)
|
||||
set_log_level(log_level)
|
||||
|
||||
self.lazy_init()
|
||||
self._reset_global_seed(seed)
|
||||
|
||||
if self.engine is None:
|
||||
raise ValueError("Broken MetaDrive instance.")
|
||||
|
||||
# 如果指定了场景ID,修改start_scenario_index
|
||||
if self.config.get("specific_scenario_id") is not None:
|
||||
scenario_id = self.config.get("specific_scenario_id")
|
||||
self.config["start_scenario_index"] = scenario_id
|
||||
if self.config.get("debug", False):
|
||||
self.logger.info(f"Using specific scenario ID: {scenario_id}")
|
||||
# 在engine.reset()之前清理对象
|
||||
self.before_reset()
|
||||
|
||||
# ✅ 先初始化引擎和 lanes
|
||||
self.engine.reset()
|
||||
self.reset_sensors()
|
||||
self.engine.taskMgr.step()
|
||||
self.lanes = self.engine.map_manager.current_map.road_network.graph
|
||||
|
||||
# 记录专家数据(现在 self.lanes 已经初始化)
|
||||
# 记录专家数据中每辆车的位置,接着全部清除,只保留位置等信息,用于后续生成
|
||||
_obj_to_clean_this_frame = []
|
||||
self.car_birth_info_list = []
|
||||
self.expert_trajectories.clear()
|
||||
total_vehicles = 0
|
||||
total_pedestrians = 0
|
||||
total_cyclists = 0
|
||||
filtered_vehicles = 0
|
||||
filtered_by_type = 0
|
||||
self.scenario_max_duration = 0 # 重置场景时长
|
||||
|
||||
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
|
||||
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
|
||||
continue
|
||||
else:
|
||||
if track["type"] == MetaDriveType.VEHICLE:
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
valid = track['state']['valid']
|
||||
first_show = np.argmax(valid) if valid.any() else -1
|
||||
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
|
||||
# id,出现时间,出生点坐标,出生朝向,目的地
|
||||
self.car_birth_info_list.append({
|
||||
'id': track['metadata']['object_id'],
|
||||
'show_time': first_show,
|
||||
'begin': (track['state']['position'][first_show, 0], track['state']['position'][first_show, 1]),
|
||||
'heading': track['state']['heading'][first_show],
|
||||
'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1])
|
||||
})
|
||||
|
||||
# 对象类型过滤
|
||||
obj_type = track["type"]
|
||||
|
||||
# 统计对象类型
|
||||
if obj_type == MetaDriveType.VEHICLE:
|
||||
total_vehicles += 1
|
||||
elif obj_type == MetaDriveType.PEDESTRIAN:
|
||||
total_pedestrians += 1
|
||||
elif obj_type == MetaDriveType.CYCLIST:
|
||||
total_cyclists += 1
|
||||
|
||||
# 根据配置过滤对象类型
|
||||
if obj_type == MetaDriveType.VEHICLE and not self.config.get("spawn_vehicles", True):
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
filtered_by_type += 1
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(f"Filtering VEHICLE {track['metadata']['object_id']} - spawn_vehicles=False")
|
||||
continue
|
||||
|
||||
if obj_type == MetaDriveType.PEDESTRIAN and not self.config.get("spawn_pedestrians", True):
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
filtered_by_type += 1
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(f"Filtering PEDESTRIAN {track['metadata']['object_id']} - spawn_pedestrians=False")
|
||||
continue
|
||||
|
||||
if obj_type == MetaDriveType.CYCLIST and not self.config.get("spawn_cyclists", True):
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
filtered_by_type += 1
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(f"Filtering CYCLIST {track['metadata']['object_id']} - spawn_cyclists=False")
|
||||
continue
|
||||
|
||||
# 只处理车辆类型(行人和自行车暂时只做过滤)
|
||||
if track["type"] == MetaDriveType.VEHICLE:
|
||||
valid = track['state']['valid']
|
||||
first_show = np.argmax(valid) if valid.any() else -1
|
||||
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
|
||||
|
||||
if first_show == -1 or last_show == -1:
|
||||
continue
|
||||
|
||||
# 更新场景最大时长
|
||||
self.scenario_max_duration = max(self.scenario_max_duration, last_show + 1)
|
||||
|
||||
# 获取车辆初始位置
|
||||
initial_position = (
|
||||
track['state']['position'][first_show, 0],
|
||||
track['state']['position'][first_show, 1]
|
||||
)
|
||||
|
||||
# 车道过滤
|
||||
if self.config.get("filter_offroad_vehicles", True):
|
||||
if not self._is_position_on_lane(initial_position):
|
||||
filtered_vehicles += 1
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(
|
||||
f"Filtering vehicle {track['metadata']['object_id']} - "
|
||||
f"not on lane at position {initial_position}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 存储完整专家轨迹(只使用2D位置,避免高度问题)
|
||||
object_id = track['metadata']['object_id']
|
||||
positions_2d = track['state']['position'].copy()
|
||||
positions_2d[:, 2] = 0 # 将z坐标设为0,让MetaDrive自动处理高度
|
||||
|
||||
self.expert_trajectories[object_id] = {
|
||||
'positions': positions_2d,
|
||||
'headings': track['state']['heading'].copy(),
|
||||
'velocities': track['state']['velocity'].copy(),
|
||||
'valid': track['state']['valid'].copy(),
|
||||
}
|
||||
|
||||
# 保存车辆生成信息
|
||||
self.car_birth_info_list.append({
|
||||
'id': object_id,
|
||||
'show_time': first_show,
|
||||
'begin': initial_position,
|
||||
'heading': track['state']['heading'][first_show],
|
||||
'velocity': track['state']['velocity'][first_show] if self.config.get("inherit_expert_velocity", False) else None,
|
||||
'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1])
|
||||
})
|
||||
|
||||
# 在回放和仿真模式下都清除原始专家车辆
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
|
||||
# 清除专家车辆和过滤的对象
|
||||
for scenario_id in _obj_to_clean_this_frame:
|
||||
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
|
||||
|
||||
# 输出统计信息
|
||||
if self.config.get("debug", False):
|
||||
self.logger.info(f"=== 对象统计 ===")
|
||||
self.logger.info(f"车辆 (VEHICLE): 总数={total_vehicles}, 车道过滤={filtered_vehicles}, 保留={total_vehicles - filtered_vehicles}")
|
||||
self.logger.info(f"行人 (PEDESTRIAN): 总数={total_pedestrians}")
|
||||
self.logger.info(f"自行车 (CYCLIST): 总数={total_cyclists}")
|
||||
self.logger.info(f"类型过滤: {filtered_by_type} 个对象")
|
||||
self.logger.info(f"场景时长: {self.scenario_max_duration} 步")
|
||||
self.engine.reset()
|
||||
self.reset_sensors()
|
||||
self.engine.taskMgr.step()
|
||||
|
||||
# 如果启用场景时长控制,更新horizon
|
||||
if self.config.get("use_scenario_duration", False) and self.scenario_max_duration > 0:
|
||||
original_horizon = self.config["horizon"]
|
||||
self.config["horizon"] = self.scenario_max_duration
|
||||
if self.config.get("debug", False):
|
||||
self.logger.info(f"Horizon updated from {original_horizon} to {self.scenario_max_duration} (scenario duration)")
|
||||
self.lanes = self.engine.map_manager.current_map.road_network.graph
|
||||
|
||||
# 调试:场景信息统计
|
||||
if self.debug_lane_filter or self.debug_traffic_light:
|
||||
print(f"\n📍 场景信息统计:")
|
||||
print(f" - 总车道数: {len(self.lanes)}")
|
||||
|
||||
# 统计红绿灯数量
|
||||
if self.debug_traffic_light:
|
||||
traffic_light_lanes = []
|
||||
for lane in self.lanes.values():
|
||||
if self.engine.light_manager.has_traffic_light(lane.lane.index):
|
||||
traffic_light_lanes.append(lane.lane.index)
|
||||
print(f" - 有红绿灯的车道数: {len(traffic_light_lanes)}")
|
||||
if len(traffic_light_lanes) > 0:
|
||||
print(f" 车道索引: {traffic_light_lanes[:5]}" +
|
||||
(f" ... 共{len(traffic_light_lanes)}个" if len(traffic_light_lanes) > 5 else ""))
|
||||
else:
|
||||
print(f" ⚠️ 场景中没有红绿灯!")
|
||||
|
||||
# 在获取车道信息后,进行车道区域过滤
|
||||
total_cars_before = len(self.car_birth_info_list)
|
||||
valid_count, filtered_count, filtered_list = self._filter_valid_spawn_positions()
|
||||
|
||||
# 输出过滤信息
|
||||
if filtered_count > 0:
|
||||
self.logger.warning(f"车辆生成位置过滤: 原始 {total_cars_before} 辆, "
|
||||
f"有效 {valid_count} 辆, 过滤 {filtered_count} 辆")
|
||||
for filtered_car in filtered_list[:5]: # 只显示前5个
|
||||
self.logger.debug(f" - 过滤车辆 ID={filtered_car['id']}, "
|
||||
f"位置={filtered_car['position']}, "
|
||||
f"原因={filtered_car['reason']}")
|
||||
if filtered_count > 5:
|
||||
self.logger.debug(f" - ... 还有 {filtered_count - 5} 辆车被过滤")
|
||||
|
||||
# 限制最大车辆数(在过滤后应用)
|
||||
max_vehicles = self.config.get("max_controlled_vehicles", None)
|
||||
if max_vehicles is not None and len(self.car_birth_info_list) > max_vehicles:
|
||||
self.car_birth_info_list = self.car_birth_info_list[:max_vehicles]
|
||||
self.logger.info(f"限制最大车辆数为 {max_vehicles} 辆")
|
||||
|
||||
self.logger.info(f"最终生成 {len(self.car_birth_info_list)} 辆可控车辆")
|
||||
|
||||
if self.top_down_renderer is not None:
|
||||
self.top_down_renderer.clear()
|
||||
@@ -256,219 +167,336 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
self.dones = {}
|
||||
self.episode_rewards = defaultdict(float)
|
||||
self.episode_lengths = defaultdict(int)
|
||||
self.controlled_agents.clear()
|
||||
self.controlled_agent_ids.clear()
|
||||
|
||||
# 调用父类reset会清理场景
|
||||
super().reset(seed) # 初始化场景
|
||||
|
||||
# 重新生成车辆
|
||||
self._spawn_controlled_agents()
|
||||
|
||||
return self._get_all_obs()
|
||||
|
||||
def _is_position_on_lane(self, position, tolerance=None):
|
||||
"""
|
||||
检测给定位置是否在有效车道范围内
|
||||
|
||||
Args:
|
||||
position: (x, y) 车辆位置坐标
|
||||
tolerance: 容差范围(米),用于放宽检测条件。None时使用配置中的默认值
|
||||
|
||||
Returns:
|
||||
bool: True表示在车道上,False表示在非车道区域(如草坪、停车场等)
|
||||
"""
|
||||
if not hasattr(self, 'lanes') or self.lanes is None:
|
||||
if self.debug_lane_filter:
|
||||
print(f" ⚠️ 车道信息未初始化,默认允许")
|
||||
return True # 如果车道信息未初始化,默认允许生成
|
||||
|
||||
if tolerance is None:
|
||||
tolerance = self.config.get("lane_tolerance", 3.0)
|
||||
|
||||
# 确保 self.lanes 已初始化
|
||||
if not hasattr(self, 'lanes') or self.lanes is None:
|
||||
if self.config.get("debug", False):
|
||||
self.logger.warning("Lanes not initialized, skipping lane check")
|
||||
return True
|
||||
position_2d = (position[0], position[1])
|
||||
|
||||
position_2d = np.array(position[:2]) if len(position) > 2 else np.array(position)
|
||||
if self.debug_lane_filter:
|
||||
print(f" 🔍 检测位置 ({position_2d[0]:.2f}, {position_2d[1]:.2f}), 容差={tolerance}m")
|
||||
|
||||
try:
|
||||
for lane in self.lanes.values():
|
||||
# 方法1:直接检测是否在任一车道上
|
||||
checked_lanes = 0
|
||||
for lane in self.lanes.values():
|
||||
try:
|
||||
checked_lanes += 1
|
||||
if lane.lane.point_on_lane(position_2d):
|
||||
if self.debug_lane_filter:
|
||||
print(f" ✅ 在车道上 (车道{lane.lane.index}, 检查了{checked_lanes}条)")
|
||||
return True
|
||||
except:
|
||||
continue
|
||||
|
||||
lane_start = np.array(lane.lane.start)[:2]
|
||||
lane_end = np.array(lane.lane.end)[:2]
|
||||
lane_vec = lane_end - lane_start
|
||||
lane_length = np.linalg.norm(lane_vec)
|
||||
if self.debug_lane_filter:
|
||||
print(f" ❌ 不在任何车道上 (检查了{checked_lanes}条车道)")
|
||||
|
||||
if lane_length < 1e-6:
|
||||
continue
|
||||
# 方法2:如果严格检测失败,使用容差范围检测(考虑车道边缘)
|
||||
# 注释:此方法已被禁用,如需启用请取消注释
|
||||
# if tolerance > 0:
|
||||
# for lane in self.lanes.values():
|
||||
# try:
|
||||
# # 计算点到车道中心线的距离
|
||||
# lane_obj = lane.lane
|
||||
# # 获取车道长度并检测最近点
|
||||
# s, lateral = lane_obj.local_coordinates(position_2d)
|
||||
|
||||
lane_vec_normalized = lane_vec / lane_length
|
||||
point_vec = position_2d - lane_start
|
||||
projection = np.dot(point_vec, lane_vec_normalized)
|
||||
|
||||
if 0 <= projection <= lane_length:
|
||||
closest_point = lane_start + projection * lane_vec_normalized
|
||||
distance = np.linalg.norm(position_2d - closest_point)
|
||||
if distance <= tolerance:
|
||||
return True
|
||||
except Exception as e:
|
||||
if self.config.get("debug", False):
|
||||
self.logger.warning(f"Lane check error: {e}")
|
||||
return False
|
||||
# # 如果横向距离在容差范围内,认为是有效的
|
||||
# if abs(lateral) <= tolerance and 0 <= s <= lane_obj.length:
|
||||
# return True
|
||||
# except:
|
||||
# continue
|
||||
|
||||
return False
|
||||
|
||||
def _filter_valid_spawn_positions(self):
|
||||
"""
|
||||
过滤掉生成位置不在有效车道上的车辆信息
|
||||
根据配置决定是否执行过滤
|
||||
|
||||
Returns:
|
||||
tuple: (有效车辆数量, 被过滤车辆数量, 被过滤车辆ID列表)
|
||||
"""
|
||||
# 如果配置中禁用了过滤,直接返回
|
||||
if not self.config.get("filter_offroad_vehicles", True):
|
||||
if self.debug_lane_filter:
|
||||
print(f"🚫 车道过滤已禁用")
|
||||
return len(self.car_birth_info_list), 0, []
|
||||
|
||||
if self.debug_lane_filter:
|
||||
print(f"\n🔍 开始车道过滤: 共 {len(self.car_birth_info_list)} 辆车待检测")
|
||||
|
||||
valid_cars = []
|
||||
filtered_cars = []
|
||||
tolerance = self.config.get("lane_tolerance", 3.0)
|
||||
|
||||
for idx, car in enumerate(self.car_birth_info_list):
|
||||
if self.debug_lane_filter:
|
||||
print(f"\n车辆 {idx+1}/{len(self.car_birth_info_list)}: ID={car['id']}")
|
||||
|
||||
if self._is_position_on_lane(car['begin'], tolerance=tolerance):
|
||||
valid_cars.append(car)
|
||||
if self.debug_lane_filter:
|
||||
print(f" ✅ 保留")
|
||||
else:
|
||||
filtered_cars.append({
|
||||
'id': car['id'],
|
||||
'position': car['begin'],
|
||||
'reason': '生成位置不在有效车道上(可能在草坪/停车场等区域)'
|
||||
})
|
||||
if self.debug_lane_filter:
|
||||
print(f" ❌ 过滤 (原因: 不在车道上)")
|
||||
|
||||
self.car_birth_info_list = valid_cars
|
||||
|
||||
if self.debug_lane_filter:
|
||||
print(f"\n📊 过滤结果: 保留 {len(valid_cars)} 辆, 过滤 {len(filtered_cars)} 辆")
|
||||
|
||||
return len(valid_cars), len(filtered_cars), filtered_cars
|
||||
|
||||
def _spawn_controlled_agents(self):
|
||||
"""
|
||||
生成应该在当前或之前出现的车辆
|
||||
如果round=0且所有车辆的show_time>0,则生成show_time最小的车辆(保证至少有车辆出现)
|
||||
"""
|
||||
vehicles_to_spawn = []
|
||||
|
||||
# ego_vehicle = self.engine.agent_manager.active_agents.get("default_agent")
|
||||
# ego_position = ego_vehicle.position if ego_vehicle else np.array([0, 0])
|
||||
for car in self.car_birth_info_list:
|
||||
if car['show_time'] <= self.round:
|
||||
vehicles_to_spawn.append(car)
|
||||
if car['show_time'] == self.round:
|
||||
agent_id = f"controlled_{car['id']}"
|
||||
|
||||
# 如果当前round没有车辆应该出现,但车辆列表不为空,则生成最早出现的车辆
|
||||
# 这样可以确保在reset时至少有车辆出现
|
||||
# if len(vehicles_to_spawn) == 0 and len(self.car_birth_info_list) > 0:
|
||||
# if self.config.get("debug", False):
|
||||
# self.logger.debug(
|
||||
# f"No vehicles to spawn at round {self.round}, "
|
||||
# f"spawning earliest vehicle instead"
|
||||
# )
|
||||
# # 找到show_time最小的车辆
|
||||
# earliest_car = min(self.car_birth_info_list, key=lambda x: x['show_time'])
|
||||
# vehicles_to_spawn.append(earliest_car)
|
||||
|
||||
for car in vehicles_to_spawn:
|
||||
agent_id = f"controlled_{car['id']}"
|
||||
|
||||
# 避免重复生成
|
||||
if agent_id in self.controlled_agents:
|
||||
continue
|
||||
|
||||
vehicle_config = {}
|
||||
vehicle = self.engine.spawn_object(
|
||||
PolicyVehicle,
|
||||
vehicle_config=vehicle_config,
|
||||
position=car['begin'],
|
||||
heading=car['heading']
|
||||
)
|
||||
|
||||
# 重置车辆状态
|
||||
reset_kwargs = {
|
||||
'position': car['begin'],
|
||||
'heading': car['heading']
|
||||
}
|
||||
|
||||
# 如果启用速度继承,设置初始速度
|
||||
if car.get('velocity') is not None:
|
||||
reset_kwargs['velocity'] = car['velocity']
|
||||
|
||||
vehicle.reset(**reset_kwargs)
|
||||
|
||||
# 设置策略和目的地
|
||||
vehicle.set_policy(self.policy)
|
||||
vehicle.set_destination(car['end'])
|
||||
vehicle.set_expert_vehicle_id(car['id'])
|
||||
|
||||
self.controlled_agents[agent_id] = vehicle
|
||||
self.controlled_agent_ids.append(agent_id)
|
||||
|
||||
# 注册到引擎的 active_agents
|
||||
self.engine.agent_manager.active_agents[agent_id] = vehicle
|
||||
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(
|
||||
f"Spawned vehicle {agent_id} at round {self.round} "
|
||||
f"(show_time={car['show_time']}), position {car['begin']}"
|
||||
vehicle = self.engine.spawn_object(
|
||||
PolicyVehicle,
|
||||
vehicle_config={},
|
||||
position=car['begin'],
|
||||
heading=car['heading']
|
||||
)
|
||||
vehicle.reset(position=car['begin'], heading=car['heading'])
|
||||
|
||||
def _get_all_obs(self):
|
||||
self.obs_list = []
|
||||
vehicle.set_policy(self.policy)
|
||||
vehicle.set_destination(car['end'])
|
||||
|
||||
for agent_id, vehicle in self.controlled_agents.items():
|
||||
state = vehicle.get_state()
|
||||
traffic_light = 0
|
||||
self.controlled_agents[agent_id] = vehicle
|
||||
self.controlled_agent_ids.append(agent_id)
|
||||
|
||||
# ✅ 关键:注册到引擎的 active_agents,才能参与物理更新
|
||||
self.engine.agent_manager.active_agents[agent_id] = vehicle
|
||||
|
||||
def before_reset(self):
|
||||
"""在reset之前清理对象"""
|
||||
# 清理所有可控车辆
|
||||
if hasattr(self, 'controlled_agents') and hasattr(self, 'engine'):
|
||||
# 使用MetaDrive的clear_objects方法清理
|
||||
if hasattr(self.engine, 'clear_objects'):
|
||||
try:
|
||||
self.engine.clear_objects(list(self.controlled_agents.keys()))
|
||||
except:
|
||||
pass
|
||||
|
||||
# 从agent_manager中移除
|
||||
if hasattr(self.engine, 'agent_manager'):
|
||||
for agent_id in list(self.controlled_agents.keys()):
|
||||
if agent_id in self.engine.agent_manager.active_agents:
|
||||
self.engine.agent_manager.active_agents.pop(agent_id)
|
||||
|
||||
self.controlled_agents.clear()
|
||||
self.controlled_agent_ids.clear()
|
||||
|
||||
def _get_traffic_light_state(self, vehicle):
|
||||
"""
|
||||
获取车辆当前位置的红绿灯状态(优化版)
|
||||
|
||||
解决问题:
|
||||
1. 部分红绿灯状态为None的问题 - 添加异常处理和默认值
|
||||
2. 车道分段导致无法获取红绿灯的问题 - 优先使用导航模块,失败时回退到遍历
|
||||
|
||||
Returns:
|
||||
int: 0=无红绿灯, 1=绿灯, 2=黄灯, 3=红灯
|
||||
"""
|
||||
traffic_light = 0
|
||||
state = vehicle.get_state()
|
||||
position_2d = state['position'][:2]
|
||||
|
||||
if self.debug_traffic_light:
|
||||
print(f"\n🚦 检测车辆红绿灯 - 位置: ({position_2d[0]:.1f}, {position_2d[1]:.1f})")
|
||||
|
||||
try:
|
||||
# 方法1:优先尝试从车辆导航模块获取当前车道(更高效)
|
||||
if hasattr(vehicle, 'navigation') and vehicle.navigation is not None:
|
||||
current_lane = vehicle.navigation.current_lane
|
||||
|
||||
if self.debug_traffic_light:
|
||||
print(f" 方法1-导航模块:")
|
||||
print(f" current_lane = {current_lane}")
|
||||
print(f" lane_index = {current_lane.index if current_lane else 'None'}")
|
||||
|
||||
if current_lane:
|
||||
has_light = self.engine.light_manager.has_traffic_light(current_lane.index)
|
||||
|
||||
if self.debug_traffic_light:
|
||||
print(f" has_traffic_light = {has_light}")
|
||||
|
||||
if has_light:
|
||||
status = self.engine.light_manager._lane_index_to_obj[current_lane.index].status
|
||||
|
||||
if self.debug_traffic_light:
|
||||
print(f" status = {status}")
|
||||
|
||||
if status == 'TRAFFIC_LIGHT_GREEN':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法1成功: 绿灯")
|
||||
return 1
|
||||
elif status == 'TRAFFIC_LIGHT_YELLOW':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法1成功: 黄灯")
|
||||
return 2
|
||||
elif status == 'TRAFFIC_LIGHT_RED':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法1成功: 红灯")
|
||||
return 3
|
||||
elif status is None:
|
||||
if self.debug_traffic_light:
|
||||
print(f" ⚠️ 方法1: 红绿灯状态为None")
|
||||
return 0
|
||||
else:
|
||||
if self.debug_traffic_light:
|
||||
print(f" 该车道没有红绿灯")
|
||||
else:
|
||||
if self.debug_traffic_light:
|
||||
print(f" 导航模块current_lane为None")
|
||||
else:
|
||||
if self.debug_traffic_light:
|
||||
has_nav = hasattr(vehicle, 'navigation')
|
||||
nav_not_none = vehicle.navigation is not None if has_nav else False
|
||||
print(f" 方法1-导航模块: 不可用 (hasattr={has_nav}, not_none={nav_not_none})")
|
||||
|
||||
except Exception as e:
|
||||
if self.debug_traffic_light:
|
||||
print(f" ❌ 方法1异常: {type(e).__name__}: {e}")
|
||||
pass
|
||||
|
||||
try:
|
||||
# 方法2:遍历所有车道查找(兜底方案,处理车道分段问题)
|
||||
if self.debug_traffic_light:
|
||||
print(f" 方法2-遍历车道: 开始遍历 {len(self.lanes)} 条车道")
|
||||
|
||||
found_lane = False
|
||||
checked_lanes = 0
|
||||
|
||||
for lane in self.lanes.values():
|
||||
if lane.lane.point_on_lane(state['position'][:2]):
|
||||
if self.engine.light_manager.has_traffic_light(lane.lane.index):
|
||||
traffic_light = self.engine.light_manager._lane_index_to_obj[lane.lane.index].status
|
||||
if traffic_light == 'TRAFFIC_LIGHT_GREEN':
|
||||
traffic_light = 1
|
||||
elif traffic_light == 'TRAFFIC_LIGHT_YELLOW':
|
||||
traffic_light = 2
|
||||
elif traffic_light == 'TRAFFIC_LIGHT_RED':
|
||||
traffic_light = 3
|
||||
try:
|
||||
checked_lanes += 1
|
||||
if lane.lane.point_on_lane(position_2d):
|
||||
found_lane = True
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✓ 找到车辆所在车道: {lane.lane.index} (检查了{checked_lanes}条)")
|
||||
|
||||
has_light = self.engine.light_manager.has_traffic_light(lane.lane.index)
|
||||
if self.debug_traffic_light:
|
||||
print(f" has_traffic_light = {has_light}")
|
||||
|
||||
if has_light:
|
||||
status = self.engine.light_manager._lane_index_to_obj[lane.lane.index].status
|
||||
if self.debug_traffic_light:
|
||||
print(f" status = {status}")
|
||||
|
||||
if status == 'TRAFFIC_LIGHT_GREEN':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法2成功: 绿灯")
|
||||
return 1
|
||||
elif status == 'TRAFFIC_LIGHT_YELLOW':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法2成功: 黄灯")
|
||||
return 2
|
||||
elif status == 'TRAFFIC_LIGHT_RED':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法2成功: 红灯")
|
||||
return 3
|
||||
elif status is None:
|
||||
if self.debug_traffic_light:
|
||||
print(f" ⚠️ 方法2: 红绿灯状态为None")
|
||||
return 0
|
||||
else:
|
||||
traffic_light = 0
|
||||
break
|
||||
if self.debug_traffic_light:
|
||||
print(f" 该车道没有红绿灯")
|
||||
break
|
||||
except:
|
||||
continue
|
||||
|
||||
# 使用最近10辆车的相对位置与相对速度替代原80维LiDAR点云
|
||||
lidar_cloud_points, detected_objects = self.engine.get_sensor("lidar").perceive(
|
||||
num_lasers=80,
|
||||
distance=30,
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.dynamic_world
|
||||
)
|
||||
nearest_vehicle_info = self.engine.get_sensor("lidar").get_surrounding_vehicles_info(
|
||||
vehicle,
|
||||
detected_objects,
|
||||
perceive_distance=30,
|
||||
num_others=10,
|
||||
add_others_navi=False
|
||||
)
|
||||
if self.debug_traffic_light and not found_lane:
|
||||
print(f" ⚠️ 未找到车辆所在车道 (检查了{checked_lanes}条)")
|
||||
|
||||
except Exception as e:
|
||||
if self.debug_traffic_light:
|
||||
print(f" ❌ 方法2异常: {type(e).__name__}: {e}")
|
||||
pass
|
||||
|
||||
if self.debug_traffic_light:
|
||||
print(f" 结果: 返回 {traffic_light} (无红绿灯/未知)")
|
||||
|
||||
return traffic_light
|
||||
|
||||
def _get_all_obs(self):
|
||||
# position, velocity, heading, lidar, navigation, TODO: trafficlight -> list
|
||||
self.obs_list = []
|
||||
for agent_id, vehicle in self.controlled_agents.items():
|
||||
state = vehicle.get_state()
|
||||
|
||||
# 使用优化后的红绿灯检测方法
|
||||
traffic_light = self._get_traffic_light_state(vehicle)
|
||||
|
||||
lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.dynamic_world)
|
||||
side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8,
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(num_lasers=10, distance=3,
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
|
||||
obs = (list(state['position'][:2]) + list(state['velocity']) + [state['heading_theta']]
|
||||
+ nearest_vehicle_info + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
|
||||
obs = (state['position'][:2] + list(state['velocity']) + [state['heading_theta']]
|
||||
+ lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
|
||||
+ list(vehicle.destination))
|
||||
|
||||
self.obs_list.append(obs)
|
||||
|
||||
return self.obs_list
|
||||
|
||||
def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]):
|
||||
self.round += 1
|
||||
|
||||
# 应用动作
|
||||
for agent_id, action in action_dict.items():
|
||||
if agent_id in self.controlled_agents:
|
||||
self.controlled_agents[agent_id].before_step(action)
|
||||
|
||||
# 物理引擎步进
|
||||
self.engine.step()
|
||||
|
||||
# 后处理
|
||||
for agent_id in action_dict:
|
||||
if agent_id in self.controlled_agents:
|
||||
self.controlled_agents[agent_id].after_step()
|
||||
|
||||
# 生成新车辆
|
||||
self._spawn_controlled_agents()
|
||||
|
||||
# 获取观测
|
||||
obs = self._get_all_obs()
|
||||
|
||||
rewards = {aid: 0.0 for aid in self.controlled_agents}
|
||||
dones = {aid: False for aid in self.controlled_agents}
|
||||
|
||||
# ✅ 修复:添加回放模式的完成检查
|
||||
replay_finished = False
|
||||
if self.replay_mode and self.config.get("use_scenario_duration", False):
|
||||
# 检查是否所有专家轨迹都已播放完毕
|
||||
if self.round >= self.scenario_max_duration:
|
||||
replay_finished = True
|
||||
if self.config.get("debug", False):
|
||||
self.logger.info(f"Replay finished at step {self.round}/{self.scenario_max_duration}")
|
||||
|
||||
dones["__all__"] = self.episode_step >= self.config["horizon"] or replay_finished
|
||||
|
||||
dones["__all__"] = self.episode_step >= self.config["horizon"]
|
||||
infos = {aid: {} for aid in self.controlled_agents}
|
||||
|
||||
return obs, rewards, dones, infos
|
||||
|
||||
def close(self):
|
||||
# ✅ 清理所有生成的车辆
|
||||
if hasattr(self, 'controlled_agents') and self.controlled_agents:
|
||||
for agent_id, vehicle in list(self.controlled_agents.items()):
|
||||
if vehicle in self.engine.get_objects():
|
||||
self.engine.clear_objects([vehicle.id])
|
||||
self.controlled_agents.clear()
|
||||
self.controlled_agent_ids.clear()
|
||||
|
||||
super().close()
|
||||
@@ -6,13 +6,8 @@ class ConstantVelocityPolicy:
|
||||
|
||||
def act(self):
|
||||
self.step_num += 1
|
||||
if self.step_num % 30 < 15:
|
||||
throttle = 1.0
|
||||
else:
|
||||
throttle = 1.0
|
||||
# 简单的前进策略:直行 + 较大油门
|
||||
steering = 0.0 # 直行
|
||||
throttle = 0.5 # 中等油门,让车辆有明显运动
|
||||
|
||||
steering = 0.1
|
||||
|
||||
# return [steering, throttle]
|
||||
|
||||
return [0.0,0.05]
|
||||
return [steering, throttle]
|
||||
|
||||
219
Env/test_lane_filter.py
Normal file
219
Env/test_lane_filter.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
测试车道过滤和红绿灯检测功能
|
||||
"""
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from simple_idm_policy import ConstantVelocityPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
from logger_utils import setup_logger
|
||||
import os
|
||||
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/MAGAIL4AutoDrive/Env"
|
||||
|
||||
def test_lane_filter():
|
||||
"""测试车道过滤功能(基础版)"""
|
||||
print("=" * 60)
|
||||
print("测试1:车道过滤功能(基础)")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建启用过滤的环境
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 100,
|
||||
"use_render": False,
|
||||
|
||||
# 车道过滤配置
|
||||
"filter_offroad_vehicles": True,
|
||||
"lane_tolerance": 3.0,
|
||||
"max_controlled_vehicles": 10,
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
print("\n启用车道过滤...")
|
||||
obs = env.reset(0)
|
||||
print(f"生成车辆数: {len(env.controlled_agents)}")
|
||||
print(f"观测数据长度: {len(obs)}")
|
||||
|
||||
# 运行几步
|
||||
for step in range(5):
|
||||
actions = {aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents}
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
|
||||
env.close()
|
||||
print("✓ 车道过滤测试通过\n")
|
||||
|
||||
|
||||
def test_lane_filter_debug():
|
||||
"""测试车道过滤功能(详细调试)"""
|
||||
print("=" * 60)
|
||||
print("测试1b:车道过滤功能(详细调试模式)")
|
||||
print("=" * 60)
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 100,
|
||||
"use_render": False,
|
||||
|
||||
# 车道过滤配置
|
||||
"filter_offroad_vehicles": True,
|
||||
"lane_tolerance": 3.0,
|
||||
"max_controlled_vehicles": 5, # 只看前5辆车
|
||||
|
||||
# 🔥 启用调试模式
|
||||
"debug_lane_filter": True, # 启用车道过滤调试
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
print("\n启用车道过滤调试...")
|
||||
obs = env.reset(0)
|
||||
|
||||
env.close()
|
||||
print("\n✓ 车道过滤调试测试完成\n")
|
||||
|
||||
|
||||
def test_traffic_light():
|
||||
"""测试红绿灯检测功能"""
|
||||
print("=" * 60)
|
||||
print("测试2:红绿灯检测功能(启用详细调试)")
|
||||
print("=" * 60)
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 100,
|
||||
"use_render": False,
|
||||
"filter_offroad_vehicles": True,
|
||||
"max_controlled_vehicles": 3, # 只测试3辆车
|
||||
|
||||
# 🔥 启用调试模式
|
||||
"debug_traffic_light": True, # 启用红绿灯调试
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
obs = env.reset(0)
|
||||
|
||||
# 测试红绿灯检测(调试模式会自动输出详细信息)
|
||||
print(f"\n" + "="*60)
|
||||
print(f"开始逐车检测红绿灯状态(共 {len(env.controlled_agents)} 辆车)")
|
||||
print("="*60)
|
||||
|
||||
for idx, (aid, vehicle) in enumerate(list(env.controlled_agents.items())[:3]): # 只测试前3辆
|
||||
print(f"\n【车辆 {idx+1}/3】 ID={aid}")
|
||||
traffic_light = env._get_traffic_light_state(vehicle)
|
||||
state = vehicle.get_state()
|
||||
|
||||
status_text = {0: '无/未知', 1: '绿灯', 2: '黄灯', 3: '红灯'}[traffic_light]
|
||||
print(f"最终结果: 红绿灯状态={traffic_light} ({status_text})\n")
|
||||
|
||||
env.close()
|
||||
print("="*60)
|
||||
print("✓ 红绿灯检测测试完成")
|
||||
print("="*60 + "\n")
|
||||
|
||||
|
||||
def test_without_filter():
|
||||
"""测试禁用过滤的情况"""
|
||||
print("=" * 60)
|
||||
print("测试3:禁用过滤(对比测试)")
|
||||
print("=" * 60)
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 100,
|
||||
"use_render": False,
|
||||
|
||||
# 禁用过滤
|
||||
"filter_offroad_vehicles": False,
|
||||
"max_controlled_vehicles": None,
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
print("\n禁用车道过滤...")
|
||||
obs = env.reset(0)
|
||||
print(f"生成车辆数(未过滤): {len(env.controlled_agents)}")
|
||||
|
||||
env.close()
|
||||
print("✓ 禁用过滤测试通过\n")
|
||||
|
||||
|
||||
def run_tests(debug_mode=False):
|
||||
"""运行测试的主函数"""
|
||||
try:
|
||||
if debug_mode:
|
||||
print("🐛 调试模式启用")
|
||||
print("=" * 60 + "\n")
|
||||
test_lane_filter_debug()
|
||||
test_traffic_light()
|
||||
else:
|
||||
print("⚡ 标准测试模式(使用 --debug 参数启用详细调试)")
|
||||
print("=" * 60 + "\n")
|
||||
test_lane_filter()
|
||||
test_traffic_light()
|
||||
test_without_filter()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ 所有测试通过!")
|
||||
print("=" * 60)
|
||||
print("\n功能说明:")
|
||||
print("1. 车道过滤功能已启用,自动过滤非车道区域车辆")
|
||||
print("2. 红绿灯检测采用双重策略,确保稳定获取状态")
|
||||
print("3. 可通过配置参数灵活启用/禁用功能")
|
||||
print("\n使用方法:")
|
||||
print(" python Env/test_lane_filter.py # 标准测试")
|
||||
print(" python Env/test_lane_filter.py --debug # 详细调试")
|
||||
print(" python Env/test_lane_filter.py --log # 保存日志")
|
||||
print(" python Env/test_lane_filter.py --debug --log # 调试+日志")
|
||||
print("\n请运行 run_multiagent_env.py 查看完整效果")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
# 解析命令行参数
|
||||
debug_mode = "--debug" in sys.argv or "-d" in sys.argv
|
||||
enable_logging = "--log" in sys.argv or "-l" in sys.argv
|
||||
|
||||
# 提取自定义日志文件名
|
||||
log_file = None
|
||||
for arg in sys.argv:
|
||||
if arg.startswith("--log-file="):
|
||||
log_file = arg.split("=")[1]
|
||||
break
|
||||
|
||||
if enable_logging:
|
||||
# 启用日志记录
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
|
||||
# 生成默认日志文件名
|
||||
if log_file is None:
|
||||
mode_suffix = "debug" if debug_mode else "standard"
|
||||
from datetime import datetime
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
log_file = f"test_{mode_suffix}_{timestamp}.log"
|
||||
|
||||
with setup_logger(log_file=log_file, log_dir=log_dir):
|
||||
run_tests(debug_mode=debug_mode)
|
||||
else:
|
||||
# 不启用日志,直接运行
|
||||
run_tests(debug_mode=debug_mode)
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
简单验证脚本:检查车辆是否正确获取观测空间
|
||||
用法:python verify_observations.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from replay_policy import ReplayPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
import numpy as np
|
||||
|
||||
def verify_observations(data_dir, scenario_id=0):
|
||||
"""
|
||||
验证观测空间是否正确获取
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录
|
||||
scenario_id: 场景ID
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("观测空间验证工具")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建环境
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"horizon": 300,
|
||||
"use_render": False, # 不渲染,加速运行
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": False,
|
||||
"manual_control": False,
|
||||
"filter_offroad_vehicles": True,
|
||||
"lane_tolerance": 3.0,
|
||||
"replay_mode": True,
|
||||
"debug": True, # 启用调试以查看详细信息
|
||||
"specific_scenario_id": scenario_id,
|
||||
"use_scenario_duration": True,
|
||||
},
|
||||
agent2policy=None
|
||||
)
|
||||
|
||||
# 重置环境
|
||||
print(f"\n加载场景 {scenario_id}...")
|
||||
obs = env.reset(seed=scenario_id)
|
||||
|
||||
# 输出基本信息
|
||||
print(f"\n场景信息:")
|
||||
print(f" - 可控车辆数: {len(env.controlled_agents)}")
|
||||
print(f" - 观测数量: {len(obs)}")
|
||||
print(f" - 场景时长: {env.scenario_max_duration} 步")
|
||||
print(f" - 车辆生成列表长度: {len(env.car_birth_info_list)}")
|
||||
print(f" - 当前回合数 (round): {env.round}")
|
||||
|
||||
# 检查车辆生成信息
|
||||
if len(env.car_birth_info_list) > 0:
|
||||
print(f"\n车辆生成信息分析:")
|
||||
show_times = [car['show_time'] for car in env.car_birth_info_list]
|
||||
print(f" - show_time 分布: min={min(show_times)}, max={max(show_times)}")
|
||||
print(f" - show_time == 0 的车辆数: {sum(1 for st in show_times if st == 0)}")
|
||||
print(f" - 前5个车辆的 show_time: {show_times[:5]}")
|
||||
else:
|
||||
print(f"\n⚠️ 警告: 车辆生成列表为空!可能原因:")
|
||||
print(f" 1. 所有车辆都被车道过滤移除")
|
||||
print(f" 2. 所有车辆都被类型过滤移除")
|
||||
print(f" 3. 场景数据中没有有效车辆")
|
||||
|
||||
# 验证观测空间
|
||||
print(f"\n" + "=" * 60)
|
||||
print("观测空间验证")
|
||||
print("=" * 60)
|
||||
|
||||
if len(obs) == 0:
|
||||
print("❌ 错误:没有获取到任何观测!")
|
||||
env.close()
|
||||
return False
|
||||
|
||||
# 检查第一个观测
|
||||
first_obs = obs[0]
|
||||
print(f"\n第一个车辆的观测:")
|
||||
print(f" - 观测类型: {type(first_obs)}")
|
||||
print(f" - 观测维度: {len(first_obs)}")
|
||||
|
||||
# 详细解析观测
|
||||
if isinstance(first_obs, (list, np.ndarray)):
|
||||
obs_array = np.array(first_obs)
|
||||
print(f" - 观测形状: {obs_array.shape}")
|
||||
print(f" - 数据类型: {obs_array.dtype}")
|
||||
print(f"\n观测内容分解:")
|
||||
# 新观测划分(与 Env/scenario_env._get_all_obs 对齐):
|
||||
# [x, y] (2) + [vx, vy] (2) + heading (1)
|
||||
# + nearest vehicles info (10 vehicles * 4 = 40)
|
||||
# + side_lidar (10) + lane_line_lidar (10)
|
||||
# + traffic_light (1) + destination (2)
|
||||
lidar_len = 40
|
||||
side_len = 10
|
||||
lane_len = 10
|
||||
pos_slice = slice(0, 2)
|
||||
vel_slice = slice(2, 4)
|
||||
heading_idx = 4
|
||||
lidar_slice = slice(5, 5 + lidar_len)
|
||||
side_slice = slice(lidar_slice.stop, lidar_slice.stop + side_len)
|
||||
lane_slice = slice(side_slice.stop, side_slice.stop + lane_len)
|
||||
tl_idx = lane_slice.stop
|
||||
dest_slice = slice(tl_idx + 1, tl_idx + 3)
|
||||
|
||||
print(f" 位置 (x, y): {obs_array[pos_slice]}")
|
||||
print(f" 速度 (vx, vy): {obs_array[vel_slice]}")
|
||||
print(f" 航向角: {obs_array[heading_idx]:.3f} 弧度")
|
||||
print(f" 最近车辆信息: {len(obs_array[lidar_slice])} 维 (10辆*4: 相对x/相对y/相对vx/相对vy)")
|
||||
print(f" 侧向检测: {len(obs_array[side_slice])} 个点")
|
||||
print(f" 车道线检测: {len(obs_array[lane_slice])} 个点")
|
||||
print(f" 交通灯状态: {obs_array[tl_idx]}")
|
||||
print(f" 目的地 (x, y): {obs_array[dest_slice]}")
|
||||
|
||||
# 检查数据有效性
|
||||
print(f"\n数据有效性检查:")
|
||||
has_nan = np.isnan(obs_array).any()
|
||||
has_inf = np.isinf(obs_array).any()
|
||||
|
||||
if has_nan:
|
||||
print(f" ❌ 观测包含 NaN 值!")
|
||||
else:
|
||||
print(f" ✅ 无 NaN 值")
|
||||
|
||||
if has_inf:
|
||||
print(f" ❌ 观测包含 Inf 值!")
|
||||
else:
|
||||
print(f" ✅ 无 Inf 值")
|
||||
|
||||
# 检查最近车辆信息数据(40维)
|
||||
lidar_data = obs_array[lidar_slice]
|
||||
lidar_min = np.min(lidar_data)
|
||||
lidar_max = np.max(lidar_data)
|
||||
print(f"\n 最近车辆特征范围: [{lidar_min:.2f}, {lidar_max:.2f}]")
|
||||
|
||||
if lidar_max > 0:
|
||||
print(f" ✅ 存在非零最近车辆特征")
|
||||
else:
|
||||
print(f" ⚠️ 最近车辆特征全零(可能无邻车或距离过远)")
|
||||
|
||||
# 运行几步,验证观测持续有效
|
||||
print(f"\n" + "=" * 60)
|
||||
print("多步运行验证(前 5 步)")
|
||||
print("=" * 60)
|
||||
|
||||
for step in range(5):
|
||||
# 空动作
|
||||
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
|
||||
print(f"\nStep {step + 1}:")
|
||||
print(f" - 活跃车辆数: {len(env.controlled_agents)}")
|
||||
print(f" - 观测数量: {len(obs)}")
|
||||
|
||||
if len(obs) > 0:
|
||||
sample_obs = np.array(obs[0])
|
||||
print(f" - 第一辆车位置: ({sample_obs[0]:.2f}, {sample_obs[1]:.2f})")
|
||||
print(f" - 数据有效: {'✅' if not (np.isnan(sample_obs).any() or np.isinf(sample_obs).any()) else '❌'}")
|
||||
|
||||
if dones["__all__"]:
|
||||
print(f" - 场景结束")
|
||||
break
|
||||
|
||||
env.close()
|
||||
|
||||
print(f"\n" + "=" * 60)
|
||||
print("✅ 验证完成!观测空间正常工作")
|
||||
print("=" * 60)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="验证观测空间")
|
||||
parser.add_argument("--data_dir", type=str, default="/home/huangfukk/mdsn", help="数据目录")
|
||||
parser.add_argument("--scenario_id", type=int, default=0, help="场景ID")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
verify_observations(args.data_dir, args.scenario_id)
|
||||
543
MAGAIL算法应用指南.md
Normal file
543
MAGAIL算法应用指南.md
Normal file
@@ -0,0 +1,543 @@
|
||||
# MAGAIL算法应用指南
|
||||
|
||||
## 目录
|
||||
1. [Algorithm模块概览](#algorithm模块概览)
|
||||
2. [如何应用到环境](#如何应用到环境)
|
||||
3. [完整训练流程](#完整训练流程)
|
||||
4. [当前实现状态](#当前实现状态)
|
||||
5. [需要完善的部分](#需要完善的部分)
|
||||
|
||||
---
|
||||
|
||||
## Algorithm模块概览
|
||||
|
||||
### 📁 模块文件说明
|
||||
|
||||
```
|
||||
Algorithm/
|
||||
├── bert.py # BERT判别器/价值网络
|
||||
├── disc.py # GAIL判别器(继承BERT)
|
||||
├── policy.py # 策略网络(Actor)
|
||||
├── ppo.py # PPO算法基类
|
||||
├── magail.py # MAGAIL主算法(继承PPO)
|
||||
├── buffer.py # 经验回放缓冲区
|
||||
└── utils.py # 工具函数(标准化等)
|
||||
```
|
||||
|
||||
### 🔗 模块依赖关系
|
||||
|
||||
```
|
||||
MAGAIL (magail.py)
|
||||
├─ 继承 PPO (ppo.py)
|
||||
│ ├─ 使用 RolloutBuffer (buffer.py)
|
||||
│ ├─ 使用 StateIndependentPolicy (policy.py)
|
||||
│ └─ 使用 Bert作为Critic (bert.py)
|
||||
│
|
||||
├─ 使用 GAILDiscrim (disc.py)
|
||||
│ └─ 继承 Bert (bert.py)
|
||||
│
|
||||
└─ 使用 Normalizer (utils.py)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 如何应用到环境
|
||||
|
||||
### ✅ 已完成的准备工作
|
||||
|
||||
我已经为您:
|
||||
|
||||
1. **修复了PPO代码bug**:添加了缺失的`action_shape`参数
|
||||
2. **创建了训练脚本**:`train_magail.py`
|
||||
3. **提供了完整框架**:包含环境初始化、训练循环、模型保存等
|
||||
|
||||
### 🚀 快速开始
|
||||
|
||||
#### 方法1:使用训练脚本(推荐)
|
||||
|
||||
```bash
|
||||
# 基本训练(使用默认参数)
|
||||
python train_magail.py
|
||||
|
||||
# 自定义参数
|
||||
python train_magail.py \
|
||||
--data-dir /path/to/waymo/data \
|
||||
--episodes 1000 \
|
||||
--horizon 300 \
|
||||
--batch-size 256 \
|
||||
--lr-actor 3e-4 \
|
||||
--render # 可视化
|
||||
|
||||
# 查看所有参数
|
||||
python train_magail.py --help
|
||||
```
|
||||
|
||||
#### 方法2:在Jupyter Notebook中使用
|
||||
|
||||
```python
|
||||
import sys
|
||||
sys.path.append('Algorithm')
|
||||
sys.path.append('Env')
|
||||
|
||||
from Algorithm.magail import MAGAIL
|
||||
from Env.scenario_env import MultiAgentScenarioEnv
|
||||
|
||||
# 初始化环境
|
||||
env = MultiAgentScenarioEnv(config={...})
|
||||
|
||||
# 初始化MAGAIL
|
||||
magail = MAGAIL(
|
||||
buffer_exp=expert_buffer,
|
||||
input_dim=(108,), # 观测维度
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# 训练循环
|
||||
for episode in range(1000):
|
||||
obs = env.reset()
|
||||
for step in range(300):
|
||||
actions, log_pis = magail.explore(obs)
|
||||
next_obs, rewards, dones, infos = env.step(actions)
|
||||
# ... 更新逻辑
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 完整训练流程
|
||||
|
||||
### 📊 数据流程图
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ MAGAIL训练流程 │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
|
||||
第1步: 初始化
|
||||
├─ 加载Waymo专家数据 → ExpertBuffer
|
||||
├─ 创建MAGAIL算法实例
|
||||
│ ├─ Actor (policy.py)
|
||||
│ ├─ Critic (bert.py)
|
||||
│ ├─ Discriminator (disc.py)
|
||||
│ └─ Buffers (buffer.py)
|
||||
└─ 创建多智能体环境
|
||||
|
||||
第2步: 训练循环
|
||||
for episode in range(episodes):
|
||||
├─ env.reset() → 重置环境,生成车辆
|
||||
│
|
||||
for step in range(horizon):
|
||||
├─ obs = env._get_all_obs() # 收集观测
|
||||
│
|
||||
├─ actions = magail.explore(obs) # 策略采样
|
||||
│
|
||||
├─ next_obs, rewards, dones = env.step(actions)
|
||||
│
|
||||
├─ buffer.append(obs, actions, rewards, ...) # 存储经验
|
||||
│
|
||||
└─ if step % rollout_length == 0:
|
||||
├─ 更新判别器
|
||||
│ ├─ 采样策略数据: buffer.sample()
|
||||
│ ├─ 采样专家数据: expert_buffer.sample()
|
||||
│ └─ update_disc(policy_data, expert_data)
|
||||
│
|
||||
├─ 计算GAIL奖励
|
||||
│ └─ reward = -log(1 - D(s, s'))
|
||||
│
|
||||
└─ 更新PPO
|
||||
├─ 计算GAE优势
|
||||
├─ update_actor()
|
||||
└─ update_critic()
|
||||
|
||||
第3步: 评估与保存
|
||||
└─ 保存模型、记录指标
|
||||
```
|
||||
|
||||
### 🔑 关键代码段
|
||||
|
||||
#### 1. 初始化MAGAIL
|
||||
|
||||
```python
|
||||
from Algorithm.magail import MAGAIL
|
||||
|
||||
magail = MAGAIL(
|
||||
buffer_exp=expert_buffer, # 专家数据缓冲区
|
||||
input_dim=(obs_dim,), # 观测维度 (108,)
|
||||
device=device, # cuda/cpu
|
||||
action_shape=(2,), # 动作维度 [转向, 油门]
|
||||
|
||||
# 判别器参数
|
||||
disc_coef=20.0, # 判别器损失系数
|
||||
disc_grad_penalty=0.1, # 梯度惩罚系数
|
||||
disc_logit_reg=0.25, # Logit正则化
|
||||
disc_weight_decay=0.0005, # 权重衰减
|
||||
lr_disc=3e-4, # 判别器学习率
|
||||
epoch_disc=5, # 判别器更新轮数
|
||||
|
||||
# PPO参数
|
||||
rollout_length=2048, # 更新间隔
|
||||
lr_actor=3e-4, # Actor学习率
|
||||
lr_critic=3e-4, # Critic学习率
|
||||
epoch_ppo=10, # PPO更新轮数
|
||||
batch_size=256, # 批次大小
|
||||
gamma=0.995, # 折扣因子
|
||||
lambd=0.97, # GAE lambda
|
||||
|
||||
# 其他
|
||||
use_gail_norm=True, # 使用数据标准化
|
||||
)
|
||||
```
|
||||
|
||||
#### 2. 环境交互
|
||||
|
||||
```python
|
||||
# 重置环境
|
||||
obs_list = env.reset(episode)
|
||||
|
||||
# 收集观测(所有车辆)
|
||||
obs_array = np.array(env.obs_list) # shape: (n_agents, 108)
|
||||
|
||||
# 策略采样
|
||||
actions, log_pis = magail.explore(obs_array)
|
||||
# actions: list of [转向, 油门] for each agent
|
||||
# log_pis: list of log probabilities
|
||||
|
||||
# 构建动作字典
|
||||
action_dict = {
|
||||
agent_id: actions[i]
|
||||
for i, agent_id in enumerate(env.controlled_agents.keys())
|
||||
}
|
||||
|
||||
# 环境步进
|
||||
next_obs, rewards, dones, infos = env.step(action_dict)
|
||||
```
|
||||
|
||||
#### 3. 模型更新
|
||||
|
||||
```python
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
writer = SummaryWriter('logs')
|
||||
|
||||
# 更新判别器和策略
|
||||
if total_steps % rollout_length == 0:
|
||||
# MAGAIL会自动:
|
||||
# 1. 从buffer采样策略数据
|
||||
# 2. 从expert_buffer采样专家数据
|
||||
# 3. 更新判别器
|
||||
# 4. 计算GAIL奖励
|
||||
# 5. 更新PPO(Actor + Critic)
|
||||
|
||||
reward = magail.update(writer, total_steps)
|
||||
|
||||
print(f"Step {total_steps}, Reward: {reward:.4f}")
|
||||
```
|
||||
|
||||
#### 4. 保存和加载模型
|
||||
|
||||
```python
|
||||
# 保存
|
||||
magail.save_models('outputs/models/checkpoint_100')
|
||||
|
||||
# 加载
|
||||
magail.load_models('outputs/models/checkpoint_100/model.pth')
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 当前实现状态
|
||||
|
||||
### ✅ 已实现功能
|
||||
|
||||
| 模块 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| BERT判别器 | ✅ 完整 | 支持动态车辆数量 |
|
||||
| GAIL判别器 | ✅ 完整 | 包含梯度惩罚、正则化 |
|
||||
| 策略网络 | ✅ 完整 | 高斯策略,重参数化 |
|
||||
| PPO算法 | ✅ 完整 | GAE、裁剪目标、自适应LR |
|
||||
| MAGAIL | ✅ 完整 | 判别器+PPO整合 |
|
||||
| 缓冲区 | ✅ 完整 | 经验存储和采样 |
|
||||
| 数据标准化 | ✅ 完整 | 运行时统计量 |
|
||||
| 环境接口 | ✅ 完整 | 多智能体场景环境 |
|
||||
|
||||
### ⚠️ 需要注意的问题
|
||||
|
||||
#### 1. 多智能体适配问题
|
||||
|
||||
**当前状态:** Algorithm模块设计为单智能体,但环境是多智能体
|
||||
|
||||
**影响:**
|
||||
- `buffer.append()` 接受单个状态-动作对
|
||||
- 但环境返回多个智能体的数据
|
||||
|
||||
**解决方案A:** 将所有智能体视为一个整体
|
||||
```python
|
||||
# 拼接所有智能体的观测
|
||||
all_obs = np.concatenate([obs for obs in obs_list])
|
||||
all_actions = np.concatenate([actions for actions in action_list])
|
||||
```
|
||||
|
||||
**解决方案B:** 为每个智能体独立存储
|
||||
```python
|
||||
for i, agent_id in enumerate(env.controlled_agents):
|
||||
buffer.append(obs_list[i], actions[i], rewards[i], ...)
|
||||
```
|
||||
|
||||
**推荐:** 解决方案B,因为MAGAIL的设计就是处理多智能体的
|
||||
|
||||
#### 2. 专家数据加载
|
||||
|
||||
**当前状态:** `ExpertBuffer` 类只有框架,未实现实际加载
|
||||
|
||||
**需要完善:**
|
||||
```python
|
||||
def _extract_trajectories(self, scenario_data):
|
||||
"""
|
||||
需要根据Waymo数据格式实现
|
||||
|
||||
示例结构:
|
||||
scenario_data = {
|
||||
'tracks': {
|
||||
'vehicle_id': {
|
||||
'states': [...], # 状态序列
|
||||
'actions': [...], # 动作序列(如果有)
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
# TODO: 提取state和next_state对
|
||||
for track_id, track_data in scenario_data['tracks'].items():
|
||||
states = track_data['states']
|
||||
for i in range(len(states) - 1):
|
||||
self.states.append(states[i])
|
||||
self.next_states.append(states[i+1])
|
||||
```
|
||||
|
||||
#### 3. 观测维度对齐
|
||||
|
||||
**当前假设:** 观测维度为108
|
||||
- 位置(2) + 速度(2) + 朝向(1) + 激光雷达(80) + 侧向(10) + 车道线(10) + 红绿灯(1) + 目标点(2) = 108
|
||||
|
||||
**需要验证:** 实际运行时打印观测shape
|
||||
```python
|
||||
obs = env.reset()
|
||||
print(f"观测维度: {len(obs[0]) if obs else 0}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 需要完善的部分
|
||||
|
||||
### 🔨 短期TODO
|
||||
|
||||
#### 1. 修复多智能体buffer问题
|
||||
|
||||
**创建文件:** `Algorithm/multi_agent_buffer.py`
|
||||
|
||||
```python
|
||||
class MultiAgentRolloutBuffer:
|
||||
"""
|
||||
多智能体经验缓冲区
|
||||
|
||||
支持动态数量的智能体
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size, state_shape, action_shape, device):
|
||||
self.buffer_size = buffer_size
|
||||
self.state_shape = state_shape
|
||||
self.action_shape = action_shape
|
||||
self.device = device
|
||||
|
||||
# 使用列表存储,支持动态智能体数量
|
||||
self.episodes = []
|
||||
self.current_episode = {
|
||||
'states': [],
|
||||
'actions': [],
|
||||
'rewards': [],
|
||||
'dones': [],
|
||||
'log_pis': [],
|
||||
'next_states': [],
|
||||
}
|
||||
|
||||
def append(self, state, action, reward, done, log_pi, next_state):
|
||||
"""添加单步经验"""
|
||||
self.current_episode['states'].append(state)
|
||||
self.current_episode['actions'].append(action)
|
||||
self.current_episode['rewards'].append(reward)
|
||||
self.current_episode['dones'].append(done)
|
||||
self.current_episode['log_pis'].append(log_pi)
|
||||
self.current_episode['next_states'].append(next_state)
|
||||
|
||||
def finish_episode(self):
|
||||
"""完成一个episode"""
|
||||
self.episodes.append(self.current_episode)
|
||||
self.current_episode = {
|
||||
'states': [],
|
||||
'actions': [],
|
||||
'rewards': [],
|
||||
'dones': [],
|
||||
'log_pis': [],
|
||||
'next_states': [],
|
||||
}
|
||||
|
||||
def sample(self, batch_size):
|
||||
"""采样批次"""
|
||||
# 从所有episode中随机采样
|
||||
all_states = []
|
||||
all_next_states = []
|
||||
|
||||
for episode in self.episodes:
|
||||
all_states.extend(episode['states'])
|
||||
all_next_states.extend(episode['next_states'])
|
||||
|
||||
indices = np.random.choice(len(all_states), batch_size, replace=False)
|
||||
|
||||
states = torch.tensor([all_states[i] for i in indices], device=self.device)
|
||||
next_states = torch.tensor([all_next_states[i] for i in indices], device=self.device)
|
||||
|
||||
return states, next_states
|
||||
```
|
||||
|
||||
#### 2. 实现专家数据加载
|
||||
|
||||
**需要了解:** Waymo数据的实际格式
|
||||
|
||||
```python
|
||||
# 示例:读取一个pkl文件并打印结构
|
||||
import pickle
|
||||
|
||||
with open('Env/exp_converted/exp_converted_0/sd_waymo_*.pkl', 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
print(type(data))
|
||||
print(data.keys() if isinstance(data, dict) else len(data))
|
||||
# 根据实际结构调整加载代码
|
||||
```
|
||||
|
||||
#### 3. 完善训练循环
|
||||
|
||||
**在 `train_magail.py` 中添加:**
|
||||
|
||||
```python
|
||||
# 完整的buffer存储逻辑
|
||||
for i, agent_id in enumerate(env.controlled_agents.keys()):
|
||||
if i < len(obs_array) and i < len(actions):
|
||||
magail.buffer.append(
|
||||
state=obs_array[i],
|
||||
action=actions[i],
|
||||
reward=rewards.get(agent_id, 0.0),
|
||||
done=dones.get(agent_id, False),
|
||||
tm_done=dones.get(agent_id, False),
|
||||
log_pi=log_pis[i],
|
||||
next_state=next_obs_array[i] if i < len(next_obs_array) else obs_array[i],
|
||||
next_state_gail=next_obs_array[i] if i < len(next_obs_array) else obs_array[i],
|
||||
means=magail.actor.means[i].detach().cpu().numpy(),
|
||||
stds=magail.actor.log_stds.exp()[0].detach().cpu().numpy()
|
||||
)
|
||||
```
|
||||
|
||||
### 🎯 中期TODO
|
||||
|
||||
1. **实现多智能体BERT**:当前BERT接受(batch, N, obs_dim),需要确保正确处理
|
||||
2. **奖励设计**:当前环境奖励为0,需要设计合理的任务奖励
|
||||
3. **评估脚本**:创建评估脚本,可视化训练好的策略
|
||||
4. **超参数调优**:使用wandb或tensorboard进行超参数搜索
|
||||
|
||||
---
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 示例1:简单训练
|
||||
|
||||
```bash
|
||||
# 1. 确保环境正常
|
||||
python Env/run_multiagent_env.py
|
||||
|
||||
# 2. 开始训练(不渲染,快速训练)
|
||||
python train_magail.py \
|
||||
--episodes 100 \
|
||||
--horizon 200 \
|
||||
--rollout-length 1024 \
|
||||
--batch-size 128
|
||||
|
||||
# 3. 查看训练日志
|
||||
tensorboard --logdir outputs/magail_*/logs
|
||||
```
|
||||
|
||||
### 示例2:调试模式
|
||||
|
||||
```bash
|
||||
# 少量episode,启用渲染
|
||||
python train_magail.py \
|
||||
--episodes 5 \
|
||||
--horizon 100 \
|
||||
--render
|
||||
```
|
||||
|
||||
### 示例3:在代码中使用
|
||||
|
||||
```python
|
||||
# test_algorithm.py
|
||||
import sys
|
||||
sys.path.append('Algorithm')
|
||||
|
||||
from Algorithm.magail import MAGAIL
|
||||
import torch
|
||||
|
||||
# 创建虚拟数据测试
|
||||
class DummyExpertBuffer:
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
|
||||
def sample(self, batch_size):
|
||||
states = torch.randn(batch_size, 108, device=self.device)
|
||||
next_states = torch.randn(batch_size, 108, device=self.device)
|
||||
return states, next_states
|
||||
|
||||
# 初始化
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
expert_buffer = DummyExpertBuffer(device)
|
||||
|
||||
magail = MAGAIL(
|
||||
buffer_exp=expert_buffer,
|
||||
input_dim=(108,),
|
||||
device=device,
|
||||
action_shape=(2,),
|
||||
)
|
||||
|
||||
# 测试前向传播
|
||||
test_obs = torch.randn(5, 108, device=device) # 5个智能体
|
||||
actions, log_pis = magail.explore(test_obs)
|
||||
|
||||
print(f"观测形状: {test_obs.shape}")
|
||||
print(f"动作数量: {len(actions)}")
|
||||
print(f"单个动作形状: {actions[0].shape}")
|
||||
print(f"测试成功!✅")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
### ✅ 现在可以做什么
|
||||
|
||||
1. **运行环境测试**:`run_multiagent_env.py` 已经可以正常运行
|
||||
2. **测试算法模块**:Algorithm中的所有模块都已实现
|
||||
3. **开始初步训练**:使用 `train_magail.py`(但需要完善buffer逻辑)
|
||||
|
||||
### ⚠️ 需要您完成的
|
||||
|
||||
1. **调试多智能体buffer**:确保经验正确存储
|
||||
2. **实现专家数据加载**:根据实际数据格式调整
|
||||
3. **验证观测维度**:确认实际观测是否为108维
|
||||
4. **调整训练参数**:根据训练效果调优
|
||||
|
||||
### 🎯 最终目标
|
||||
|
||||
```
|
||||
环境 (Env/) + 算法 (Algorithm/) = 完整的MAGAIL训练系统
|
||||
↓
|
||||
训练出能够模仿专家行为的
|
||||
多智能体自动驾驶策略
|
||||
```
|
||||
|
||||
祝训练顺利!🚀
|
||||
|
||||
480
README.md
480
README.md
@@ -1,401 +1,85 @@
|
||||
# MAGAIL4AutoDrive - 多智能体自动驾驶环境
|
||||
|
||||
基于 MetaDrive 的多智能体自动驾驶仿真与回放环境,支持 Waymo Open Dataset 的专家轨迹回放和自定义策略仿真。
|
||||
|
||||
## 📋 目录
|
||||
|
||||
- [项目简介](#项目简介)
|
||||
- [功能特性](#功能特性)
|
||||
- [环境要求](#环境要求)
|
||||
- [安装步骤](#安装步骤)
|
||||
- [快速开始](#快速开始)
|
||||
- [使用指南](#使用指南)
|
||||
- [项目结构](#项目结构)
|
||||
- [配置说明](#配置说明)
|
||||
- [常见问题](#常见问题)
|
||||
|
||||
## 项目简介
|
||||
|
||||
MAGAIL4AutoDrive 是一个基于 MetaDrive 0.4.3 的多智能体自动驾驶环境,专为模仿学习(Imitation Learning)和强化学习(Reinforcement Learning)研究设计。项目支持从真实世界数据集(如 Waymo Open Dataset)中加载场景,并提供两种核心运行模式:
|
||||
|
||||
- **回放模式(Replay Mode)**:严格按照专家轨迹回放,用于数据可视化和验证
|
||||
- **仿真模式(Simulation Mode)**:使用自定义策略控制车辆,用于算法训练和测试
|
||||
|
||||
## 功能特性
|
||||
|
||||
### 核心功能
|
||||
- ✅ **多智能体支持**:同时控制多辆车辆进行协同仿真
|
||||
- ✅ **专家轨迹回放**:精确回放 Waymo 数据集中的专家驾驶行为
|
||||
- ✅ **自定义策略接口**:灵活接入各种控制策略(IDM、RL 等)
|
||||
- ✅ **智能车道过滤**:自动过滤不在车道上的异常车辆
|
||||
- ✅ **场景时长控制**:支持使用数据集原始场景时长或自定义 horizon
|
||||
- ✅ **丰富的传感器**:LiDAR、侧向检测器、车道线检测器、相机、仪表盘
|
||||
|
||||
### 高级特性
|
||||
- 🎯 指定场景 ID 运行
|
||||
- 🔄 自动场景切换(修复版)
|
||||
- 📊 详细的调试日志输出
|
||||
- 🚗 车辆动态生成与管理
|
||||
- 🎮 支持可视化渲染和无头运行
|
||||
|
||||
## 环境要求
|
||||
|
||||
### 系统要求
|
||||
- **操作系统**:Ubuntu 18.04+ / macOS 10.14+ / Windows 10+
|
||||
- **Python 版本**:3.8 - 3.10
|
||||
- **GPU**:可选,但推荐使用(用于加速渲染)
|
||||
|
||||
### 依赖库
|
||||
```
|
||||
|
||||
metadrive-simulator==0.4.3
|
||||
numpy>=1.19.0
|
||||
pygame>=2.0.0
|
||||
|
||||
```
|
||||
|
||||
## 安装步骤
|
||||
|
||||
### 1. 创建 Conda 环境
|
||||
```
|
||||
|
||||
conda create -n metadrive python=3.10
|
||||
conda activate metadrive
|
||||
|
||||
```
|
||||
|
||||
### 2. 安装 MetaDrive
|
||||
```
|
||||
|
||||
pip install metadrive-simulator==0.4.3
|
||||
|
||||
```
|
||||
|
||||
### 3. 克隆项目
|
||||
```
|
||||
|
||||
git clone https://github.com/your-username/MAGAIL4AutoDrive.git
|
||||
cd MAGAIL4AutoDrive/Env
|
||||
|
||||
```
|
||||
|
||||
### 4. 准备数据集
|
||||
将 Waymo 数据集转换为 MetaDrive 格式并放置在项目目录下:
|
||||
```
|
||||
|
||||
MAGAIL4AutoDrive/Env/
|
||||
├── exp_converted/
|
||||
│ ├── scenario_0/
|
||||
│ ├── scenario_1/
|
||||
│ └── ...
|
||||
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 回放模式(推荐先尝试)
|
||||
```
|
||||
|
||||
|
||||
# 使用场景原始时长回放第一个场景
|
||||
|
||||
python run_multiagent_env.py --mode replay --episodes 1 --use_scenario_duration
|
||||
|
||||
# 回放指定场景
|
||||
|
||||
python run_multiagent_env.py --mode replay --scenario_id 0 --use_scenario_duration
|
||||
|
||||
# 回放多个场景
|
||||
|
||||
python run_multiagent_env.py --mode replay --episodes 3 --use_scenario_duration
|
||||
|
||||
```
|
||||
|
||||
### 仿真模式
|
||||
```
|
||||
|
||||
|
||||
# 使用默认策略运行仿真
|
||||
|
||||
python run_multiagent_env.py --mode simulation --episodes 1
|
||||
|
||||
# 无渲染运行(加速训练)
|
||||
|
||||
python run_multiagent_env.py --mode simulation --episodes 5 --no_render
|
||||
|
||||
```
|
||||
|
||||
## 使用指南
|
||||
|
||||
### 命令行参数
|
||||
|
||||
| 参数 | 类型 | 默认值 | 说明 |
|
||||
|------|------|--------|------|
|
||||
| `--mode` | str | simulation | 运行模式:`replay` 或 `simulation` |
|
||||
| `--data_dir` | str | 当前目录 | Waymo 数据目录路径 |
|
||||
| `--episodes` | int | 1 | 运行回合数 |
|
||||
| `--horizon` | int | 300 | 每回合最大步数 |
|
||||
| `--no_render` | flag | False | 禁用渲染(加速运行) |
|
||||
| `--debug` | flag | False | 启用调试模式 |
|
||||
| `--scenario_id` | int | None | 指定场景 ID |
|
||||
| `--use_scenario_duration` | flag | False | 使用场景原始时长 |
|
||||
| `--no_vehicles` | flag | False | 禁止生成车辆 |
|
||||
| `--no_pedestrians` | flag | False | 禁止生成行人 |
|
||||
| `--no_cyclists` | flag | False | 禁止生成自行车 |
|
||||
|
||||
### 回放模式详解
|
||||
|
||||
回放模式严格按照专家轨迹回放车辆状态,不涉及物理引擎控制。主要用途:
|
||||
- 数据集可视化
|
||||
- 验证数据质量
|
||||
- 生成演示视频
|
||||
|
||||
```bash
|
||||
# 完整参数示例
|
||||
python run_multiagent_env.py \
|
||||
--mode replay \
|
||||
--episodes 1 \
|
||||
--use_scenario_duration \
|
||||
--debug
|
||||
|
||||
# 仅回放车辆,禁止行人和自行车
|
||||
python run_multiagent_env.py \
|
||||
--mode replay \
|
||||
--use_scenario_duration \
|
||||
--no_pedestrians \
|
||||
--no_cyclists
|
||||
```
|
||||
|
||||
**重要提示**:回放模式建议始终启用 `--use_scenario_duration`,否则会出现场景播放完后继续运行的问题。
|
||||
|
||||
### 仿真模式详解
|
||||
|
||||
仿真模式使用自定义策略控制车辆,适合算法开发和测试:
|
||||
|
||||
```bash
|
||||
# 基础仿真
|
||||
python run_multiagent_env.py --mode simulation
|
||||
|
||||
# 长时间训练(无渲染)
|
||||
python run_multiagent_env.py \
|
||||
--mode simulation \
|
||||
--episodes 100 \
|
||||
--horizon 500 \
|
||||
--no_render
|
||||
|
||||
# 仅车辆仿真(用于专注车车交互场景)
|
||||
python run_multiagent_env.py \
|
||||
--mode simulation \
|
||||
--no_pedestrians \
|
||||
--no_cyclists
|
||||
```
|
||||
|
||||
### 自定义策略
|
||||
|
||||
修改 `simple_idm_policy.py` 或创建新的策略类:
|
||||
|
||||
```python
|
||||
class CustomPolicy:
|
||||
def __init__(self, **kwargs):
|
||||
# 初始化策略参数
|
||||
pass
|
||||
|
||||
def act(self, observation=None):
|
||||
# 返回动作 [steering, acceleration]
|
||||
# steering: [-1, 1]
|
||||
# acceleration: [-1, 1]
|
||||
return [0.0, 0.5]
|
||||
```
|
||||
|
||||
在 `run_multiagent_env.py` 中使用:
|
||||
```
|
||||
|
||||
from custom_policy import CustomPolicy
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={...},
|
||||
agent2policy=CustomPolicy()
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
|
||||
MAGAIL4AutoDrive/Env/
|
||||
├── run_multiagent_env.py \# 主运行脚本
|
||||
├── scenario_env.py \# 多智能体场景环境
|
||||
├── replay_policy.py \# 专家轨迹回放策略
|
||||
├── simple_idm_policy.py \# IDM 策略实现
|
||||
├── utils.py \# 工具函数
|
||||
├── ENHANCED_USAGE_GUIDE.md \# 详细使用指南
|
||||
├── README.md \# 本文档
|
||||
└── exp_converted/ \# Waymo 数据集(需自行准备)
|
||||
├── scenario_0/
|
||||
├── scenario_1/
|
||||
└── ...
|
||||
|
||||
```
|
||||
|
||||
### 核心文件说明
|
||||
|
||||
**run_multiagent_env.py**
|
||||
- 主入口脚本
|
||||
- 处理命令行参数
|
||||
- 管理回放和仿真两种模式的运行逻辑
|
||||
|
||||
**scenario_env.py**
|
||||
- 自定义多智能体环境类
|
||||
- 车辆生成与管理
|
||||
- 车道过滤逻辑
|
||||
- 观测空间定义
|
||||
|
||||
**replay_policy.py**
|
||||
- 专家轨迹回放策略
|
||||
- 逐帧状态查询
|
||||
- 轨迹完成判断
|
||||
|
||||
**simple_idm_policy.py**
|
||||
- 简单的恒速策略示例
|
||||
- 可作为自定义策略的模板
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 环境配置参数
|
||||
|
||||
在 `scenario_env.py` 的 `default_config()` 中可修改:
|
||||
|
||||
```python
|
||||
config.update(dict(
|
||||
data_directory=None, # 数据目录
|
||||
num_controlled_agents=3, # 可控车辆数量(仅仿真模式)
|
||||
horizon=1000, # 最大步数
|
||||
filter_offroad_vehicles=True, # 是否过滤车道外车辆
|
||||
lane_tolerance=3.0, # 车道容差(米)
|
||||
replay_mode=False, # 是否为回放模式
|
||||
specific_scenario_id=None, # 指定场景 ID
|
||||
use_scenario_duration=False, # 使用场景原始时长
|
||||
# 对象类型过滤选项
|
||||
spawn_vehicles=True, # 是否生成车辆
|
||||
spawn_pedestrians=True, # 是否生成行人
|
||||
spawn_cyclists=True, # 是否生成自行车
|
||||
))
|
||||
```
|
||||
|
||||
### 传感器配置
|
||||
|
||||
默认启用的传感器(可在环境初始化时修改):
|
||||
- **LiDAR**:80 条激光,探测距离 30 米
|
||||
- **侧向检测器**:10 条激光,探测距离 8 米
|
||||
- **车道线检测器**:10 条激光,探测距离 3 米
|
||||
- **主相机**:分辨率 1200x900
|
||||
- **仪表盘**:车辆状态信息
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q1: 回放模式为什么超出数据集的最大帧数还在继续?
|
||||
**A**: 需要添加 `--use_scenario_duration` 参数。修复版本已在 `scenario_env.py` 中添加了自动检测机制。
|
||||
|
||||
### Q2: 如何切换不同的场景?
|
||||
**A**:
|
||||
- 方法一:使用 `--scenario_id` 指定场景
|
||||
- 方法二:使用 `--episodes N` 自动遍历 N 个场景
|
||||
|
||||
### Q3: 为什么有些车辆没有出现?
|
||||
**A**: 启用了车道过滤功能(`filter_offroad_vehicles=True`),不在车道上的车辆会被过滤。可以通过设置 `lane_tolerance` 调整容差或关闭此功能。
|
||||
|
||||
### Q4: 如何提高运行速度?
|
||||
**A**:
|
||||
- 使用 `--no_render` 禁用可视化
|
||||
- 减少 `num_controlled_agents` 数量
|
||||
- 使用 GPU 加速
|
||||
|
||||
### Q5: 如何控制场景中的对象类型?
|
||||
**A**: 使用对象过滤参数:
|
||||
```bash
|
||||
# 仅车辆,无行人和自行车
|
||||
python run_multiagent_env.py --mode replay --no_pedestrians --no_cyclists
|
||||
|
||||
# 仅行人和自行车,无车辆(特殊场景)
|
||||
python run_multiagent_env.py --mode replay --no_vehicles
|
||||
|
||||
# 调试模式查看过滤统计
|
||||
python run_multiagent_env.py --mode replay --debug --no_pedestrians
|
||||
```
|
||||
|
||||
### Q6: 为什么有些车辆生成在空中?
|
||||
**A**: 已在 v1.2.0 中修复。现在所有车辆位置都只使用 2D 坐标(x, y),z 坐标设为 0,让 MetaDrive 自动处理高度,确保车辆贴在地面上。
|
||||
|
||||
### Q7: 如何导出观测数据?
|
||||
**A**: 在 `run_multiagent_env.py` 中添加数据保存逻辑:
|
||||
```python
|
||||
import pickle
|
||||
|
||||
obs_data = []
|
||||
while True:
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
obs_data.append(obs)
|
||||
if dones["__all__"]:
|
||||
break
|
||||
|
||||
with open('observations.pkl', 'wb') as f:
|
||||
pickle.dump(obs_data, f)
|
||||
```
|
||||
|
||||
## 更新日志
|
||||
|
||||
### v1.2.0 (2025-10-26)
|
||||
- ✅ 修复车辆生成高度问题(车辆悬空)
|
||||
- ✅ 添加对象类型过滤功能(车辆/行人/自行车)
|
||||
- ✅ 新增命令行参数:`--no_vehicles`、`--no_pedestrians`、`--no_cyclists`
|
||||
- ✅ 改进调试信息输出,显示各类型对象统计
|
||||
- ✅ 优化位置处理逻辑,只使用 2D 坐标避免高度问题
|
||||
|
||||
### v1.1.0 (2025-10-26)
|
||||
- ✅ 修复回放模式超出场景时长问题
|
||||
- ✅ 添加场景自动切换功能
|
||||
- ✅ 改进 `replay_policy.py`,新增 `is_finished()` 方法
|
||||
- ✅ 优化 `scenario_env.py` 的 done 判断逻辑
|
||||
- ✅ 修复多回合运行时的对象清理问题
|
||||
|
||||
### v1.0.0 (初始版本)
|
||||
- 基础多智能体环境实现
|
||||
- 回放和仿真两种模式
|
||||
- 车道过滤功能
|
||||
- Waymo 数据集支持
|
||||
|
||||
## 贡献指南
|
||||
|
||||
欢迎提交 Issue 和 Pull Request!
|
||||
|
||||
### 提交 Issue
|
||||
- 请详细描述问题和复现步骤
|
||||
- 附上运行日志和错误信息
|
||||
- 说明运行环境(OS、Python 版本等)
|
||||
|
||||
### 提交 PR
|
||||
- Fork 本项目
|
||||
- 创建特性分支:`git checkout -b feature/your-feature`
|
||||
- 提交更改:`git commit -m 'Add some feature'`
|
||||
- 推送分支:`git push origin feature/your-feature`
|
||||
- 提交 Pull Request
|
||||
|
||||
## 许可证
|
||||
|
||||
本项目基于 MIT 许可证开源。
|
||||
|
||||
## 致谢
|
||||
|
||||
- [MetaDrive](https://github.com/metadriverse/metadrive) - 优秀的驾驶仿真平台
|
||||
- [Waymo Open Dataset](https://waymo.com/open/) - 高质量的自动驾驶数据集
|
||||
|
||||
## 联系方式
|
||||
|
||||
如有问题或建议,请通过以下方式联系:
|
||||
- GitHub Issues: [项目 Issues 页面]
|
||||
- Email: huangfukk@xxx.com
|
||||
# MAGAIL4AutoDrive
|
||||
### 1.1 环境搭建
|
||||
环境核心代码封装于`Env`文件夹,通过运行`run_multiagent_env.py`即可启动多智能体交互环境,该脚本的核心功能为读取各智能体(车辆)的动作指令,并将其传入`env.step()`方法中完成仿真执行。
|
||||
|
||||
**性能优化版本:** 针对原始版本FPS低(15帧)和CPU利用率不足的问题,已提供多个优化版本:
|
||||
- `run_multiagent_env_fast.py` - 激光雷达优化版(30-60 FPS,2-4倍提升)⭐推荐
|
||||
- `run_multiagent_env_parallel.py` - 多进程并行版(300-600 steps/s总吞吐量,充分利用多核CPU)⭐⭐推荐
|
||||
- 详见 `Env/QUICK_START.md` 快速使用指南
|
||||
|
||||
当前已初步实现`Env.senario_env.MultiAgentScenarioEnv.reset()`车辆生成函数,具体逻辑如下:首先读取专家数据集中各车辆的初始位姿信息;随后对原始数据进行清洗,剔除车辆 Agent 实例信息,记录核心参数(车辆 ID、初始生成位置、朝向角、生成时间戳、目标终点坐标);最后调用`_spawn_controlled_agents()`函数,依据清洗后的参数在指定时间、指定位置生成搭载自动驾驶算法的可控车辆。
|
||||
|
||||
**✅ 已解决:车辆生成位置偏差问题**
|
||||
- **问题描述**:部分车辆生成于草坪、停车场等非车道区域,原因是专家数据记录误差或停车场特殊标注
|
||||
- **解决方案**:实现了`_is_position_on_lane()`车道区域检测机制和`_filter_valid_spawn_positions()`过滤函数
|
||||
- 检测逻辑:通过`point_on_lane()`判断位置是否在车道上,支持容差参数(默认3米)处理边界情况
|
||||
- 双重检测:优先使用精确检测,失败时使用容差范围检测,确保车道边缘车辆不被误过滤
|
||||
- 自动过滤:在`reset()`时自动过滤非车道区域车辆,并输出过滤统计信息
|
||||
- **配置参数**:
|
||||
- `filter_offroad_vehicles=True`:启用/禁用车道过滤功能
|
||||
- `lane_tolerance=3.0`:车道检测容差(米),可根据场景调整
|
||||
- `max_controlled_vehicles=10`:限制最大车辆数(可选)
|
||||
- **使用示例**:在环境配置中设置上述参数即可自动启用,运行时会显示过滤信息(如"过滤5辆,保留45辆")
|
||||
|
||||
|
||||
### 1.2 观测获取
|
||||
观测信息采集功能通过`Env.senario_env.MultiAgentScenarioEnv._get_all_obs()`函数实现,该函数支持遍历所有可控车辆并采集多维度观测数据,当前已实现的观测维度包括:车辆实时位置坐标、朝向角、行驶速度、雷达扫描点云(含障碍物与车道线特征)、导航信息(因场景复杂度较低,暂采用目标终点坐标直接作为导航输入)。
|
||||
|
||||
**✅ 已解决:红绿灯信息采集问题**
|
||||
- **问题描述**:
|
||||
- 问题1:部分红绿灯状态值为`None`,导致异常或错误判断
|
||||
- 问题2:车道分段设计时,部分区域车辆无法匹配到红绿灯
|
||||
- **解决方案**:实现了`_get_traffic_light_state()`优化方法,采用多级检测策略
|
||||
- **方法1(优先)**:从车辆导航模块`vehicle.navigation.current_lane`获取当前车道,直接查询红绿灯状态(高效,自动处理车道分段)
|
||||
- **方法2(兜底)**:遍历所有车道,通过`point_on_lane()`判断车辆位置,查找对应红绿灯(处理导航失败情况)
|
||||
- **异常处理**:对状态为`None`的情况返回0(无红绿灯),所有异常均有try-except保护,确保不会中断程序
|
||||
- **返回值规范**:0=无红绿灯/未知, 1=绿灯, 2=黄灯, 3=红灯
|
||||
- **优势**:双重保障机制,优先用高效方法,失败时自动切换到兜底方案,确保所有场景都能正确获取红绿灯信息
|
||||
|
||||
|
||||
### 1.3 算法模块
|
||||
本方案的核心创新点在于对 GAIL 算法的判别器进行改进,使其适配多智能体场景下 “输入长度动态变化”(车辆数量不固定)的特性,实现对整体交互场景的分类判断,进而满足多智能体自动驾驶环境的训练需求。算法核心代码封装于`Algorithm.bert.Bert`类,具体实现逻辑如下:
|
||||
|
||||
1. 输入层处理:输入数据为维度`(N, input_dim)`的矩阵(其中`N`为当前场景车辆数量,`input_dim`为单车辆固定观测维度),初始化`Bert`类时需设置`input_dim`,确保输入维度匹配;
|
||||
2. 嵌入层与位置编码:通过`projection`线性投影层将单车辆观测维度映射至预设的嵌入维度(`embed_dim`),随后叠加可学习的位置编码(`pos_embed`),以捕捉观测序列的时序与空间关联信息;
|
||||
3. Transformer 特征提取:嵌入后的特征向量输入至多层`Transformer`网络(层数由`num_layers`参数控制),完成高阶特征交互与抽象;
|
||||
4. 分类头设计:提供两种特征聚合与分类方案:若开启`CLS`模式,在嵌入层前拼接 1 个可学习的`CLS`标记,最终取`CLS`标记对应的特征向量输入全连接层完成分类;若关闭`CLS`模式,则对`Transformer`输出的所有车辆特征向量进行序列维度均值池化,再将池化后的全局特征输入全连接层。分类器支持可选的`Tanh`激活函数,以适配不同场景下的输出分布需求。
|
||||
|
||||
|
||||
### 1.4 动作执行
|
||||
在当前环境测试阶段,暂沿用腾达的动作执行框架:为每辆可控车辆分配独立的`policy`模型,将单车辆观测数据输入对应`policy`得到动作指令后,传入`env.step()`完成仿真;同时在`before_step`阶段调用`_set_action()`函数,将动作指令绑定至车辆实例,最终由 MetaDrive 仿真系统完成物理动力学计算与场景渲染。
|
||||
|
||||
后续优化方向为构建 "参数共享式统一模型框架",具体设计如下:所有车辆共用 1 个`policy`模型,通过参数共享机制实现模型的全局统一维护。该框架具备三重优势:一是避免多车辆独立模型带来的训练偏差(如不同模型训练程度不一致);二是解决车辆数量动态变化时的模型管理问题(车辆新增无需额外初始化模型,车辆减少不丢失模型训练信息);三是支持动作指令的并行计算,可显著提升每一步决策的迭代效率,适配大规模多智能体交互场景的训练需求。
|
||||
|
||||
---
|
||||
|
||||
**Happy Driving! 🚗💨**
|
||||
## 问题解决总结
|
||||
|
||||
### ✅ 已完成的优化
|
||||
|
||||
1. **车辆生成位置偏差** - 实现车道区域检测和自动过滤,配置参数:`filter_offroad_vehicles`, `lane_tolerance`, `max_controlled_vehicles`
|
||||
2. **红绿灯信息采集** - 采用双重检测策略(导航模块+遍历兜底),处理None状态和车道分段问题
|
||||
3. **性能优化** - 提供多个优化版本(fast/parallel),FPS从15提升到30-60,支持多进程充分利用CPU
|
||||
|
||||
### 🧪 测试方法
|
||||
```bash
|
||||
# 测试车道过滤和红绿灯检测
|
||||
python Env/test_lane_filter.py
|
||||
|
||||
# 运行标准版本(带过滤)
|
||||
python Env/run_multiagent_env.py
|
||||
|
||||
# 运行高性能版本
|
||||
python Env/run_multiagent_env_fast.py
|
||||
```
|
||||
|
||||
### 📝 配置示例
|
||||
```python
|
||||
config = {
|
||||
# 车道过滤
|
||||
"filter_offroad_vehicles": True, # 启用车道过滤
|
||||
"lane_tolerance": 3.0, # 容差范围(米)
|
||||
"max_controlled_vehicles": 10, # 最大车辆数
|
||||
# 其他配置...
|
||||
}
|
||||
```
|
||||
|
||||
103
analyze_expert_data.py
Normal file
103
analyze_expert_data.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
分析Waymo专家数据的结构
|
||||
|
||||
运行: python analyze_expert_data.py
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
def analyze_pkl_file(filepath):
|
||||
"""分析单个pkl文件的结构"""
|
||||
print(f"\n{'='*80}")
|
||||
print(f"分析文件: {os.path.basename(filepath)}")
|
||||
print(f"{'='*80}")
|
||||
|
||||
with open(filepath, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
print(f"\n1. 数据类型: {type(data)}")
|
||||
print(f" 文件大小: {os.path.getsize(filepath) / 1024:.1f} KB")
|
||||
|
||||
if isinstance(data, dict):
|
||||
print(f"\n2. 字典结构:")
|
||||
print(f" 键数量: {len(data)}")
|
||||
print(f" 键列表: {list(data.keys())[:10]}")
|
||||
|
||||
# 详细分析每个键
|
||||
for i, (key, value) in enumerate(list(data.items())[:5]):
|
||||
print(f"\n 键 [{i+1}]: '{key}'")
|
||||
print(f" 类型: {type(value)}")
|
||||
|
||||
if isinstance(value, dict):
|
||||
print(f" 子键: {list(value.keys())}")
|
||||
|
||||
# 分析子字典
|
||||
for subkey, subvalue in list(value.items())[:3]:
|
||||
print(f" - {subkey}: {type(subvalue)}", end="")
|
||||
if isinstance(subvalue, np.ndarray):
|
||||
print(f" shape={subvalue.shape}, dtype={subvalue.dtype}")
|
||||
elif isinstance(subvalue, dict):
|
||||
print(f" keys={list(subvalue.keys())[:5]}")
|
||||
elif isinstance(subvalue, (list, tuple)):
|
||||
print(f" len={len(subvalue)}")
|
||||
else:
|
||||
print(f" = {subvalue}")
|
||||
|
||||
elif isinstance(value, np.ndarray):
|
||||
print(f" Shape: {value.shape}, dtype: {value.dtype}")
|
||||
print(f" 示例: {value.flatten()[:5]}")
|
||||
elif isinstance(value, (list, tuple)):
|
||||
print(f" 长度: {len(value)}")
|
||||
if len(value) > 0:
|
||||
print(f" 第一个元素: {type(value[0])}")
|
||||
|
||||
elif isinstance(data, (list, tuple)):
|
||||
print(f"\n2. 列表/元组结构:")
|
||||
print(f" 长度: {len(data)}")
|
||||
if len(data) > 0:
|
||||
print(f" 第一个元素类型: {type(data[0])}")
|
||||
if isinstance(data[0], dict):
|
||||
print(f" 第一个元素的键: {list(data[0].keys())}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def find_trajectory_data(data, max_depth=3, current_depth=0, path=""):
|
||||
"""递归查找可能包含轨迹数据的字段"""
|
||||
if current_depth > max_depth:
|
||||
return
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
new_path = f"{path}.{key}" if path else key
|
||||
|
||||
# 查找可能是轨迹的数据(通常是时间序列数组)
|
||||
if isinstance(value, np.ndarray):
|
||||
if len(value.shape) >= 2 and value.shape[0] > 10: # 可能是时间序列
|
||||
print(f" 🎯 可能的轨迹数据: {new_path}")
|
||||
print(f" Shape: {value.shape}, dtype: {value.dtype}")
|
||||
print(f" 前3个值: {value[:3]}")
|
||||
|
||||
# 继续递归
|
||||
elif isinstance(value, dict):
|
||||
find_trajectory_data(value, max_depth, current_depth + 1, new_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 分析第一个数据文件
|
||||
data_dir = "Env/exp_converted/exp_converted_0"
|
||||
pkl_files = [f for f in os.listdir(data_dir) if f.startswith('sd_waymo')]
|
||||
|
||||
if pkl_files:
|
||||
filepath = os.path.join(data_dir, pkl_files[0])
|
||||
data = analyze_pkl_file(filepath)
|
||||
|
||||
print(f"\n\n{'='*80}")
|
||||
print("查找可能的轨迹数据...")
|
||||
print(f"{'='*80}")
|
||||
find_trajectory_data(data)
|
||||
else:
|
||||
print("未找到数据文件!")
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251021_205002/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251021_205002/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251021_205133/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251021_205133/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251021_205320/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251021_205320/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251021_205507/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251021_205507/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251021_205656/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251021_205656/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251021_205825/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251021_205825/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251021_205842/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251021_205842/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251021_210006/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251021_210006/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251021_210055/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251021_210055/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251021_210302/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251021_210302/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251021_210523/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251021_210523/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251021_210644/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251021_210644/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251022_160448/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251022_160448/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251022_161725/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251022_161725/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251022_161806/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251022_161806/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251022_161924/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251022_161924/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251022_162104/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251022_162104/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251022_162133/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251022_162133/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251022_162311/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251022_162311/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251022_162445/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251022_162445/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251022_162527/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251022_162527/models/best_model/model.pth
Normal file
Binary file not shown.
Binary file not shown.
BIN
outputs/magail_20251022_162558/models/best_model/model.pth
Normal file
BIN
outputs/magail_20251022_162558/models/best_model/model.pth
Normal file
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user