finish filter

This commit is contained in:
QuanyiLi
2023-05-07 21:29:14 +01:00
parent 468cff9040
commit 1c81399256
23 changed files with 84 additions and 0 deletions

View File

@@ -0,0 +1,63 @@
import os
import os.path
from metadrive.type import MetaDriveType
from scenarionet import SCENARIONET_DATASET_PATH
from scenarionet.builder.filters import ScenarioFilter
from scenarionet.builder.utils import combine_multiple_dataset
def test_filter_dataset():
"""
It is just a runnable test
"""
dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")]
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "nuplan"))
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "waymo"))
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "pg"))
output_path = os.path.join(SCENARIONET_DATASET_PATH, "combined_dataset")
# ========================= test 1 =========================
# nuscenes data has no light
# light_condition = ScenarioFilter.make(ScenarioFilter.has_traffic_light)
sdc_driving_condition = ScenarioFilter.make(ScenarioFilter.sdc_moving_dist,
target_dist=30,
condition="greater")
summary, mapping = combine_multiple_dataset(output_path,
*dataset_paths,
force_overwrite=True,
try_generate_missing_file=True,
filters=[sdc_driving_condition]
)
assert len(summary) > 0
# ========================= test 2 =========================
num_condition = ScenarioFilter.make(ScenarioFilter.object_number,
number_threshold=50,
object_type=MetaDriveType.PEDESTRIAN,
condition="greater")
summary, mapping = combine_multiple_dataset(output_path,
*dataset_paths,
force_overwrite=True,
try_generate_missing_file=True,
filters=[num_condition])
assert len(summary) > 0
# ========================= test 3 =========================
traffic_light = ScenarioFilter.make(ScenarioFilter.has_traffic_light)
summary, mapping = combine_multiple_dataset(output_path,
*dataset_paths,
force_overwrite=True,
try_generate_missing_file=True,
filters=[traffic_light])
assert len(summary) > 0
if __name__ == '__main__':
test_filter_dataset()

View File

@@ -36,6 +36,16 @@ def test_filter_dataset():
break
assert in_
sdc_driving_condition = ScenarioFilter.make(ScenarioFilter.sdc_moving_dist,
target_dist=5,
condition="greater")
summary, mapping = combine_multiple_dataset(output_path,
*dataset_paths,
force_overwrite=True,
try_generate_missing_file=True,
filters=[sdc_driving_condition])
assert len(summary) == 8
# ========================= test 2 =========================
num_condition = ScenarioFilter.make(ScenarioFilter.object_number,
@@ -53,6 +63,17 @@ def test_filter_dataset():
for a in answer:
assert a in summary
num_condition = ScenarioFilter.make(ScenarioFilter.object_number,
number_threshold=50,
condition="greater")
summary, mapping = combine_multiple_dataset(output_path,
*dataset_paths,
force_overwrite=True,
try_generate_missing_file=True,
filters=[num_condition])
assert len(summary) > 0
if __name__ == '__main__':
test_filter_dataset()