Fix bug for generating dataset (#2)
* update parameters for scripts * update write function * modify waymo script * use exist ok instead of overwrite * remove TODO * rename to comvine_dataset * use exist_ok and force_overwrite together * format * test * creat env for each thread * restore * fix bug * fix pg bug * fix * fix bug * add assert * don't return done info * to dict * add test * only compare sdc * no store mao * release memory * add start index to argumen * test * format some settings/flags * add tmp path * add tmp dir * test all scripts * suppress warning * suppress warning * format * test memory leak * fix memory leak * remove useless functions * imap * thread-1 process for avoiding memory leak * add list() * rename * verify existence * verify completeness * test * add test * add default value * add limit * use script * add anotation * test script * fix bug * fix bug * add author4 * add overwrite * fix bug * fix * combine overwrite * fix bug * gpu007 * add result save dir * adjust sequence * fix test bug * disable bash scri[t * add episode length limit * move scripts to root dir * format * fix test
This commit is contained in:
@@ -3,3 +3,6 @@ import os
|
|||||||
SCENARIONET_PACKAGE_PATH = os.path.dirname(os.path.abspath(__file__))
|
SCENARIONET_PACKAGE_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||||
SCENARIONET_REPO_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
SCENARIONET_REPO_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
SCENARIONET_DATASET_PATH = os.path.join(SCENARIONET_REPO_PATH, "dataset")
|
SCENARIONET_DATASET_PATH = os.path.join(SCENARIONET_REPO_PATH, "dataset")
|
||||||
|
|
||||||
|
# use this dir to store junk files generated by testing
|
||||||
|
TMP_PATH = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp")
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import pkg_resources # for suppress warning
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -5,6 +6,7 @@ import os.path as osp
|
|||||||
import pickle
|
import pickle
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
|
import tqdm
|
||||||
|
|
||||||
from metadrive.scenario.scenario_description import ScenarioDescription
|
from metadrive.scenario.scenario_description import ScenarioDescription
|
||||||
|
|
||||||
@@ -25,13 +27,19 @@ def try_generating_summary(file_folder):
|
|||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
def combine_multiple_dataset(
|
def combine_dataset(
|
||||||
output_path, *dataset_paths, force_overwrite=False, try_generate_missing_file=True, filters: List[Callable] = None
|
output_path,
|
||||||
|
*dataset_paths,
|
||||||
|
exist_ok=False,
|
||||||
|
overwrite=False,
|
||||||
|
try_generate_missing_file=True,
|
||||||
|
filters: List[Callable] = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Combine multiple datasets. Each dataset should have a dataset_summary.pkl
|
Combine multiple datasets. Each dataset should have a dataset_summary.pkl
|
||||||
:param output_path: The path to store the output dataset
|
:param output_path: The path to store the output dataset
|
||||||
:param force_overwrite: If True, overwrite the output_path even if it exists
|
:param exist_ok: If True, though the output_path already exist, still write into it
|
||||||
|
:param overwrite: If True, overwrite existing dataset_summary.pkl and mapping.pkl. Otherwise, raise error
|
||||||
:param try_generate_missing_file: If dataset_summary.pkl and mapping.pkl are missing, whether to try generating them
|
:param try_generate_missing_file: If dataset_summary.pkl and mapping.pkl are missing, whether to try generating them
|
||||||
:param dataset_paths: Path of each dataset
|
:param dataset_paths: Path of each dataset
|
||||||
:param filters: a set of filters to choose which scenario to be selected and added into this combined dataset
|
:param filters: a set of filters to choose which scenario to be selected and added into this combined dataset
|
||||||
@@ -39,24 +47,26 @@ def combine_multiple_dataset(
|
|||||||
"""
|
"""
|
||||||
filters = filters or []
|
filters = filters or []
|
||||||
output_abs_path = osp.abspath(output_path)
|
output_abs_path = osp.abspath(output_path)
|
||||||
if os.path.exists(output_abs_path):
|
os.makedirs(output_abs_path, exist_ok=exist_ok)
|
||||||
if not force_overwrite:
|
summary_file = osp.join(output_abs_path, ScenarioDescription.DATASET.SUMMARY_FILE)
|
||||||
raise FileExistsError("Output path already exists!")
|
mapping_file = osp.join(output_abs_path, ScenarioDescription.DATASET.MAPPING_FILE)
|
||||||
|
for file in [summary_file, mapping_file]:
|
||||||
|
if os.path.exists(file):
|
||||||
|
if overwrite:
|
||||||
|
os.remove(file)
|
||||||
else:
|
else:
|
||||||
shutil.rmtree(output_abs_path)
|
raise FileExistsError("{} already exists at: {}!".format(file, output_abs_path))
|
||||||
os.makedirs(output_abs_path, exist_ok=False)
|
|
||||||
|
|
||||||
summaries = {}
|
summaries = {}
|
||||||
mappings = {}
|
mappings = {}
|
||||||
|
|
||||||
# collect
|
# collect
|
||||||
for dataset_path in dataset_paths:
|
for dataset_path in tqdm.tqdm(dataset_paths):
|
||||||
abs_dir_path = osp.abspath(dataset_path)
|
abs_dir_path = osp.abspath(dataset_path)
|
||||||
# summary
|
# summary
|
||||||
assert osp.exists(abs_dir_path), "Wrong dataset path. Can not find dataset at: {}".format(abs_dir_path)
|
assert osp.exists(abs_dir_path), "Wrong dataset path. Can not find dataset at: {}".format(abs_dir_path)
|
||||||
if not osp.exists(osp.join(abs_dir_path, ScenarioDescription.DATASET.SUMMARY_FILE)):
|
if not osp.exists(osp.join(abs_dir_path, ScenarioDescription.DATASET.SUMMARY_FILE)):
|
||||||
if try_generate_missing_file:
|
if try_generate_missing_file:
|
||||||
# TODO add test for 1. number dataset 2. missing summary dataset 3. missing mapping dataset
|
|
||||||
summary = try_generating_summary(abs_dir_path)
|
summary = try_generating_summary(abs_dir_path)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError("Can not find summary file for dataset: {}".format(abs_dir_path))
|
raise FileNotFoundError("Can not find summary file for dataset: {}".format(abs_dir_path))
|
||||||
@@ -96,8 +106,6 @@ def combine_multiple_dataset(
|
|||||||
summaries.pop(file)
|
summaries.pop(file)
|
||||||
mappings.pop(file)
|
mappings.pop(file)
|
||||||
|
|
||||||
summary_file = osp.join(output_abs_path, ScenarioDescription.DATASET.SUMMARY_FILE)
|
|
||||||
mapping_file = osp.join(output_abs_path, ScenarioDescription.DATASET.MAPPING_FILE)
|
|
||||||
save_summary_anda_mapping(summary_file, mapping_file, summaries, mappings)
|
save_summary_anda_mapping(summary_file, mapping_file, summaries, mappings)
|
||||||
|
|
||||||
return summaries, mappings
|
return summaries, mappings
|
||||||
|
|||||||
55
scenarionet/combine_dataset.py
Normal file
55
scenarionet/combine_dataset.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import pkg_resources # for suppress warning
|
||||||
|
import argparse
|
||||||
|
from scenarionet.builder.filters import ScenarioFilter
|
||||||
|
from scenarionet.builder.utils import combine_dataset
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_path",
|
||||||
|
required=True,
|
||||||
|
help="The name of the new combined dataset. "
|
||||||
|
"It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. "
|
||||||
|
"If exists_ok=True, those two .pkl files will be stored in an existing directory and turn "
|
||||||
|
"that directory into a dataset."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--from_datasets',
|
||||||
|
required=True,
|
||||||
|
nargs='+',
|
||||||
|
default=[],
|
||||||
|
help="Which datasets to combine. It takes any number of directory path as input"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--exist_ok",
|
||||||
|
action="store_true",
|
||||||
|
help="Still allow to write, if the dir exists already. "
|
||||||
|
"This write will only create two .pkl files and this directory will become a dataset."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--overwrite",
|
||||||
|
action="store_true",
|
||||||
|
help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, "
|
||||||
|
"whether to overwrite both files"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sdc_moving_dist_min",
|
||||||
|
default=20,
|
||||||
|
help="Selecting case with sdc_moving_dist > this value. "
|
||||||
|
"We will add more filter conditions in the future."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
target = args.sdc_moving_dist_min
|
||||||
|
filters = [ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=target, condition="greater")]
|
||||||
|
|
||||||
|
if len(args.from_datasets) != 0:
|
||||||
|
combine_dataset(
|
||||||
|
args.dataset_path,
|
||||||
|
*args.from_datasets,
|
||||||
|
exist_ok=args.exist_ok,
|
||||||
|
overwrite=args.overwrite,
|
||||||
|
try_generate_missing_file=True,
|
||||||
|
filters=filters
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("No source dataset are provided. Abort.")
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
|
import pkg_resources # for suppress warning
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from scenarionet import SCENARIONET_DATASET_PATH
|
from scenarionet import SCENARIONET_DATASET_PATH
|
||||||
from scenarionet.converter.nuplan.utils import get_nuplan_scenarios, convert_nuplan_scenario
|
from scenarionet.converter.nuplan.utils import get_nuplan_scenarios, convert_nuplan_scenario
|
||||||
from scenarionet.converter.utils import write_to_directory
|
from scenarionet.converter.utils import write_to_directory
|
||||||
@@ -14,22 +14,31 @@ if __name__ == '__main__':
|
|||||||
"--dataset_path",
|
"--dataset_path",
|
||||||
"-d",
|
"-d",
|
||||||
default=os.path.join(SCENARIONET_DATASET_PATH, "nuplan"),
|
default=os.path.join(SCENARIONET_DATASET_PATH, "nuplan"),
|
||||||
help="The path of the dataset"
|
help="A directory, the path to place the data"
|
||||||
)
|
)
|
||||||
parser.add_argument("--version", "-v", default='v1.1', help="version")
|
parser.add_argument("--version", "-v", default='v1.1', help="version of the raw data")
|
||||||
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, whether to overwrite it")
|
||||||
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
||||||
|
parser.add_argument(
|
||||||
|
"--raw_data_path",
|
||||||
|
type=str,
|
||||||
|
default=os.path.join(os.getenv("NUPLAN_DATA_ROOT"), "nuplan-v1.1/splits/mini"),
|
||||||
|
help="the place store .db files"
|
||||||
|
)
|
||||||
|
parser.add_argument("--test", action="store_true", help="for test use only. convert one log")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
force_overwrite = args.overwrite
|
overwrite = args.overwrite
|
||||||
dataset_name = args.dataset_name
|
dataset_name = args.dataset_name
|
||||||
output_path = args.dataset_path
|
output_path = args.dataset_path
|
||||||
version = args.version
|
version = args.version
|
||||||
|
|
||||||
data_root = os.path.join(os.getenv("NUPLAN_DATA_ROOT"), "nuplan-v1.1/splits/mini")
|
data_root = args.raw_data_path
|
||||||
map_root = os.getenv("NUPLAN_MAPS_ROOT")
|
map_root = os.getenv("NUPLAN_MAPS_ROOT")
|
||||||
|
if args.test:
|
||||||
scenarios = get_nuplan_scenarios(data_root, map_root, logs=["2021.07.16.20.45.29_veh-35_01095_01486"])
|
scenarios = get_nuplan_scenarios(data_root, map_root, logs=["2021.07.16.20.45.29_veh-35_01095_01486"])
|
||||||
# scenarios = get_nuplan_scenarios(data_root, map_root)
|
else:
|
||||||
|
scenarios = get_nuplan_scenarios(data_root, map_root)
|
||||||
|
|
||||||
write_to_directory(
|
write_to_directory(
|
||||||
convert_func=convert_nuplan_scenario,
|
convert_func=convert_nuplan_scenario,
|
||||||
@@ -37,6 +46,6 @@ if __name__ == '__main__':
|
|||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
dataset_version=version,
|
dataset_version=version,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
force_overwrite=force_overwrite,
|
overwrite=overwrite,
|
||||||
num_workers=args.num_workers
|
num_workers=args.num_workers
|
||||||
)
|
)
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
|
import pkg_resources # for suppress warning
|
||||||
import argparse
|
import argparse
|
||||||
import os.path
|
import os.path
|
||||||
|
|
||||||
from scenarionet import SCENARIONET_DATASET_PATH
|
from scenarionet import SCENARIONET_DATASET_PATH
|
||||||
from scenarionet.converter.nuscenes.utils import convert_nuscenes_scenario, get_nuscenes_scenarios
|
from scenarionet.converter.nuscenes.utils import convert_nuscenes_scenario, get_nuscenes_scenarios
|
||||||
from scenarionet.converter.utils import write_to_directory
|
from scenarionet.converter.utils import write_to_directory
|
||||||
@@ -14,20 +14,25 @@ if __name__ == '__main__':
|
|||||||
"--dataset_path",
|
"--dataset_path",
|
||||||
"-d",
|
"-d",
|
||||||
default=os.path.join(SCENARIONET_DATASET_PATH, "nuscenes"),
|
default=os.path.join(SCENARIONET_DATASET_PATH, "nuscenes"),
|
||||||
help="The path of the dataset"
|
help="directory, The path to place the data"
|
||||||
)
|
)
|
||||||
parser.add_argument("--version", "-v", default='v1.0-mini', help="version")
|
parser.add_argument(
|
||||||
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
"--version",
|
||||||
|
"-v",
|
||||||
|
default='v1.0-mini',
|
||||||
|
help="version of nuscenes data, scenario of this version will be converted "
|
||||||
|
)
|
||||||
|
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, whether to overwrite it")
|
||||||
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
force_overwrite = args.overwrite
|
overwrite = args.overwrite
|
||||||
dataset_name = args.dataset_name
|
dataset_name = args.dataset_name
|
||||||
output_path = args.dataset_path
|
output_path = args.dataset_path
|
||||||
version = args.version
|
version = args.version
|
||||||
|
|
||||||
dataroot = '/home/shady/data/nuscenes'
|
dataroot = '/home/shady/data/nuscenes'
|
||||||
scenarios, nusc = get_nuscenes_scenarios(dataroot, version)
|
scenarios, nuscs = get_nuscenes_scenarios(dataroot, version, args.num_workers)
|
||||||
|
|
||||||
write_to_directory(
|
write_to_directory(
|
||||||
convert_func=convert_nuscenes_scenario,
|
convert_func=convert_nuscenes_scenario,
|
||||||
@@ -35,7 +40,7 @@ if __name__ == '__main__':
|
|||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
dataset_version=version,
|
dataset_version=version,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
force_overwrite=force_overwrite,
|
overwrite=overwrite,
|
||||||
nuscenes=nusc,
|
num_workers=args.num_workers,
|
||||||
num_workers=args.num_workers
|
nuscenes=nuscs,
|
||||||
)
|
)
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
|
import pkg_resources # for suppress warning
|
||||||
import argparse
|
import argparse
|
||||||
import os.path
|
import os.path
|
||||||
|
|
||||||
import metadrive
|
import metadrive
|
||||||
from metadrive.policy.idm_policy import IDMPolicy
|
|
||||||
|
|
||||||
from scenarionet import SCENARIONET_DATASET_PATH
|
from scenarionet import SCENARIONET_DATASET_PATH
|
||||||
from scenarionet.converter.pg.utils import get_pg_scenarios, convert_pg_scenario
|
from scenarionet.converter.pg.utils import get_pg_scenarios, convert_pg_scenario
|
||||||
@@ -14,19 +14,24 @@ if __name__ == '__main__':
|
|||||||
"--dataset_name", "-n", default="pg", help="Dataset name, will be used to generate scenario files"
|
"--dataset_name", "-n", default="pg", help="Dataset name, will be used to generate scenario files"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset_path", "-d", default=os.path.join(SCENARIONET_DATASET_PATH, "pg"), help="The path of the dataset"
|
"--dataset_path",
|
||||||
|
"-d",
|
||||||
|
default=os.path.join(SCENARIONET_DATASET_PATH, "pg"),
|
||||||
|
help="directory, The path to place the data"
|
||||||
)
|
)
|
||||||
parser.add_argument("--version", "-v", default=metadrive.constants.DATA_VERSION, help="version")
|
parser.add_argument("--version", "-v", default=metadrive.constants.DATA_VERSION, help="version")
|
||||||
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, whether to overwrite it")
|
||||||
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
||||||
|
parser.add_argument("--num_scenarios", type=int, default=64, help="how many scenarios to generate (default: 30)")
|
||||||
|
parser.add_argument("--start_index", type=int, default=0, help="which index to start")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
force_overwrite = args.overwrite
|
overwrite = args.overwrite
|
||||||
dataset_name = args.dataset_name
|
dataset_name = args.dataset_name
|
||||||
output_path = args.dataset_path
|
output_path = args.dataset_path
|
||||||
version = args.version
|
version = args.version
|
||||||
|
|
||||||
scenario_indices, env = get_pg_scenarios(30, IDMPolicy)
|
scenario_indices = get_pg_scenarios(args.start_index, args.num_scenarios)
|
||||||
|
|
||||||
write_to_directory(
|
write_to_directory(
|
||||||
convert_func=convert_pg_scenario,
|
convert_func=convert_pg_scenario,
|
||||||
@@ -34,7 +39,6 @@ if __name__ == '__main__':
|
|||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
dataset_version=version,
|
dataset_version=version,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
force_overwrite=force_overwrite,
|
overwrite=overwrite,
|
||||||
env=env,
|
num_workers=args.num_workers,
|
||||||
num_workers=args.num_workers
|
|
||||||
)
|
)
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
|
import pkg_resources # for suppress warning
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from scenarionet import SCENARIONET_DATASET_PATH
|
from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_REPO_PATH
|
||||||
from scenarionet.converter.utils import write_to_directory
|
from scenarionet.converter.utils import write_to_directory
|
||||||
from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios
|
from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios
|
||||||
|
|
||||||
@@ -14,19 +15,27 @@ if __name__ == '__main__':
|
|||||||
"--dataset_name", "-n", default="waymo", help="Dataset name, will be used to generate scenario files"
|
"--dataset_name", "-n", default="waymo", help="Dataset name, will be used to generate scenario files"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset_path", "-d", default=os.path.join(SCENARIONET_DATASET_PATH, "waymo"), help="The path of the dataset"
|
"--dataset_path",
|
||||||
|
"-d",
|
||||||
|
default=os.path.join(SCENARIONET_DATASET_PATH, "waymo"),
|
||||||
|
help="A directory, the path to place the converted data"
|
||||||
)
|
)
|
||||||
parser.add_argument("--version", "-v", default='v1.2', help="version")
|
parser.add_argument("--version", "-v", default='v1.2', help="version")
|
||||||
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, whether to overwrite it")
|
||||||
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
||||||
|
parser.add_argument(
|
||||||
|
"--raw_data_path",
|
||||||
|
default=os.path.join(SCENARIONET_REPO_PATH, "waymo_origin"),
|
||||||
|
help="The directory stores all waymo tfrecord"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
force_overwrite = args.overwrite
|
overwrite = args.overwrite
|
||||||
dataset_name = args.dataset_name
|
dataset_name = args.dataset_name
|
||||||
output_path = args.dataset_path
|
output_path = args.dataset_path
|
||||||
version = args.version
|
version = args.version
|
||||||
|
|
||||||
waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, "../waymo_origin")
|
waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, args.raw_data_path)
|
||||||
scenarios = get_waymo_scenarios(waymo_data_directory)
|
scenarios = get_waymo_scenarios(waymo_data_directory)
|
||||||
|
|
||||||
write_to_directory(
|
write_to_directory(
|
||||||
@@ -35,6 +44,6 @@ if __name__ == '__main__':
|
|||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
dataset_version=version,
|
dataset_version=version,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
force_overwrite=force_overwrite,
|
overwrite=overwrite,
|
||||||
num_workers=args.num_workers
|
num_workers=args.num_workers
|
||||||
)
|
)
|
||||||
@@ -388,7 +388,11 @@ def convert_nuscenes_scenario(scene, version, nuscenes: NuScenes):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_nuscenes_scenarios(dataroot, version):
|
def get_nuscenes_scenarios(dataroot, version, num_workers=2):
|
||||||
nusc = NuScenes(version=version, dataroot=dataroot)
|
nusc = NuScenes(version=version, dataroot=dataroot)
|
||||||
scenarios = nusc.scene
|
scenarios = nusc.scene
|
||||||
return scenarios, nusc
|
|
||||||
|
def _get_nusc():
|
||||||
|
return NuScenes(version=version, dataroot=dataroot)
|
||||||
|
|
||||||
|
return scenarios, [_get_nusc() for _ in range(num_workers)]
|
||||||
|
|||||||
@@ -1,32 +1,29 @@
|
|||||||
from metadrive.envs.metadrive_env import MetaDriveEnv
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
||||||
|
|
||||||
|
|
||||||
def convert_pg_scenario(scenario_index, version, env):
|
def convert_pg_scenario(scenario_index, version, env):
|
||||||
"""
|
"""
|
||||||
Simulate to collect PG Scenarios
|
Simulate to collect PG Scenarios
|
||||||
:param scenario_index: the index to export
|
:param scenario_index: the index to export [env.start_seed, env.start_seed + num_scenarios_per_worker]
|
||||||
:param version: place holder
|
:param version: place holder
|
||||||
:param env: metadrive env instance
|
:param env: metadrive env instance
|
||||||
"""
|
"""
|
||||||
|
#
|
||||||
|
# if (scenario_index - env.config["start_seed"]) % reset_freq == 0:
|
||||||
|
# # for avoiding memory leak
|
||||||
|
# env.close()
|
||||||
|
|
||||||
logging.disable(logging.INFO)
|
logging.disable(logging.INFO)
|
||||||
policy = lambda x: [0, 1] # placeholder
|
policy = lambda x: [0, 1] # placeholder
|
||||||
scenarios, done_info = env.export_scenarios(policy, scenario_index=[scenario_index], to_dict=False)
|
scenarios, done_info = env.export_scenarios(
|
||||||
|
policy, scenario_index=[scenario_index], max_episode_length=500, suppress_warning=True, to_dict=False
|
||||||
|
)
|
||||||
scenario = scenarios[scenario_index]
|
scenario = scenarios[scenario_index]
|
||||||
assert scenario[SD.VERSION] == version, "Data version mismatch"
|
assert scenario[SD.VERSION] == version, "Data version mismatch"
|
||||||
return scenario
|
return scenario
|
||||||
|
|
||||||
|
|
||||||
def get_pg_scenarios(num_scenarios, policy, start_seed=0):
|
def get_pg_scenarios(start_index, num_scenarios):
|
||||||
env = MetaDriveEnv(
|
return [i for i in range(start_index, start_index + num_scenarios)]
|
||||||
dict(
|
|
||||||
start_seed=start_seed,
|
|
||||||
num_scenarios=num_scenarios,
|
|
||||||
traffic_density=0.2,
|
|
||||||
agent_policy=policy,
|
|
||||||
crash_vehicle_done=False,
|
|
||||||
map=2
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return [i for i in range(num_scenarios)], env
|
|
||||||
|
|||||||
@@ -10,11 +10,15 @@ import shutil
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import psutil
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from metadrive.envs.metadrive_env import MetaDriveEnv
|
||||||
|
from metadrive.policy.idm_policy import IDMPolicy
|
||||||
from metadrive.scenario import ScenarioDescription as SD
|
from metadrive.scenario import ScenarioDescription as SD
|
||||||
|
|
||||||
from scenarionet.builder.utils import combine_multiple_dataset
|
from scenarionet.builder.utils import combine_dataset
|
||||||
from scenarionet.common_utils import save_summary_anda_mapping
|
from scenarionet.common_utils import save_summary_anda_mapping
|
||||||
|
from scenarionet.converter.pg.utils import convert_pg_scenario
|
||||||
|
|
||||||
logger = logging.getLogger(__file__)
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
@@ -61,33 +65,34 @@ def contains_explicit_return(f):
|
|||||||
|
|
||||||
|
|
||||||
def write_to_directory(
|
def write_to_directory(
|
||||||
convert_func,
|
convert_func, scenarios, output_path, dataset_version, dataset_name, overwrite=False, num_workers=8, **kwargs
|
||||||
scenarios,
|
|
||||||
output_path,
|
|
||||||
dataset_version,
|
|
||||||
dataset_name,
|
|
||||||
force_overwrite=False,
|
|
||||||
num_workers=8,
|
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
# make sure dir not exist
|
# make sure dir not exist
|
||||||
|
kwargs_for_workers = [{} for _ in range(num_workers)]
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
for i in range(num_workers):
|
||||||
|
kwargs_for_workers[i][key] = value[i]
|
||||||
|
|
||||||
save_path = copy.deepcopy(output_path)
|
save_path = copy.deepcopy(output_path)
|
||||||
if os.path.exists(output_path):
|
if os.path.exists(output_path):
|
||||||
if not force_overwrite:
|
if not overwrite:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Directory {} already exists! Abort. "
|
"Directory {} already exists! Abort. "
|
||||||
"\n Try setting force_overwrite=True or adding --overwrite".format(output_path)
|
"\n Try setting overwrite=True or adding --overwrite".format(output_path)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
shutil.rmtree(output_path)
|
||||||
|
os.makedirs(save_path, exist_ok=False)
|
||||||
|
|
||||||
basename = os.path.basename(output_path)
|
basename = os.path.basename(output_path)
|
||||||
dir = os.path.dirname(output_path)
|
# dir = os.path.dirname(output_path)
|
||||||
for i in range(num_workers):
|
for i in range(num_workers):
|
||||||
output_path = os.path.join(dir, "{}_{}".format(basename, str(i)))
|
subdir = os.path.join(output_path, "{}_{}".format(basename, str(i)))
|
||||||
if os.path.exists(output_path):
|
if os.path.exists(subdir):
|
||||||
if not force_overwrite:
|
if not overwrite:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Directory {} already exists! Abort. "
|
"Directory {} already exists! Abort. "
|
||||||
"\n Try setting force_overwrite=True or adding --overwrite".format(output_path)
|
"\n Try setting overwrite=True or adding --overwrite".format(subdir)
|
||||||
)
|
)
|
||||||
# get arguments for workers
|
# get arguments for workers
|
||||||
num_files = len(scenarios)
|
num_files = len(scenarios)
|
||||||
@@ -104,9 +109,9 @@ def write_to_directory(
|
|||||||
end_idx = num_files
|
end_idx = num_files
|
||||||
else:
|
else:
|
||||||
end_idx = (i + 1) * num_files_each_worker
|
end_idx = (i + 1) * num_files_each_worker
|
||||||
output_path = os.path.join(dir, "{}_{}".format(basename, str(i)))
|
subdir = os.path.join(output_path, "{}_{}".format(basename, str(i)))
|
||||||
output_pathes.append(output_path)
|
output_pathes.append(subdir)
|
||||||
argument_list.append([scenarios[i * num_files_each_worker:end_idx], kwargs, i, output_path])
|
argument_list.append([scenarios[i * num_files_each_worker:end_idx], kwargs_for_workers[i], i, subdir])
|
||||||
|
|
||||||
# prefill arguments
|
# prefill arguments
|
||||||
func = partial(
|
func = partial(
|
||||||
@@ -114,26 +119,24 @@ def write_to_directory(
|
|||||||
convert_func=convert_func,
|
convert_func=convert_func,
|
||||||
dataset_version=dataset_version,
|
dataset_version=dataset_version,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
force_overwrite=force_overwrite
|
overwrite=overwrite
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run, workers and process result from worker
|
# Run, workers and process result from worker
|
||||||
with multiprocessing.Pool(num_workers) as p:
|
with multiprocessing.Pool(num_workers, maxtasksperchild=10) as p:
|
||||||
all_result = list(p.imap(func, argument_list))
|
ret = list(p.imap(func, argument_list))
|
||||||
combine_multiple_dataset(
|
# call ret to block the process
|
||||||
save_path, *output_pathes, force_overwrite=force_overwrite, try_generate_missing_file=False
|
combine_dataset(save_path, *output_pathes, exist_ok=True, overwrite=False, try_generate_missing_file=False)
|
||||||
)
|
|
||||||
return all_result
|
|
||||||
|
|
||||||
|
|
||||||
def writing_to_directory_wrapper(args, convert_func, dataset_version, dataset_name, force_overwrite=False):
|
def writing_to_directory_wrapper(args, convert_func, dataset_version, dataset_name, overwrite=False):
|
||||||
return write_to_directory_single_worker(
|
return write_to_directory_single_worker(
|
||||||
convert_func=convert_func,
|
convert_func=convert_func,
|
||||||
scenarios=args[0],
|
scenarios=args[0],
|
||||||
output_path=args[3],
|
output_path=args[3],
|
||||||
dataset_version=dataset_version,
|
dataset_version=dataset_version,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
force_overwrite=force_overwrite,
|
overwrite=overwrite,
|
||||||
worker_index=args[2],
|
worker_index=args[2],
|
||||||
**args[1]
|
**args[1]
|
||||||
)
|
)
|
||||||
@@ -146,7 +149,8 @@ def write_to_directory_single_worker(
|
|||||||
dataset_version,
|
dataset_version,
|
||||||
dataset_name,
|
dataset_name,
|
||||||
worker_index=0,
|
worker_index=0,
|
||||||
force_overwrite=False,
|
overwrite=False,
|
||||||
|
report_memory_freq=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -169,13 +173,10 @@ def write_to_directory_single_worker(
|
|||||||
# make real save dir
|
# make real save dir
|
||||||
delay_remove = None
|
delay_remove = None
|
||||||
if os.path.exists(save_path):
|
if os.path.exists(save_path):
|
||||||
if force_overwrite:
|
if overwrite:
|
||||||
delay_remove = save_path
|
delay_remove = save_path
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("Directory already exists! Abort." "\n Try setting overwrite=True or using --overwrite")
|
||||||
"Directory already exists! Abort."
|
|
||||||
"\n Try setting force_overwrite=True or using --overwrite"
|
|
||||||
)
|
|
||||||
|
|
||||||
summary_file = SD.DATASET.SUMMARY_FILE
|
summary_file = SD.DATASET.SUMMARY_FILE
|
||||||
mapping_file = SD.DATASET.MAPPING_FILE
|
mapping_file = SD.DATASET.MAPPING_FILE
|
||||||
@@ -185,6 +186,23 @@ def write_to_directory_single_worker(
|
|||||||
|
|
||||||
summary = {}
|
summary = {}
|
||||||
mapping = {}
|
mapping = {}
|
||||||
|
|
||||||
|
# for pg scenario only
|
||||||
|
if convert_func is convert_pg_scenario:
|
||||||
|
env = MetaDriveEnv(
|
||||||
|
dict(
|
||||||
|
start_seed=scenarios[0],
|
||||||
|
num_scenarios=len(scenarios),
|
||||||
|
traffic_density=0.15,
|
||||||
|
agent_policy=IDMPolicy,
|
||||||
|
crash_vehicle_done=False,
|
||||||
|
store_map=False,
|
||||||
|
map=2
|
||||||
|
)
|
||||||
|
)
|
||||||
|
kwargs["env"] = env
|
||||||
|
|
||||||
|
count = 0
|
||||||
for scenario in tqdm.tqdm(scenarios, desc="Worker Index: {}".format(worker_index)):
|
for scenario in tqdm.tqdm(scenarios, desc="Worker Index: {}".format(worker_index)):
|
||||||
# convert scenario
|
# convert scenario
|
||||||
sd_scenario = convert_func(scenario, dataset_version, **kwargs)
|
sd_scenario = convert_func(scenario, dataset_version, **kwargs)
|
||||||
@@ -217,6 +235,10 @@ def write_to_directory_single_worker(
|
|||||||
with open(p, "wb") as f:
|
with open(p, "wb") as f:
|
||||||
pickle.dump(sd_scenario, f)
|
pickle.dump(sd_scenario, f)
|
||||||
|
|
||||||
|
if report_memory_freq is not None and (count) % report_memory_freq == 0:
|
||||||
|
print("Current Memory: {}".format(process_memory()))
|
||||||
|
count += 1
|
||||||
|
|
||||||
# store summary file
|
# store summary file
|
||||||
save_summary_anda_mapping(summary_file_path, mapping_file_path, summary, mapping)
|
save_summary_anda_mapping(summary_file_path, mapping_file_path, summary, mapping)
|
||||||
|
|
||||||
@@ -226,4 +248,8 @@ def write_to_directory_single_worker(
|
|||||||
shutil.rmtree(delay_remove)
|
shutil.rmtree(delay_remove)
|
||||||
os.rename(output_path, save_path)
|
os.rename(output_path, save_path)
|
||||||
|
|
||||||
return summary, mapping
|
|
||||||
|
def process_memory():
|
||||||
|
process = psutil.Process(os.getpid())
|
||||||
|
mem_info = process.memory_info()
|
||||||
|
return mem_info.rss / 1024 / 1024 # mb
|
||||||
|
|||||||
@@ -395,7 +395,6 @@ def convert_waymo_scenario(scenario, version):
|
|||||||
md_scenario[SD.METADATA][SD.SDC_ID] = str(sdc_id)
|
md_scenario[SD.METADATA][SD.SDC_ID] = str(sdc_id)
|
||||||
md_scenario[SD.METADATA]["dataset"] = "waymo"
|
md_scenario[SD.METADATA]["dataset"] = "waymo"
|
||||||
md_scenario[SD.METADATA]["scenario_id"] = scenario.scenario_id[:id_end]
|
md_scenario[SD.METADATA]["scenario_id"] = scenario.scenario_id[:id_end]
|
||||||
# TODO LQY Can we infer it?
|
|
||||||
md_scenario[SD.METADATA]["source_file"] = scenario.scenario_id[id_end + 1:]
|
md_scenario[SD.METADATA]["source_file"] = scenario.scenario_id[id_end + 1:]
|
||||||
md_scenario[SD.METADATA]["track_length"] = track_length
|
md_scenario[SD.METADATA]["track_length"] = track_length
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,18 @@
|
|||||||
|
import pkg_resources # for suppress warning
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from scenarionet.verifier.error import ErrorFile
|
from scenarionet.verifier.error import ErrorFile
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--file", "-f", required=True, help="The path of the error file")
|
parser.add_argument("--file", "-f", required=True, help="The path of the error file, should be xyz.json")
|
||||||
parser.add_argument("--dataset_path", "-d", required=True, help="The path of the generated dataset")
|
parser.add_argument("--dataset_path", "-d", required=True, help="The path of the newly generated dataset")
|
||||||
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
||||||
parser.add_argument("--broken", action="store_true", help="Generate dataset containing only broken files")
|
parser.add_argument(
|
||||||
|
"--broken",
|
||||||
|
action="store_true",
|
||||||
|
help="By default, only successful scenarios will be picked to build the new dataset. "
|
||||||
|
"If turn on this flog, it will generate dataset containing only broken scenarios."
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
ErrorFile.generate_dataset(args.file, args.dataset_path, args.overwrite, args.broken)
|
ErrorFile.generate_dataset(args.file, args.dataset_path, args.overwrite, args.broken)
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
|
import pkg_resources # for suppress warning
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from metadrive.envs.scenario_env import ScenarioEnv
|
from metadrive.envs.scenario_env import ScenarioEnv
|
||||||
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
|
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
|
||||||
from metadrive.scenario.utils import get_number_of_scenarios
|
from metadrive.scenario.utils import get_number_of_scenarios
|
||||||
@@ -39,7 +39,7 @@ if __name__ == '__main__':
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
for seed in range(num_scenario if args.scenario_index is not None else 1000000):
|
for seed in range(num_scenario if args.scenario_index is not None else 1000000):
|
||||||
env.reset(force_seed=seed if args.scenario_index is not None else args.scenario_index)
|
env.reset(force_seed=seed if args.scenario_index is None else args.scenario_index)
|
||||||
for t in range(10000):
|
for t in range(10000):
|
||||||
o, r, d, info = env.step([0, 0])
|
o, r, d, info = env.step([0, 0])
|
||||||
if env.config["use_render"]:
|
if env.config["use_render"]:
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
from scenarionet.builder.filters import ScenarioFilter
|
|
||||||
from scenarionet.builder.utils import combine_multiple_dataset
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--to", required=True, help="Dataset path, a directory")
|
|
||||||
parser.add_argument('--from_datasets', required=True, nargs='+', default=[])
|
|
||||||
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
|
||||||
parser.add_argument("--sdc_moving_dist_min", default=0, help="Selecting case with sdc_moving_dist > this value")
|
|
||||||
args = parser.parse_args()
|
|
||||||
filters = [ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=20, condition="greater")]
|
|
||||||
|
|
||||||
if len(args.from_datasets) != 0:
|
|
||||||
combine_multiple_dataset(
|
|
||||||
args.to,
|
|
||||||
*args.from_datasets,
|
|
||||||
force_overwrite=args.overwrite,
|
|
||||||
try_generate_missing_file=True,
|
|
||||||
filters=filters
|
|
||||||
)
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
from scenarionet.verifier.utils import verify_loading_into_metadrive, set_random_drop
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--dataset_path", required=True, help="Dataset path, a directory")
|
|
||||||
parser.add_argument("--result_save_dir", required=True, help="Dataset path, a directory")
|
|
||||||
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
|
||||||
parser.add_argument("--random_drop", action="store_true", help="Randomly make some scenarios fail. for test only!")
|
|
||||||
args = parser.parse_args()
|
|
||||||
set_random_drop(args.random_drop)
|
|
||||||
verify_loading_into_metadrive(args.dataset_path, args.result_save_dir, num_workers=args.num_workers)
|
|
||||||
@@ -1,23 +1,28 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_PACKAGE_PATH
|
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
||||||
from scenarionet.builder.utils import combine_multiple_dataset
|
|
||||||
from scenarionet.verifier.utils import verify_loading_into_metadrive
|
from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_PACKAGE_PATH, TMP_PATH
|
||||||
|
from scenarionet.builder.utils import combine_dataset
|
||||||
|
from scenarionet.common_utils import read_dataset_summary, read_scenario
|
||||||
|
|
||||||
|
|
||||||
def _test_combine_dataset():
|
def _test_combine_dataset():
|
||||||
dataset_paths = [
|
dataset_paths = [
|
||||||
os.path.join(SCENARIONET_DATASET_PATH, "nuscenes"),
|
os.path.join(SCENARIONET_DATASET_PATH, "nuscenes"),
|
||||||
|
os.path.join(SCENARIONET_DATASET_PATH, "nuscenes", "nuscenes_0"),
|
||||||
os.path.join(SCENARIONET_DATASET_PATH, "nuplan"),
|
os.path.join(SCENARIONET_DATASET_PATH, "nuplan"),
|
||||||
os.path.join(SCENARIONET_DATASET_PATH, "waymo"),
|
os.path.join(SCENARIONET_DATASET_PATH, "waymo"),
|
||||||
os.path.join(SCENARIONET_DATASET_PATH, "pg")
|
os.path.join(SCENARIONET_DATASET_PATH, "pg")
|
||||||
]
|
]
|
||||||
|
|
||||||
combine_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "combine")
|
combine_path = os.path.join(TMP_PATH, "combine")
|
||||||
combine_multiple_dataset(combine_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True)
|
combine_dataset(combine_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True)
|
||||||
# os.makedirs("verify_results", exist_ok=True)
|
summary, _, mapping = read_dataset_summary(combine_path)
|
||||||
# verify_loading_into_metadrive(combine_path, "verify_results")
|
for scenario in summary:
|
||||||
# assert success
|
sd = read_scenario(combine_path, mapping, scenario)
|
||||||
|
SD.sanity_check(sd)
|
||||||
|
print("Test pass")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ import os.path
|
|||||||
|
|
||||||
from metadrive.type import MetaDriveType
|
from metadrive.type import MetaDriveType
|
||||||
|
|
||||||
from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_PACKAGE_PATH
|
from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_PACKAGE_PATH, TMP_PATH
|
||||||
from scenarionet.builder.filters import ScenarioFilter
|
from scenarionet.builder.filters import ScenarioFilter
|
||||||
from scenarionet.builder.utils import combine_multiple_dataset
|
from scenarionet.builder.utils import combine_dataset
|
||||||
|
|
||||||
|
|
||||||
def test_filter_dataset():
|
def test_filter_dataset():
|
||||||
@@ -17,16 +17,17 @@ def test_filter_dataset():
|
|||||||
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "waymo"))
|
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "waymo"))
|
||||||
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "pg"))
|
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "pg"))
|
||||||
|
|
||||||
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "combine")
|
output_path = os.path.join(TMP_PATH, "combine")
|
||||||
|
|
||||||
# ========================= test 1 =========================
|
# ========================= test 1 =========================
|
||||||
# nuscenes data has no light
|
# nuscenes data has no light
|
||||||
# light_condition = ScenarioFilter.make(ScenarioFilter.has_traffic_light)
|
# light_condition = ScenarioFilter.make(ScenarioFilter.has_traffic_light)
|
||||||
sdc_driving_condition = ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=30, condition="greater")
|
sdc_driving_condition = ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=30, condition="greater")
|
||||||
summary, mapping = combine_multiple_dataset(
|
summary, mapping = combine_dataset(
|
||||||
output_path,
|
output_path,
|
||||||
*dataset_paths,
|
*dataset_paths,
|
||||||
force_overwrite=True,
|
exist_ok=True,
|
||||||
|
overwrite=True,
|
||||||
try_generate_missing_file=True,
|
try_generate_missing_file=True,
|
||||||
filters=[sdc_driving_condition]
|
filters=[sdc_driving_condition]
|
||||||
)
|
)
|
||||||
@@ -38,8 +39,13 @@ def test_filter_dataset():
|
|||||||
ScenarioFilter.object_number, number_threshold=50, object_type=MetaDriveType.PEDESTRIAN, condition="greater"
|
ScenarioFilter.object_number, number_threshold=50, object_type=MetaDriveType.PEDESTRIAN, condition="greater"
|
||||||
)
|
)
|
||||||
|
|
||||||
summary, mapping = combine_multiple_dataset(
|
summary, mapping = combine_dataset(
|
||||||
output_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True, filters=[num_condition]
|
output_path,
|
||||||
|
*dataset_paths,
|
||||||
|
exist_ok=True,
|
||||||
|
overwrite=True,
|
||||||
|
try_generate_missing_file=True,
|
||||||
|
filters=[num_condition]
|
||||||
)
|
)
|
||||||
assert len(summary) > 0
|
assert len(summary) > 0
|
||||||
|
|
||||||
@@ -47,8 +53,13 @@ def test_filter_dataset():
|
|||||||
|
|
||||||
traffic_light = ScenarioFilter.make(ScenarioFilter.has_traffic_light)
|
traffic_light = ScenarioFilter.make(ScenarioFilter.has_traffic_light)
|
||||||
|
|
||||||
summary, mapping = combine_multiple_dataset(
|
summary, mapping = combine_dataset(
|
||||||
output_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True, filters=[traffic_light]
|
output_path,
|
||||||
|
*dataset_paths,
|
||||||
|
exist_ok=True,
|
||||||
|
overwrite=True,
|
||||||
|
try_generate_missing_file=True,
|
||||||
|
filters=[traffic_light]
|
||||||
)
|
)
|
||||||
assert len(summary) > 0
|
assert len(summary) > 0
|
||||||
|
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ import os.path
|
|||||||
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
||||||
|
|
||||||
from scenarionet import SCENARIONET_DATASET_PATH
|
from scenarionet import SCENARIONET_DATASET_PATH
|
||||||
from scenarionet import SCENARIONET_PACKAGE_PATH
|
from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH
|
||||||
from scenarionet.builder.utils import combine_multiple_dataset
|
from scenarionet.builder.utils import combine_dataset
|
||||||
from scenarionet.common_utils import read_dataset_summary, read_scenario
|
from scenarionet.common_utils import read_dataset_summary, read_scenario
|
||||||
from scenarionet.common_utils import recursive_equal
|
from scenarionet.common_utils import recursive_equal
|
||||||
from scenarionet.verifier.error import ErrorFile
|
from scenarionet.verifier.error import ErrorFile
|
||||||
from scenarionet.verifier.utils import set_random_drop
|
from scenarionet.verifier.utils import set_random_drop
|
||||||
from scenarionet.verifier.utils import verify_loading_into_metadrive
|
from scenarionet.verifier.utils import verify_dataset
|
||||||
|
|
||||||
|
|
||||||
def test_generate_from_error():
|
def test_generate_from_error():
|
||||||
@@ -24,27 +24,27 @@ def test_generate_from_error():
|
|||||||
os.path.join(SCENARIONET_DATASET_PATH, "pg")
|
os.path.join(SCENARIONET_DATASET_PATH, "pg")
|
||||||
]
|
]
|
||||||
|
|
||||||
dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "combine")
|
dataset_path = os.path.join(TMP_PATH, "combine")
|
||||||
combine_multiple_dataset(dataset_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True)
|
combine_dataset(dataset_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True)
|
||||||
|
|
||||||
summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path)
|
summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path)
|
||||||
for scenario_file in sorted_scenarios:
|
for scenario_file in sorted_scenarios:
|
||||||
read_scenario(dataset_path, mapping, scenario_file)
|
read_scenario(dataset_path, mapping, scenario_file)
|
||||||
success, logs = verify_loading_into_metadrive(
|
success, logs = verify_dataset(
|
||||||
dataset_path, result_save_dir="../test_dataset", steps_to_run=1000, num_workers=16
|
dataset_path, result_save_dir="../test_dataset", steps_to_run=1000, num_workers=16, overwrite=True
|
||||||
)
|
)
|
||||||
set_random_drop(False)
|
set_random_drop(False)
|
||||||
# get error file
|
# get error file
|
||||||
file_name = ErrorFile.get_error_file_name(dataset_path)
|
file_name = ErrorFile.get_error_file_name(dataset_path)
|
||||||
error_file_path = os.path.join("../test_dataset", file_name)
|
error_file_path = os.path.join("../test_dataset", file_name)
|
||||||
# regenerate
|
# regenerate
|
||||||
pass_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "passed_scenarios")
|
pass_dataset = os.path.join(TMP_PATH, "passed_scenarios")
|
||||||
fail_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "failed_scenarios")
|
fail_dataset = os.path.join(TMP_PATH, "failed_scenarios")
|
||||||
pass_summary, pass_mapping = ErrorFile.generate_dataset(
|
pass_summary, pass_mapping = ErrorFile.generate_dataset(
|
||||||
error_file_path, pass_dataset, force_overwrite=True, broken_scenario=False
|
error_file_path, pass_dataset, overwrite=True, broken_scenario=False
|
||||||
)
|
)
|
||||||
fail_summary, fail_mapping = ErrorFile.generate_dataset(
|
fail_summary, fail_mapping = ErrorFile.generate_dataset(
|
||||||
error_file_path, fail_dataset, force_overwrite=True, broken_scenario=True
|
error_file_path, fail_dataset, overwrite=True, broken_scenario=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
|
|||||||
26
scenarionet/tests/local_test/_test_memory_leak_pg.py
Normal file
26
scenarionet/tests/local_test/_test_memory_leak_pg.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from metadrive.constants import DATA_VERSION
|
||||||
|
|
||||||
|
from scenarionet import TMP_PATH
|
||||||
|
from scenarionet.converter.pg.utils import convert_pg_scenario, get_pg_scenarios
|
||||||
|
from scenarionet.converter.utils import write_to_directory_single_worker
|
||||||
|
|
||||||
|
|
||||||
|
def _test_pg_memory_leak():
|
||||||
|
path = os.path.join(TMP_PATH, "test_memory_leak")
|
||||||
|
scenario_indices = get_pg_scenarios(0, 1000)
|
||||||
|
write_to_directory_single_worker(
|
||||||
|
convert_pg_scenario,
|
||||||
|
scenario_indices,
|
||||||
|
path,
|
||||||
|
DATA_VERSION,
|
||||||
|
"pg",
|
||||||
|
worker_index=0,
|
||||||
|
report_memory_freq=10,
|
||||||
|
overwrite=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
_test_pg_memory_leak()
|
||||||
39
scenarionet/tests/local_test/_test_pg_multiprocess.py
Normal file
39
scenarionet/tests/local_test/_test_pg_multiprocess.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from metadrive.envs.metadrive_env import MetaDriveEnv
|
||||||
|
from metadrive.policy.idm_policy import IDMPolicy
|
||||||
|
from metadrive.scenario.utils import get_number_of_scenarios, assert_scenario_equal
|
||||||
|
|
||||||
|
from scenarionet import SCENARIONET_DATASET_PATH
|
||||||
|
from scenarionet.common_utils import read_dataset_summary, read_scenario
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
dataset_path = os.path.abspath(os.path.join(SCENARIONET_DATASET_PATH, "pg"))
|
||||||
|
start_seed = 0
|
||||||
|
num_scenario = get_number_of_scenarios(dataset_path)
|
||||||
|
|
||||||
|
# load multi process ret
|
||||||
|
summary, s_list, mapping = read_dataset_summary(dataset_path)
|
||||||
|
to_compare = dict()
|
||||||
|
for k, file in enumerate(s_list[:num_scenario]):
|
||||||
|
to_compare[k + start_seed] = read_scenario(dataset_path, mapping, file).to_dict()
|
||||||
|
|
||||||
|
# generate single process ret
|
||||||
|
env = MetaDriveEnv(
|
||||||
|
dict(
|
||||||
|
start_seed=start_seed,
|
||||||
|
num_scenarios=num_scenario,
|
||||||
|
traffic_density=0.15,
|
||||||
|
agent_policy=IDMPolicy,
|
||||||
|
crash_vehicle_done=False,
|
||||||
|
map=2
|
||||||
|
)
|
||||||
|
)
|
||||||
|
policy = lambda x: [0, 1] # placeholder
|
||||||
|
ret = env.export_scenarios(
|
||||||
|
policy, [i for i in range(start_seed, start_seed + num_scenario)], return_done_info=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# for i in tqdm.tqdm(range(num_scenario), desc="Assert"):
|
||||||
|
assert_scenario_equal(ret, to_compare, only_compare_sdc=True)
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
python ../../scripts/combine_dataset.py --to ../tmp/test_combine_dataset --from_datasets ../../../dataset/waymo ../../../dataset/pg ../../../dataset/nuscenes ../../../dataset/nuplan --overwrite
|
python ../../combine_dataset.py --overwrite --exist_ok --dataset_path ../tmp/test_combine_dataset --from_datasets ../../../dataset/waymo ../../../dataset/pg ../../../dataset/nuscenes ../../../dataset/nuplan --overwrite
|
||||||
python ../../scripts/verify_dataset.py --dataset_path ../tmp/test_combine_dataset --result_save_dir ../tmp/test_combine_dataset --random_drop --num_workers=16
|
python ../../verify_simulation.py --overwrite --dataset_path ../tmp/test_combine_dataset --result_save_dir ../tmp/test_combine_dataset --random_drop --num_workers=16
|
||||||
python ../../scripts/generate_from_error_file.py --file ../tmp/test_combine_dataset/error_scenarios_for_test_combine_dataset.json --overwrite --dataset_path ../tmp/verify_pass
|
python ../../generate_from_error_file.py --file ../tmp/test_combine_dataset/error_scenarios_for_test_combine_dataset.json --overwrite --dataset_path ../tmp/verify_pass
|
||||||
python ../../scripts/generate_from_error_file.py --file ../tmp/test_combine_dataset/error_scenarios_for_test_combine_dataset.json --overwrite --dataset_path ../tmp/verify_fail --broken
|
python ../../generate_from_error_file.py --file ../tmp/test_combine_dataset/error_scenarios_for_test_combine_dataset.json --overwrite --dataset_path ../tmp/verify_fail --broken
|
||||||
3
scenarionet/tests/local_test/convert_large_pg.sh
Normal file
3
scenarionet/tests/local_test/convert_large_pg.sh
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
bash ../../convert_pg_large.sh ../tmp/pg_large 4 64 8 true
|
||||||
39
scenarionet/tests/local_test/convert_pg_large.sh
Normal file
39
scenarionet/tests/local_test/convert_pg_large.sh
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Author: GPT-4
|
||||||
|
# Usage: ./script_name.sh /path/to/datasets 10 5000 8 true
|
||||||
|
|
||||||
|
# check if five arguments are passed
|
||||||
|
if [ $# -ne 5 ]; then
|
||||||
|
echo "Usage: $0 <dataset_path> <num_sub_dataset> <num_scenarios_sub_dataset> <num_workers> <overwrite>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# get the number of scenarios, datasets, dataset path, number of workers, and overwrite from command line arguments
|
||||||
|
dataset_path=$1
|
||||||
|
num_sub_dataset=$2
|
||||||
|
num_scenarios_sub_dataset=$3
|
||||||
|
num_workers=$4
|
||||||
|
overwrite=$5
|
||||||
|
|
||||||
|
# initialize start_index
|
||||||
|
start_index=0
|
||||||
|
|
||||||
|
# run the conversion script in a loop
|
||||||
|
for i in $(seq 1 $num_sub_dataset)
|
||||||
|
do
|
||||||
|
sub_dataset_path="${dataset_path}/pg_$((i-1))"
|
||||||
|
if [ "$overwrite" = true ]; then
|
||||||
|
python -m scenarionet.scripts.convert_pg -n pg -d $sub_dataset_path --start_index=$start_index --num_workers=$num_workers --num_scenarios=$num_scenarios_sub_dataset --overwrite
|
||||||
|
else
|
||||||
|
python -m scenarionet.scripts.convert_pg -n pg -d $sub_dataset_path --start_index=$start_index --num_workers=$num_workers --num_scenarios=$num_scenarios_sub_dataset
|
||||||
|
fi
|
||||||
|
start_index=$((start_index + num_scenarios_sub_dataset))
|
||||||
|
done
|
||||||
|
|
||||||
|
# combine the datasets
|
||||||
|
if [ "$overwrite" = true ]; then
|
||||||
|
python -m scenarionet.scripts.combine_dataset --dataset_path $dataset_path --from_datasets $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --overwrite --exist_ok
|
||||||
|
else
|
||||||
|
python -m scenarionet.scripts.combine_dataset --dataset_path $dataset_path --from_datasets $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --exist_ok
|
||||||
|
fi
|
||||||
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
nohup python ../../scripts/convert_nuplan.py --overwrite > nuplan.log 2>&1 &
|
nohup python ../../convert_nuplan.py --overwrite --test > nuplan.log 2>&1 &
|
||||||
nohup python ../../scripts/convert_nuscenes.py --overwrite > nuscenes.log 2>&1 &
|
nohup python ../../convert_nuscenes.py --overwrite > nuscenes.log 2>&1 &
|
||||||
nohup python ../../scripts/convert_pg.py --overwrite > pg.log 2>&1 &
|
nohup python ../../convert_pg.py --overwrite > pg.log 2>&1 &
|
||||||
nohup python ../../scripts/convert_waymo.py --overwrite > waymo.log 2>&1 &
|
nohup python ../../convert_waymo.py --overwrite > waymo.log 2>&1 &
|
||||||
@@ -12,10 +12,10 @@ if __name__ == "__main__":
|
|||||||
dataset_name = "nuscenes"
|
dataset_name = "nuscenes"
|
||||||
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
||||||
version = 'v1.0-mini'
|
version = 'v1.0-mini'
|
||||||
force_overwrite = True
|
overwrite = True
|
||||||
|
|
||||||
dataroot = '/home/shady/data/nuscenes'
|
dataroot = '/home/shady/data/nuscenes'
|
||||||
scenarios, nusc = get_nuscenes_scenarios(dataroot, version)
|
scenarios, nuscs = get_nuscenes_scenarios(dataroot, version, 2)
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
write_to_directory(
|
write_to_directory(
|
||||||
@@ -24,6 +24,7 @@ if __name__ == "__main__":
|
|||||||
output_path=output_path + "_{}".format(i),
|
output_path=output_path + "_{}".format(i),
|
||||||
dataset_version=version,
|
dataset_version=version,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
force_overwrite=force_overwrite,
|
overwrite=overwrite,
|
||||||
nuscenes=nusc
|
num_workers=2,
|
||||||
|
nuscenes=nuscs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ from metadrive.envs.scenario_env import ScenarioEnv
|
|||||||
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
|
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
|
||||||
from metadrive.scenario.utils import get_number_of_scenarios
|
from metadrive.scenario.utils import get_number_of_scenarios
|
||||||
|
|
||||||
from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_PACKAGE_PATH
|
from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_PACKAGE_PATH, TMP_PATH
|
||||||
from scenarionet.builder.utils import combine_multiple_dataset
|
from scenarionet.builder.utils import combine_dataset
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")]
|
dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")]
|
||||||
@@ -13,8 +13,8 @@ if __name__ == '__main__':
|
|||||||
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "waymo"))
|
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "waymo"))
|
||||||
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "pg"))
|
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "pg"))
|
||||||
|
|
||||||
combine_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "combine")
|
combine_path = os.path.join(TMP_PATH, "combine")
|
||||||
combine_multiple_dataset(combine_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True)
|
combine_dataset(combine_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True)
|
||||||
|
|
||||||
env = ScenarioEnv(
|
env = ScenarioEnv(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
|
|
||||||
from scenarionet import SCENARIONET_PACKAGE_PATH
|
from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH
|
||||||
from scenarionet.builder.utils import combine_multiple_dataset
|
from scenarionet.builder.utils import combine_dataset
|
||||||
from scenarionet.common_utils import read_dataset_summary, read_scenario
|
from scenarionet.common_utils import read_dataset_summary, read_scenario
|
||||||
from scenarionet.verifier.utils import verify_loading_into_metadrive
|
from scenarionet.verifier.utils import verify_dataset
|
||||||
|
|
||||||
|
|
||||||
def test_combine_multiple_dataset():
|
def test_combine_multiple_dataset():
|
||||||
@@ -13,15 +13,15 @@ def test_combine_multiple_dataset():
|
|||||||
test_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset")
|
test_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset")
|
||||||
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
||||||
|
|
||||||
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "combine")
|
output_path = os.path.join(TMP_PATH, "combine")
|
||||||
combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True)
|
combine_dataset(output_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True)
|
||||||
dataset_paths.append(output_path)
|
dataset_paths.append(output_path)
|
||||||
for dataset_path in dataset_paths:
|
for dataset_path in dataset_paths:
|
||||||
summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path)
|
summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path)
|
||||||
for scenario_file in sorted_scenarios:
|
for scenario_file in sorted_scenarios:
|
||||||
read_scenario(dataset_path, mapping, scenario_file)
|
read_scenario(dataset_path, mapping, scenario_file)
|
||||||
success, result = verify_loading_into_metadrive(
|
success, result = verify_dataset(
|
||||||
dataset_path, result_save_dir=test_dataset_path, steps_to_run=1000, num_workers=4
|
dataset_path, result_save_dir=test_dataset_path, steps_to_run=1000, num_workers=4, overwrite=True
|
||||||
)
|
)
|
||||||
assert success
|
assert success
|
||||||
|
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ import os.path
|
|||||||
|
|
||||||
from metadrive.type import MetaDriveType
|
from metadrive.type import MetaDriveType
|
||||||
|
|
||||||
from scenarionet import SCENARIONET_PACKAGE_PATH
|
from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH
|
||||||
from scenarionet.builder.filters import ScenarioFilter
|
from scenarionet.builder.filters import ScenarioFilter
|
||||||
from scenarionet.builder.utils import combine_multiple_dataset
|
from scenarionet.builder.utils import combine_dataset
|
||||||
|
|
||||||
|
|
||||||
def test_filter_dataset():
|
def test_filter_dataset():
|
||||||
@@ -13,17 +13,18 @@ def test_filter_dataset():
|
|||||||
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
||||||
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
||||||
|
|
||||||
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "combine")
|
output_path = os.path.join(TMP_PATH, "combine")
|
||||||
|
|
||||||
# ========================= test 1 =========================
|
# ========================= test 1 =========================
|
||||||
# nuscenes data has no light
|
# nuscenes data has no light
|
||||||
# light_condition = ScenarioFilter.make(ScenarioFilter.has_traffic_light)
|
# light_condition = ScenarioFilter.make(ScenarioFilter.has_traffic_light)
|
||||||
sdc_driving_condition = ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=30, condition="smaller")
|
sdc_driving_condition = ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=30, condition="smaller")
|
||||||
answer = ['sd_nuscenes_v1.0-mini_scene-0553.pkl', '0.pkl', 'sd_nuscenes_v1.0-mini_scene-1100.pkl']
|
answer = ['sd_nuscenes_v1.0-mini_scene-0553.pkl', '0.pkl', 'sd_nuscenes_v1.0-mini_scene-1100.pkl']
|
||||||
summary, mapping = combine_multiple_dataset(
|
summary, mapping = combine_dataset(
|
||||||
output_path,
|
output_path,
|
||||||
*dataset_paths,
|
*dataset_paths,
|
||||||
force_overwrite=True,
|
exist_ok=True,
|
||||||
|
overwrite=True,
|
||||||
try_generate_missing_file=True,
|
try_generate_missing_file=True,
|
||||||
filters=[sdc_driving_condition]
|
filters=[sdc_driving_condition]
|
||||||
)
|
)
|
||||||
@@ -37,10 +38,11 @@ def test_filter_dataset():
|
|||||||
assert in_, summary.keys()
|
assert in_, summary.keys()
|
||||||
|
|
||||||
sdc_driving_condition = ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=5, condition="greater")
|
sdc_driving_condition = ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=5, condition="greater")
|
||||||
summary, mapping = combine_multiple_dataset(
|
summary, mapping = combine_dataset(
|
||||||
output_path,
|
output_path,
|
||||||
*dataset_paths,
|
*dataset_paths,
|
||||||
force_overwrite=True,
|
exist_ok=True,
|
||||||
|
overwrite=True,
|
||||||
try_generate_missing_file=True,
|
try_generate_missing_file=True,
|
||||||
filters=[sdc_driving_condition]
|
filters=[sdc_driving_condition]
|
||||||
)
|
)
|
||||||
@@ -53,8 +55,13 @@ def test_filter_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
answer = ['sd_nuscenes_v1.0-mini_scene-0061.pkl', 'sd_nuscenes_v1.0-mini_scene-1094.pkl']
|
answer = ['sd_nuscenes_v1.0-mini_scene-0061.pkl', 'sd_nuscenes_v1.0-mini_scene-1094.pkl']
|
||||||
summary, mapping = combine_multiple_dataset(
|
summary, mapping = combine_dataset(
|
||||||
output_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True, filters=[num_condition]
|
output_path,
|
||||||
|
*dataset_paths,
|
||||||
|
exist_ok=True,
|
||||||
|
overwrite=True,
|
||||||
|
try_generate_missing_file=True,
|
||||||
|
filters=[num_condition]
|
||||||
)
|
)
|
||||||
assert len(answer) == len(summary)
|
assert len(answer) == len(summary)
|
||||||
for a in answer:
|
for a in answer:
|
||||||
@@ -62,8 +69,13 @@ def test_filter_dataset():
|
|||||||
|
|
||||||
num_condition = ScenarioFilter.make(ScenarioFilter.object_number, number_threshold=50, condition="greater")
|
num_condition = ScenarioFilter.make(ScenarioFilter.object_number, number_threshold=50, condition="greater")
|
||||||
|
|
||||||
summary, mapping = combine_multiple_dataset(
|
summary, mapping = combine_dataset(
|
||||||
output_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True, filters=[num_condition]
|
output_path,
|
||||||
|
*dataset_paths,
|
||||||
|
exist_ok=True,
|
||||||
|
overwrite=True,
|
||||||
|
try_generate_missing_file=True,
|
||||||
|
filters=[num_condition]
|
||||||
)
|
)
|
||||||
assert len(summary) > 0
|
assert len(summary) > 0
|
||||||
|
|
||||||
|
|||||||
@@ -4,41 +4,40 @@ import os.path
|
|||||||
|
|
||||||
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
||||||
|
|
||||||
from scenarionet import SCENARIONET_PACKAGE_PATH
|
from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH
|
||||||
from scenarionet.builder.utils import combine_multiple_dataset
|
from scenarionet.builder.utils import combine_dataset
|
||||||
from scenarionet.common_utils import read_dataset_summary, read_scenario
|
from scenarionet.common_utils import read_dataset_summary, read_scenario
|
||||||
from scenarionet.common_utils import recursive_equal
|
from scenarionet.common_utils import recursive_equal
|
||||||
from scenarionet.verifier.error import ErrorFile
|
from scenarionet.verifier.error import ErrorFile
|
||||||
from scenarionet.verifier.utils import verify_loading_into_metadrive, set_random_drop
|
from scenarionet.verifier.utils import verify_dataset, set_random_drop
|
||||||
|
|
||||||
|
|
||||||
def test_generate_from_error():
|
def test_generate_from_error():
|
||||||
set_random_drop(True)
|
set_random_drop(True)
|
||||||
dataset_name = "nuscenes"
|
dataset_name = "nuscenes"
|
||||||
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
||||||
test_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset")
|
|
||||||
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
||||||
dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "combine")
|
dataset_path = os.path.join(TMP_PATH, "combine")
|
||||||
combine_multiple_dataset(dataset_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True)
|
combine_dataset(dataset_path, *dataset_paths, exist_ok=True, try_generate_missing_file=True, overwrite=True)
|
||||||
|
|
||||||
summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path)
|
summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path)
|
||||||
for scenario_file in sorted_scenarios:
|
for scenario_file in sorted_scenarios:
|
||||||
read_scenario(dataset_path, mapping, scenario_file)
|
read_scenario(dataset_path, mapping, scenario_file)
|
||||||
success, logs = verify_loading_into_metadrive(
|
success, logs = verify_dataset(
|
||||||
dataset_path, result_save_dir=test_dataset_path, steps_to_run=1000, num_workers=3
|
dataset_path, result_save_dir=TMP_PATH, steps_to_run=1000, num_workers=3, overwrite=True
|
||||||
)
|
)
|
||||||
set_random_drop(False)
|
set_random_drop(False)
|
||||||
# get error file
|
# get error file
|
||||||
file_name = ErrorFile.get_error_file_name(dataset_path)
|
file_name = ErrorFile.get_error_file_name(dataset_path)
|
||||||
error_file_path = os.path.join(test_dataset_path, file_name)
|
error_file_path = os.path.join(TMP_PATH, file_name)
|
||||||
# regenerate
|
# regenerate
|
||||||
pass_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "passed_senarios")
|
pass_dataset = os.path.join(TMP_PATH, "passed_senarios")
|
||||||
fail_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "failed_scenarios")
|
fail_dataset = os.path.join(TMP_PATH, "failed_scenarios")
|
||||||
pass_summary, pass_mapping = ErrorFile.generate_dataset(
|
pass_summary, pass_mapping = ErrorFile.generate_dataset(
|
||||||
error_file_path, pass_dataset, force_overwrite=True, broken_scenario=False
|
error_file_path, pass_dataset, overwrite=True, broken_scenario=False
|
||||||
)
|
)
|
||||||
fail_summary, fail_mapping = ErrorFile.generate_dataset(
|
fail_summary, fail_mapping = ErrorFile.generate_dataset(
|
||||||
error_file_path, fail_dataset, force_overwrite=True, broken_scenario=True
|
error_file_path, fail_dataset, overwrite=True, broken_scenario=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
|
|||||||
35
scenarionet/tests/test_verify_completeness.py
Normal file
35
scenarionet/tests/test_verify_completeness.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import os
|
||||||
|
import os.path
|
||||||
|
|
||||||
|
from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH
|
||||||
|
from scenarionet.builder.utils import combine_dataset
|
||||||
|
from scenarionet.common_utils import read_dataset_summary, read_scenario
|
||||||
|
from scenarionet.verifier.utils import verify_dataset, set_random_drop
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_completeness():
|
||||||
|
dataset_name = "nuscenes"
|
||||||
|
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
||||||
|
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
||||||
|
|
||||||
|
output_path = os.path.join(TMP_PATH, "combine")
|
||||||
|
combine_dataset(output_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True)
|
||||||
|
dataset_path = output_path
|
||||||
|
summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path)
|
||||||
|
for scenario_file in sorted_scenarios:
|
||||||
|
read_scenario(dataset_path, mapping, scenario_file)
|
||||||
|
set_random_drop(True)
|
||||||
|
success, result = verify_dataset(
|
||||||
|
dataset_path, result_save_dir=TMP_PATH, steps_to_run=0, num_workers=4, overwrite=True
|
||||||
|
)
|
||||||
|
assert not success
|
||||||
|
|
||||||
|
set_random_drop(False)
|
||||||
|
success, result = verify_dataset(
|
||||||
|
dataset_path, result_save_dir=TMP_PATH, steps_to_run=0, num_workers=4, overwrite=True
|
||||||
|
)
|
||||||
|
assert success
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_verify_completeness()
|
||||||
@@ -51,24 +51,23 @@ class ErrorFile:
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_dataset(cls, error_file_path, new_dataset_path, force_overwrite=False, broken_scenario=False):
|
def generate_dataset(cls, error_file_path, new_dataset_path, overwrite=False, broken_scenario=False):
|
||||||
"""
|
"""
|
||||||
Generate a new dataset containing all broken scenarios or all good scenarios
|
Generate a new dataset containing all broken scenarios or all good scenarios
|
||||||
:param error_file_path: error file path
|
:param error_file_path: error file path
|
||||||
:param new_dataset_path: a directory where you want to store your data
|
:param new_dataset_path: a directory where you want to store your data
|
||||||
:param force_overwrite: if new_dataset_path exists, whether to overwrite
|
:param overwrite: if new_dataset_path exists, whether to overwrite
|
||||||
:param broken_scenario: generate broken scenarios. You can generate such a broken scenarios for debugging
|
:param broken_scenario: generate broken scenarios. You can generate such a broken scenarios for debugging
|
||||||
:return: dataset summary, dataset mapping
|
:return: dataset summary, dataset mapping
|
||||||
"""
|
"""
|
||||||
# TODO Add test!
|
|
||||||
new_dataset_path = os.path.abspath(new_dataset_path)
|
new_dataset_path = os.path.abspath(new_dataset_path)
|
||||||
if os.path.exists(new_dataset_path):
|
if os.path.exists(new_dataset_path):
|
||||||
if force_overwrite:
|
if overwrite:
|
||||||
shutil.rmtree(new_dataset_path)
|
shutil.rmtree(new_dataset_path)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Directory: {} already exists! "
|
"Directory: {} already exists! "
|
||||||
"Set force_overwrite=True to overwrite".format(new_dataset_path)
|
"Set overwrite=True to overwrite".format(new_dataset_path)
|
||||||
)
|
)
|
||||||
os.makedirs(new_dataset_path, exist_ok=False)
|
os.makedirs(new_dataset_path, exist_ok=False)
|
||||||
|
|
||||||
|
|||||||
@@ -4,12 +4,14 @@ import os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from scenarionet.common_utils import read_scenario, read_dataset_summary
|
||||||
from scenarionet.verifier.error import ErrorDescription as ED
|
from scenarionet.verifier.error import ErrorDescription as ED
|
||||||
from scenarionet.verifier.error import ErrorFile as EF
|
from scenarionet.verifier.error import ErrorFile as EF
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
import tqdm
|
import tqdm
|
||||||
from metadrive.envs.scenario_env import ScenarioEnv
|
from metadrive.envs.scenario_env import ScenarioEnv
|
||||||
|
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
||||||
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
|
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
|
||||||
from metadrive.scenario.utils import get_number_of_scenarios
|
from metadrive.scenario.utils import get_number_of_scenarios
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -23,9 +25,16 @@ def set_random_drop(drop):
|
|||||||
RANDOM_DROP = drop
|
RANDOM_DROP = drop
|
||||||
|
|
||||||
|
|
||||||
def verify_loading_into_metadrive(dataset_path, result_save_dir, steps_to_run=1000, num_workers=8):
|
def verify_dataset(dataset_path, result_save_dir, overwrite=False, num_workers=8, steps_to_run=1000):
|
||||||
|
global RANDOM_DROP
|
||||||
assert os.path.isdir(result_save_dir), "result_save_dir must be a dir, get {}".format(result_save_dir)
|
assert os.path.isdir(result_save_dir), "result_save_dir must be a dir, get {}".format(result_save_dir)
|
||||||
os.makedirs(result_save_dir, exist_ok=True)
|
os.makedirs(result_save_dir, exist_ok=True)
|
||||||
|
error_file_name = EF.get_error_file_name(dataset_path)
|
||||||
|
if os.path.exists(os.path.join(result_save_dir, error_file_name)) and not overwrite:
|
||||||
|
raise FileExistsError(
|
||||||
|
"An error_file already exists in result_save_directory. "
|
||||||
|
"Setting overwrite=True to cancel this alert"
|
||||||
|
)
|
||||||
num_scenario = get_number_of_scenarios(dataset_path)
|
num_scenario = get_number_of_scenarios(dataset_path)
|
||||||
if num_scenario < num_workers:
|
if num_scenario < num_workers:
|
||||||
# single process
|
# single process
|
||||||
@@ -34,7 +43,7 @@ def verify_loading_into_metadrive(dataset_path, result_save_dir, steps_to_run=10
|
|||||||
|
|
||||||
# prepare arguments
|
# prepare arguments
|
||||||
argument_list = []
|
argument_list = []
|
||||||
func = partial(loading_wrapper, dataset_path=dataset_path, steps_to_run=steps_to_run)
|
func = partial(loading_wrapper, dataset_path=dataset_path, steps_to_run=steps_to_run, random_drop=RANDOM_DROP)
|
||||||
|
|
||||||
num_scenario_each_worker = int(num_scenario // num_workers)
|
num_scenario_each_worker = int(num_scenario // num_workers)
|
||||||
for i in range(num_workers):
|
for i in range(num_workers):
|
||||||
@@ -64,14 +73,35 @@ def verify_loading_into_metadrive(dataset_path, result_save_dir, steps_to_run=10
|
|||||||
return success, errors
|
return success, errors
|
||||||
|
|
||||||
|
|
||||||
def loading_into_metadrive(start_scenario_index, num_scenario, dataset_path, steps_to_run, metadrive_config=None):
|
def loading_into_metadrive(
|
||||||
global RANDOM_DROP
|
start_scenario_index, num_scenario, dataset_path, steps_to_run, metadrive_config=None, random_drop=False
|
||||||
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
"================ Begin Scenario Loading Verification for scenario {}-{} ================ \n".format(
|
"================ Begin Scenario Loading Verification for scenario {}-{} ================ \n".format(
|
||||||
start_scenario_index, num_scenario + start_scenario_index
|
start_scenario_index, num_scenario + start_scenario_index
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
success = True
|
success = True
|
||||||
|
error_msgs = []
|
||||||
|
|
||||||
|
if steps_to_run == 0:
|
||||||
|
summary, scenarios, mapping = read_dataset_summary(dataset_path)
|
||||||
|
index_count = -1
|
||||||
|
for file_name in tqdm.tqdm(scenarios):
|
||||||
|
index_count += 1
|
||||||
|
try:
|
||||||
|
scenario = read_scenario(dataset_path, mapping, file_name)
|
||||||
|
SD.sanity_check(scenario)
|
||||||
|
if random_drop and np.random.rand() < 0.5:
|
||||||
|
raise ValueError("Random Drop")
|
||||||
|
except Exception as e:
|
||||||
|
file_path = os.path.join(dataset_path, mapping[file_name], file_name)
|
||||||
|
error_msg = ED.make(index_count, file_path, file_name, str(e))
|
||||||
|
error_msgs.append(error_msg)
|
||||||
|
success = False
|
||||||
|
# proceed to next scenario
|
||||||
|
continue
|
||||||
|
else:
|
||||||
metadrive_config = metadrive_config or {}
|
metadrive_config = metadrive_config or {}
|
||||||
metadrive_config.update(
|
metadrive_config.update(
|
||||||
{
|
{
|
||||||
@@ -85,13 +115,12 @@ def loading_into_metadrive(start_scenario_index, num_scenario, dataset_path, ste
|
|||||||
)
|
)
|
||||||
env = ScenarioEnv(metadrive_config)
|
env = ScenarioEnv(metadrive_config)
|
||||||
logging.disable(logging.INFO)
|
logging.disable(logging.INFO)
|
||||||
error_msgs = []
|
|
||||||
desc = "Scenarios: {}-{}".format(start_scenario_index, start_scenario_index + num_scenario)
|
desc = "Scenarios: {}-{}".format(start_scenario_index, start_scenario_index + num_scenario)
|
||||||
for scenario_index in tqdm.tqdm(range(start_scenario_index, start_scenario_index + num_scenario), desc=desc):
|
for scenario_index in tqdm.tqdm(range(start_scenario_index, start_scenario_index + num_scenario), desc=desc):
|
||||||
try:
|
try:
|
||||||
env.reset(force_seed=scenario_index)
|
env.reset(force_seed=scenario_index)
|
||||||
arrive = False
|
arrive = False
|
||||||
if RANDOM_DROP and np.random.rand() < 0.5:
|
if random_drop and np.random.rand() < 0.5:
|
||||||
raise ValueError("Random Drop")
|
raise ValueError("Random Drop")
|
||||||
for _ in range(steps_to_run):
|
for _ in range(steps_to_run):
|
||||||
o, r, d, info = env.step([0, 0])
|
o, r, d, info = env.step([0, 0])
|
||||||
@@ -111,6 +140,8 @@ def loading_into_metadrive(start_scenario_index, num_scenario, dataset_path, ste
|
|||||||
return success, error_msgs
|
return success, error_msgs
|
||||||
|
|
||||||
|
|
||||||
def loading_wrapper(arglist, dataset_path, steps_to_run):
|
def loading_wrapper(arglist, dataset_path, steps_to_run, random_drop):
|
||||||
assert len(arglist) == 2, "Too much arguments!"
|
assert len(arglist) == 2, "Too much arguments!"
|
||||||
return loading_into_metadrive(arglist[0], arglist[1], dataset_path=dataset_path, steps_to_run=steps_to_run)
|
return loading_into_metadrive(
|
||||||
|
arglist[0], arglist[1], dataset_path=dataset_path, steps_to_run=steps_to_run, random_drop=random_drop
|
||||||
|
)
|
||||||
|
|||||||
23
scenarionet/verify_completeness.py
Normal file
23
scenarionet/verify_completeness.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
from scenarionet.verifier.utils import verify_dataset, set_random_drop
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_path", "-d", required=True, help="Dataset path, a directory containing summary.pkl and mapping.pkl"
|
||||||
|
)
|
||||||
|
parser.add_argument("--result_save_dir", default="./", help="Where to save the error file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--overwrite",
|
||||||
|
action="store_true",
|
||||||
|
help="If an error file already exists in result_save_dir, "
|
||||||
|
"whether to overwrite it"
|
||||||
|
)
|
||||||
|
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
||||||
|
parser.add_argument("--random_drop", action="store_true", help="Randomly make some scenarios fail. for test only!")
|
||||||
|
args = parser.parse_args()
|
||||||
|
set_random_drop(args.random_drop)
|
||||||
|
verify_dataset(
|
||||||
|
args.dataset_path, args.result_save_dir, overwrite=args.overwrite, num_workers=args.num_workers, steps_to_run=0
|
||||||
|
)
|
||||||
21
scenarionet/verify_simulation.py
Normal file
21
scenarionet/verify_simulation.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import pkg_resources # for suppress warning
|
||||||
|
import argparse
|
||||||
|
from scenarionet.verifier.utils import verify_dataset, set_random_drop
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_path", "-d", required=True, help="Dataset path, a directory containing summary.pkl and mapping.pkl"
|
||||||
|
)
|
||||||
|
parser.add_argument("--result_save_dir", default="./", help="Where to save the error file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--overwrite",
|
||||||
|
action="store_true",
|
||||||
|
help="If an error file already exists in result_save_dir, "
|
||||||
|
"whether to overwrite it"
|
||||||
|
)
|
||||||
|
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
||||||
|
parser.add_argument("--random_drop", action="store_true", help="Randomly make some scenarios fail. for test only!")
|
||||||
|
args = parser.parse_args()
|
||||||
|
set_random_drop(args.random_drop)
|
||||||
|
verify_dataset(args.dataset_path, args.result_save_dir, overwrite=args.overwrite, num_workers=args.num_workers)
|
||||||
Reference in New Issue
Block a user