diff --git a/scenarionet/builder/filters.py b/scenarionet/builder/filters.py index 6b89991..76c1521 100644 --- a/scenarionet/builder/filters.py +++ b/scenarionet/builder/filters.py @@ -53,6 +53,18 @@ class ScenarioFilter: def no_traffic_light(metadata, file_path): return metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_TRAFFIC_LIGHTS] == 0 + @staticmethod + def no_pedestrian(metadata, file_path): + """Return True if the scenario has no pedestrians""" + num = metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_OBJECTS_EACH_TYPE].get("PEDESTRIAN", 0) + return num == 0 + + @staticmethod + def no_cyclist(metadata, file_path): + """Return True if the scenario has no cyclists""" + num = metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_OBJECTS_EACH_TYPE].get("CYCLIST", 0) + return num == 0 + @staticmethod def no_overpass(metadata, file_path): """ diff --git a/scenarionet/filter.py b/scenarionet/filter.py index 6d0fe7e..01987bd 100644 --- a/scenarionet/filter.py +++ b/scenarionet/filter.py @@ -3,6 +3,7 @@ desc = "Filter unwanted scenarios out and build a new database" if __name__ == '__main__': import argparse + from metadrive.type import MetaDriveType from scenarionet.builder.filters import ScenarioFilter from scenarionet.builder.utils import merge_database @@ -59,6 +60,31 @@ if __name__ == '__main__': "--exclude_ids", nargs='+', default=[], help="Scenarios with indicated name will NOT be selected" ) + parser.add_argument( + "--num_vehicle", + action="store_true", + help="add this flag to select cases with vehicle_num < max_num_vehicle" + ) + + parser.add_argument( + "--max_num_vehicle", + default=50, + type=int, + help="case will be selected if num_vehicle < this argument" + ) + + parser.add_argument( + "--no_pedestrian", + action="store_true", + help="Scenarios with pedestrians WON'T be selected" + ) + + parser.add_argument( + "--no_cyclist", + action="store_true", + help="Scenarios with cyclists WON'T be selected" + ) + args = parser.parse_args() target = args.sdc_moving_dist_min obj_threshold = args.max_num_object @@ -75,6 +101,19 @@ if __name__ == '__main__': filters.append(ScenarioFilter.make(ScenarioFilter.no_traffic_light)) if args.id_filter: filters.append(ScenarioFilter.make(ScenarioFilter.id_filter, ids=args.exclude_ids)) + if args.num_vehicle: + filters.append( + ScenarioFilter.make( + ScenarioFilter.object_number, + number_threshold=args.max_num_vehicle, + object_type=MetaDriveType.VEHICLE, + condition=ScenarioFilter.SMALLER + ) + ) + if args.no_pedestrian: + filters.append(ScenarioFilter.make(ScenarioFilter.no_pedestrian)) + if args.no_cyclist: + filters.append(ScenarioFilter.make(ScenarioFilter.no_cyclist)) if len(filters) == 0: raise ValueError("No filters are applied. Abort.")