110 lines
3.7 KiB
Python
110 lines
3.7 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
数据集检查脚本:统计可用的场景数量
|
||
"""
|
||
|
||
import pickle
|
||
import os
|
||
from pathlib import Path
|
||
from metadrive.engine.asset_loader import AssetLoader
|
||
|
||
def check_dataset(data_dir, subfolder="exp_filtered"):
|
||
"""
|
||
检查数据集中的场景数量
|
||
|
||
Args:
|
||
data_dir: 数据根目录
|
||
subfolder: 数据子目录(exp_filtered 或 exp_converted)
|
||
"""
|
||
print("=" * 80)
|
||
print("数据集检查工具")
|
||
print("=" * 80)
|
||
|
||
# 获取完整路径
|
||
full_path = AssetLoader.file_path(data_dir, subfolder, unix_style=False)
|
||
print(f"\n数据集路径: {full_path}")
|
||
|
||
# 检查文件结构
|
||
if not os.path.exists(full_path):
|
||
print(f"❌ 错误:路径不存在: {full_path}")
|
||
return
|
||
|
||
print(f"✅ 路径存在")
|
||
|
||
# 读取数据集映射
|
||
mapping_file = os.path.join(full_path, "dataset_mapping.pkl")
|
||
if os.path.exists(mapping_file):
|
||
print(f"\n读取 dataset_mapping.pkl...")
|
||
with open(mapping_file, 'rb') as f:
|
||
dataset_mapping = pickle.load(f)
|
||
|
||
print(f"✅ 数据集映射文件存在")
|
||
print(f" 映射的场景数量: {len(dataset_mapping)}")
|
||
|
||
# 统计各个子目录的分布
|
||
subdirs = {}
|
||
for filename, subdir in dataset_mapping.items():
|
||
if subdir not in subdirs:
|
||
subdirs[subdir] = []
|
||
subdirs[subdir].append(filename)
|
||
|
||
print(f"\n场景分布:")
|
||
for subdir, files in subdirs.items():
|
||
print(f" {subdir}: {len(files)} 个场景")
|
||
|
||
# 读取数据集摘要
|
||
summary_file = os.path.join(full_path, "dataset_summary.pkl")
|
||
if os.path.exists(summary_file):
|
||
print(f"\n读取 dataset_summary.pkl...")
|
||
with open(summary_file, 'rb') as f:
|
||
dataset_summary = pickle.load(f)
|
||
|
||
print(f"✅ 数据集摘要文件存在")
|
||
print(f" 摘要的场景数量: {len(dataset_summary)}")
|
||
|
||
# 打印前几个场景的ID
|
||
print(f"\n前10个场景ID:")
|
||
for i, (scenario_id, info) in enumerate(list(dataset_summary.items())[:10]):
|
||
if isinstance(info, dict):
|
||
track_length = info.get('track_length', 'N/A')
|
||
num_objects = info.get('number_summary', {}).get('num_objects', 'N/A')
|
||
print(f" {i}: {scenario_id[:16]}... (时长: {track_length}, 对象数: {num_objects})")
|
||
|
||
# 检查实际文件
|
||
print(f"\n检查实际文件...")
|
||
pkl_files = list(Path(full_path).rglob("*.pkl"))
|
||
# 排除dataset_mapping和dataset_summary
|
||
scenario_files = [f for f in pkl_files if f.name not in ["dataset_mapping.pkl", "dataset_summary.pkl"]]
|
||
print(f" 实际场景文件数量: {len(scenario_files)}")
|
||
|
||
# 检查子目录
|
||
print(f"\n子目录结构:")
|
||
for item in os.listdir(full_path):
|
||
item_path = os.path.join(full_path, item)
|
||
if os.path.isdir(item_path):
|
||
pkl_count = len([f for f in os.listdir(item_path) if f.endswith('.pkl')])
|
||
print(f" {item}/: {pkl_count} 个pkl文件")
|
||
|
||
print("\n" + "=" * 80)
|
||
print("检查完成")
|
||
print("=" * 80)
|
||
|
||
# 返回场景数量
|
||
return len(dataset_mapping) if 'dataset_mapping' in locals() else 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
|
||
|
||
print("\n【检查 exp_filtered 数据集】")
|
||
num_filtered = check_dataset(WAYMO_DATA_DIR, "exp_filtered")
|
||
|
||
print("\n\n【检查 exp_converted 数据集】")
|
||
num_converted = check_dataset(WAYMO_DATA_DIR, "exp_converted")
|
||
|
||
print("\n\n总结:")
|
||
print(f" exp_filtered: {num_filtered} 个场景")
|
||
print(f" exp_converted: {num_converted} 个场景")
|
||
|