chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,10 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
# appended to the __init__.py in the onnxruntime module's 'tools' folder from /tools/python/util/__init__append.py
|
||||
import importlib.util
|
||||
|
||||
have_torch = importlib.util.find_spec("torch")
|
||||
if have_torch:
|
||||
from .pytorch_export_helpers import infer_input_info # noqa: F401
|
||||
@@ -0,0 +1,47 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
# need this before the mobile helper imports for some reason
|
||||
logging.basicConfig(format="%(levelname)s: %(message)s")
|
||||
|
||||
from .mobile_helpers import usability_checker # noqa: E402
|
||||
|
||||
|
||||
def check_usability():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Analyze an ONNX model to determine how well it will work in mobile scenarios.""",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--log_level", choices=["debug", "info"], default="info", help="Logging level")
|
||||
parser.add_argument("model_path", help="Path to ONNX model to check", type=pathlib.Path)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger = logging.getLogger("check_usability")
|
||||
|
||||
if args.log_level == "debug":
|
||||
logger.setLevel(logging.DEBUG)
|
||||
elif args.log_level == "info":
|
||||
logger.setLevel(logging.INFO)
|
||||
elif args.log_level == "warning":
|
||||
logger.setLevel(logging.WARNING)
|
||||
else:
|
||||
logger.setLevel(logging.ERROR)
|
||||
|
||||
try_eps = usability_checker.analyze_model(args.model_path, skip_optimize=False, logger=logger)
|
||||
|
||||
if try_eps:
|
||||
logger.info(
|
||||
"As NNAPI or CoreML may provide benefits with this model it is recommended to compare the "
|
||||
"performance of the model using the NNAPI EP on Android, and the CoreML EP on iOS, "
|
||||
"against the performance using the CPU EP."
|
||||
)
|
||||
else:
|
||||
logger.info("For optimal performance the model should be used with the CPU EP. ")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_usability()
|
||||
@@ -0,0 +1,380 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import enum
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
from .file_utils import files_from_file_or_dir, path_match_suffix_ignore_case
|
||||
from .onnx_model_utils import get_optimization_level
|
||||
from .ort_format_model import create_config_from_models
|
||||
|
||||
|
||||
class OptimizationStyle(enum.Enum):
|
||||
Fixed = 0
|
||||
Runtime = 1
|
||||
|
||||
|
||||
def _optimization_suffix(optimization_level_str: str, optimization_style: OptimizationStyle, suffix: str):
|
||||
return "{}{}{}".format(
|
||||
f".{optimization_level_str}" if optimization_level_str != "all" else "",
|
||||
".with_runtime_opt" if optimization_style == OptimizationStyle.Runtime else "",
|
||||
suffix,
|
||||
)
|
||||
|
||||
|
||||
def _create_config_file_path(
|
||||
model_path_or_dir: pathlib.Path,
|
||||
output_dir: pathlib.Path | None,
|
||||
optimization_level_str: str,
|
||||
optimization_style: OptimizationStyle,
|
||||
enable_type_reduction: bool,
|
||||
):
|
||||
config_name = "{}{}".format(
|
||||
"required_operators_and_types" if enable_type_reduction else "required_operators",
|
||||
_optimization_suffix(optimization_level_str, optimization_style, ".config"),
|
||||
)
|
||||
|
||||
if model_path_or_dir.is_dir():
|
||||
return (output_dir or model_path_or_dir) / config_name
|
||||
|
||||
model_config_path = model_path_or_dir.with_suffix(f".{config_name}")
|
||||
|
||||
if output_dir is not None:
|
||||
return output_dir / model_config_path.name
|
||||
|
||||
return model_config_path
|
||||
|
||||
|
||||
def _create_session_options(
|
||||
optimization_level: ort.GraphOptimizationLevel,
|
||||
output_model_path: pathlib.Path,
|
||||
custom_op_library: pathlib.Path,
|
||||
session_options_config_entries: dict[str, str],
|
||||
):
|
||||
so = ort.SessionOptions()
|
||||
so.optimized_model_filepath = str(output_model_path)
|
||||
so.graph_optimization_level = optimization_level
|
||||
|
||||
if custom_op_library:
|
||||
so.register_custom_ops_library(str(custom_op_library))
|
||||
|
||||
for key, value in session_options_config_entries.items():
|
||||
so.add_session_config_entry(key, value)
|
||||
|
||||
return so
|
||||
|
||||
|
||||
def _convert(
|
||||
model_path_or_dir: pathlib.Path,
|
||||
output_dir: pathlib.Path | None,
|
||||
optimization_level_str: str,
|
||||
optimization_style: OptimizationStyle,
|
||||
custom_op_library: pathlib.Path,
|
||||
create_optimized_onnx_model: bool,
|
||||
allow_conversion_failures: bool,
|
||||
target_platform: str,
|
||||
session_options_config_entries: dict[str, str],
|
||||
) -> list[pathlib.Path]:
|
||||
model_dir = model_path_or_dir if model_path_or_dir.is_dir() else model_path_or_dir.parent
|
||||
output_dir = output_dir or model_dir
|
||||
|
||||
optimization_level = get_optimization_level(optimization_level_str)
|
||||
|
||||
def is_model_file_to_convert(file_path: pathlib.Path):
|
||||
if not path_match_suffix_ignore_case(file_path, ".onnx"):
|
||||
return False
|
||||
# ignore any files with an extension of .optimized.onnx which are presumably from previous executions
|
||||
# of this script
|
||||
if path_match_suffix_ignore_case(file_path, ".optimized.onnx"):
|
||||
print(f"Ignoring '{file_path}'")
|
||||
return False
|
||||
return True
|
||||
|
||||
models = files_from_file_or_dir(model_path_or_dir, is_model_file_to_convert)
|
||||
|
||||
if len(models) == 0:
|
||||
raise ValueError(f"No model files were found in '{model_path_or_dir}'")
|
||||
|
||||
providers = ["CPUExecutionProvider"]
|
||||
|
||||
# if the optimization level is greater than or equal to 'layout' we manually exclude the NCHWc transformer.
|
||||
# It's not applicable to ARM devices, and creates a device specific model which won't run on all hardware.
|
||||
# If someone really really really wants to run it they could manually create an optimized onnx model first,
|
||||
# or they could comment out this code.
|
||||
optimizer_filter = None
|
||||
if (
|
||||
(optimization_level == ort.GraphOptimizationLevel.ORT_ENABLE_ALL)
|
||||
or (optimization_level == ort.GraphOptimizationLevel.ORT_ENABLE_LAYOUT)
|
||||
) and target_platform != "amd64":
|
||||
optimizer_filter = ["NchwcTransformer"]
|
||||
|
||||
converted_models = []
|
||||
|
||||
for model in models:
|
||||
try:
|
||||
relative_model_path = model.relative_to(model_dir)
|
||||
|
||||
(output_dir / relative_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ort_target_path = (output_dir / relative_model_path).with_suffix(
|
||||
_optimization_suffix(optimization_level_str, optimization_style, ".ort")
|
||||
)
|
||||
|
||||
if create_optimized_onnx_model:
|
||||
# Create an ONNX file with the same optimization level that will be used for the ORT format file.
|
||||
# This allows the ONNX equivalent of the ORT format model to be easily viewed in Netron.
|
||||
# If runtime optimizations are saved in the ORT format model, there may be some difference in the
|
||||
# graphs at runtime between the ORT format model and this saved ONNX model.
|
||||
optimized_target_path = (output_dir / relative_model_path).with_suffix(
|
||||
_optimization_suffix(optimization_level_str, optimization_style, ".optimized.onnx")
|
||||
)
|
||||
so = _create_session_options(
|
||||
optimization_level, optimized_target_path, custom_op_library, session_options_config_entries
|
||||
)
|
||||
if optimization_style == OptimizationStyle.Runtime:
|
||||
# Limit the optimizations to those that can run in a model with runtime optimizations.
|
||||
so.add_session_config_entry("optimization.minimal_build_optimizations", "apply")
|
||||
|
||||
print(f"Saving optimized ONNX model {model} to {optimized_target_path}")
|
||||
_ = ort.InferenceSession(
|
||||
str(model), sess_options=so, providers=providers, disabled_optimizers=optimizer_filter
|
||||
)
|
||||
|
||||
# Load ONNX model, optimize, and save to ORT format
|
||||
so = _create_session_options(
|
||||
optimization_level, ort_target_path, custom_op_library, session_options_config_entries
|
||||
)
|
||||
so.add_session_config_entry("session.save_model_format", "ORT")
|
||||
if optimization_style == OptimizationStyle.Runtime:
|
||||
so.add_session_config_entry("optimization.minimal_build_optimizations", "save")
|
||||
|
||||
print(f"Converting optimized ONNX model {model} to ORT format model {ort_target_path}")
|
||||
_ = ort.InferenceSession(
|
||||
str(model), sess_options=so, providers=providers, disabled_optimizers=optimizer_filter
|
||||
)
|
||||
|
||||
converted_models.append(ort_target_path)
|
||||
|
||||
# orig_size = os.path.getsize(onnx_target_path)
|
||||
# new_size = os.path.getsize(ort_target_path)
|
||||
# print("Serialized {} to {}. Sizes: orig={} new={} diff={} new:old={:.4f}:1.0".format(
|
||||
# onnx_target_path, ort_target_path, orig_size, new_size, new_size - orig_size, new_size / orig_size))
|
||||
except Exception as e:
|
||||
print(f"Error converting {model}: {e}")
|
||||
if not allow_conversion_failures:
|
||||
raise
|
||||
|
||||
print(f"Converted {len(converted_models)}/{len(models)} models successfully.")
|
||||
|
||||
return converted_models
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
os.path.basename(__file__),
|
||||
description="""Convert the ONNX format model/s in the provided directory to ORT format models.
|
||||
All files with a `.onnx` extension will be processed. For each one, an ORT format model will be created in the
|
||||
given output directory, if specified, or the same directory.
|
||||
A configuration file will also be created containing the list of required operators for all
|
||||
converted models. This configuration file should be used as input to the minimal build via the
|
||||
`--include_ops_by_config` parameter.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=pathlib.Path,
|
||||
help="Provide an output directory for the converted model/s and configuration file. "
|
||||
"If unspecified, the converted ORT format model/s will be in the same directory as the ONNX model/s.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--optimization_style",
|
||||
nargs="+",
|
||||
default=[OptimizationStyle.Fixed.name, OptimizationStyle.Runtime.name],
|
||||
choices=[e.name for e in OptimizationStyle],
|
||||
help="Style of optimization to perform on the ORT format model. "
|
||||
"Multiple values may be provided. The conversion will run once for each value. "
|
||||
"The general guidance is to use models optimized with "
|
||||
f"'{OptimizationStyle.Runtime.name}' style when using NNAPI or CoreML and "
|
||||
f"'{OptimizationStyle.Fixed.name}' style otherwise. "
|
||||
f"'{OptimizationStyle.Fixed.name}': Run optimizations directly before saving the ORT "
|
||||
"format model. This bakes in any platform-specific optimizations. "
|
||||
f"'{OptimizationStyle.Runtime.name}': Run basic optimizations directly and save certain "
|
||||
"other optimizations to be applied at runtime if possible. This is useful when using a "
|
||||
"compiling EP like NNAPI or CoreML that may run an unknown (at model conversion time) "
|
||||
"number of nodes. The saved optimizations can further optimize nodes not assigned to the "
|
||||
"compiling EP at runtime.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable_type_reduction",
|
||||
action="store_true",
|
||||
help="Add operator specific type information to the configuration file to potentially reduce "
|
||||
"the types supported by individual operator implementations.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--custom_op_library",
|
||||
type=pathlib.Path,
|
||||
default=None,
|
||||
help="Provide path to shared library containing custom operator kernels to register.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_optimized_onnx_model",
|
||||
action="store_true",
|
||||
help="Save the optimized version of each ONNX model. "
|
||||
"This will have the same level of optimizations applied as the ORT format model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--allow_conversion_failures",
|
||||
action="store_true",
|
||||
help="Whether to proceed after encountering model conversion failures.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target_platform",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["arm", "amd64"],
|
||||
help="Specify the target platform where the exported model will be used. "
|
||||
"This parameter can be used to choose between platform-specific options, "
|
||||
"such as QDQIsInt8Allowed(arm), NCHWc (amd64) and NHWC (arm/amd64) format, different "
|
||||
"optimizer level options, etc.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"model_path_or_dir",
|
||||
type=pathlib.Path,
|
||||
help="Provide path to ONNX model or directory containing ONNX model/s to convert. "
|
||||
"All files with a .onnx extension, including those in subdirectories, will be "
|
||||
"processed.",
|
||||
)
|
||||
|
||||
parsed_args = parser.parse_args()
|
||||
parsed_args.optimization_style = [OptimizationStyle[style_str] for style_str in parsed_args.optimization_style]
|
||||
return parsed_args
|
||||
|
||||
|
||||
def convert_onnx_models_to_ort(
|
||||
model_path_or_dir: pathlib.Path,
|
||||
output_dir: pathlib.Path | None = None,
|
||||
optimization_styles: list[OptimizationStyle] | None = None,
|
||||
custom_op_library_path: pathlib.Path | None = None,
|
||||
target_platform: str | None = None,
|
||||
save_optimized_onnx_model: bool = False,
|
||||
allow_conversion_failures: bool = False,
|
||||
enable_type_reduction: bool = False,
|
||||
):
|
||||
if output_dir is not None:
|
||||
if not output_dir.is_dir():
|
||||
output_dir.mkdir(parents=True)
|
||||
output_dir = output_dir.resolve(strict=True)
|
||||
|
||||
optimization_styles = optimization_styles or []
|
||||
|
||||
# setting optimization level is not expected to be needed by typical users, but it can be set with this
|
||||
# environment variable
|
||||
optimization_level_str = os.getenv("ORT_CONVERT_ONNX_MODELS_TO_ORT_OPTIMIZATION_LEVEL", "all")
|
||||
model_path_or_dir = model_path_or_dir.resolve()
|
||||
custom_op_library = custom_op_library_path.resolve() if custom_op_library_path else None
|
||||
|
||||
if not model_path_or_dir.is_dir() and not model_path_or_dir.is_file():
|
||||
raise FileNotFoundError(f"Model path '{model_path_or_dir}' is not a file or directory.")
|
||||
|
||||
if custom_op_library and not custom_op_library.is_file():
|
||||
raise FileNotFoundError(f"Unable to find custom operator library '{custom_op_library}'")
|
||||
|
||||
session_options_config_entries = {}
|
||||
|
||||
if target_platform is not None and target_platform == "arm":
|
||||
session_options_config_entries["session.qdqisint8allowed"] = "1"
|
||||
else:
|
||||
session_options_config_entries["session.qdqisint8allowed"] = "0"
|
||||
|
||||
for optimization_style in optimization_styles:
|
||||
print(
|
||||
f"Converting models with optimization style '{optimization_style.name}' and level '{optimization_level_str}'"
|
||||
)
|
||||
|
||||
converted_models = _convert(
|
||||
model_path_or_dir=model_path_or_dir,
|
||||
output_dir=output_dir,
|
||||
optimization_level_str=optimization_level_str,
|
||||
optimization_style=optimization_style,
|
||||
custom_op_library=custom_op_library,
|
||||
create_optimized_onnx_model=save_optimized_onnx_model,
|
||||
allow_conversion_failures=allow_conversion_failures,
|
||||
target_platform=target_platform,
|
||||
session_options_config_entries=session_options_config_entries,
|
||||
)
|
||||
|
||||
with contextlib.ExitStack() as context_stack:
|
||||
if optimization_style == OptimizationStyle.Runtime:
|
||||
# Convert models again without runtime optimizations.
|
||||
# Runtime optimizations may not end up being applied, so we need to use both converted models with and
|
||||
# without runtime optimizations to get a complete set of ops that may be needed for the config file.
|
||||
model_dir = model_path_or_dir if model_path_or_dir.is_dir() else model_path_or_dir.parent
|
||||
temp_output_dir = context_stack.enter_context(
|
||||
tempfile.TemporaryDirectory(dir=model_dir, suffix=".without_runtime_opt")
|
||||
)
|
||||
session_options_config_entries_for_second_conversion = session_options_config_entries.copy()
|
||||
# Limit the optimizations to those that can run in a model with runtime optimizations.
|
||||
session_options_config_entries_for_second_conversion["optimization.minimal_build_optimizations"] = (
|
||||
"apply"
|
||||
)
|
||||
|
||||
print(
|
||||
"Converting models again without runtime optimizations to generate a complete config file. "
|
||||
"These converted models are temporary and will be deleted."
|
||||
)
|
||||
converted_models += _convert(
|
||||
model_path_or_dir=model_path_or_dir,
|
||||
output_dir=temp_output_dir,
|
||||
optimization_level_str=optimization_level_str,
|
||||
optimization_style=OptimizationStyle.Fixed,
|
||||
custom_op_library=custom_op_library,
|
||||
create_optimized_onnx_model=False, # not useful as they would be created in a temp directory
|
||||
allow_conversion_failures=allow_conversion_failures,
|
||||
target_platform=target_platform,
|
||||
session_options_config_entries=session_options_config_entries_for_second_conversion,
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generating config file from ORT format models with optimization style '{optimization_style.name}' and level '{optimization_level_str}'"
|
||||
)
|
||||
|
||||
config_file = _create_config_file_path(
|
||||
model_path_or_dir,
|
||||
output_dir,
|
||||
optimization_level_str,
|
||||
optimization_style,
|
||||
enable_type_reduction,
|
||||
)
|
||||
|
||||
create_config_from_models(converted_models, config_file, enable_type_reduction)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
convert_onnx_models_to_ort(
|
||||
args.model_path_or_dir,
|
||||
output_dir=args.output_dir,
|
||||
optimization_styles=args.optimization_style,
|
||||
custom_op_library_path=args.custom_op_library,
|
||||
target_platform=args.target_platform,
|
||||
save_optimized_onnx_model=args.save_optimized_onnx_model,
|
||||
allow_conversion_failures=args.allow_conversion_failures,
|
||||
enable_type_reduction=args.enable_type_reduction,
|
||||
)
|
||||
@@ -0,0 +1,47 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import typing
|
||||
|
||||
|
||||
def path_match_suffix_ignore_case(path: pathlib.Path | str, suffix: str) -> bool:
|
||||
"""
|
||||
Returns whether `path` ends in `suffix`, ignoring case.
|
||||
"""
|
||||
if not isinstance(path, str):
|
||||
path = str(path)
|
||||
return path.casefold().endswith(suffix.casefold())
|
||||
|
||||
|
||||
def files_from_file_or_dir(
|
||||
file_or_dir_path: pathlib.Path | str, predicate: typing.Callable[[pathlib.Path], bool] = lambda _: True
|
||||
) -> list[pathlib.Path]:
|
||||
"""
|
||||
Gets the files in `file_or_dir_path` satisfying `predicate`.
|
||||
If `file_or_dir_path` is a file, the single file is considered. Otherwise, all files in the directory are
|
||||
considered.
|
||||
:param file_or_dir_path: Path to a file or directory.
|
||||
:param predicate: Predicate to determine if a file is included.
|
||||
:return: A list of files.
|
||||
"""
|
||||
if not isinstance(file_or_dir_path, pathlib.Path):
|
||||
file_or_dir_path = pathlib.Path(file_or_dir_path)
|
||||
|
||||
selected_files = []
|
||||
|
||||
def process_file(file_path: pathlib.Path):
|
||||
if predicate(file_path):
|
||||
selected_files.append(file_path)
|
||||
|
||||
if file_or_dir_path.is_dir():
|
||||
for root, _, files in os.walk(file_or_dir_path):
|
||||
for file in files:
|
||||
file_path = pathlib.Path(root, file)
|
||||
process_file(file_path)
|
||||
else:
|
||||
process_file(file_or_dir_path)
|
||||
|
||||
return selected_files
|
||||
@@ -0,0 +1,11 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
def get_logger(name, level=logging.DEBUG):
|
||||
logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s")
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
return logger
|
||||
@@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import onnx
|
||||
|
||||
from .onnx_model_utils import fix_output_shapes, make_dim_param_fixed, make_input_shape_fixed
|
||||
|
||||
|
||||
def make_dynamic_shape_fixed_helper():
|
||||
parser = argparse.ArgumentParser(
|
||||
f"{os.path.basename(__file__)}:{make_dynamic_shape_fixed_helper.__name__}",
|
||||
description="""
|
||||
Assign a fixed value to a dim_param or input shape
|
||||
Provide either dim_param and dim_value or input_name and input_shape.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dim_param", type=str, required=False, help="Symbolic parameter name. Provide dim_value if specified."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dim_value", type=int, required=False, help="Value to replace dim_param with in the model. Must be > 0."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_name",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Model input name to replace shape of. Provide input_shape if specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_shape",
|
||||
type=lambda x: [int(i) for i in x.split(",")],
|
||||
required=False,
|
||||
help="Shape to use for input_shape. Provide comma separated list for the shape. "
|
||||
"All values must be > 0. e.g. --input_shape 1,3,256,256",
|
||||
)
|
||||
|
||||
parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.")
|
||||
parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write updated ONNX model to.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if (
|
||||
(args.dim_param and args.input_name)
|
||||
or (not args.dim_param and not args.input_name)
|
||||
or (args.dim_param and (not args.dim_value or args.dim_value < 1))
|
||||
or (args.input_name and (not args.input_shape or any(value < 1 for value in args.input_shape)))
|
||||
):
|
||||
print("Invalid usage.")
|
||||
parser.print_help()
|
||||
sys.exit(-1)
|
||||
|
||||
model = onnx.load(str(args.input_model.resolve(strict=True)))
|
||||
|
||||
if args.dim_param:
|
||||
make_dim_param_fixed(model.graph, args.dim_param, args.dim_value)
|
||||
else:
|
||||
make_input_shape_fixed(model.graph, args.input_name, args.input_shape)
|
||||
|
||||
# update the output shapes to make them fixed if possible.
|
||||
fix_output_shapes(model)
|
||||
|
||||
onnx.save(model, str(args.output_model.resolve()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
make_dynamic_shape_fixed_helper()
|
||||
@@ -0,0 +1,50 @@
|
||||
<!--
|
||||
Keep in sync with doco generated from /docs/execution-providers/CoreML-ExecutionProvider.md on the gh_pages branch
|
||||
-->
|
||||
|Operator|Note|
|
||||
|--------|------|
|
||||
|ai.onnx:Add||
|
||||
|ai.onnx:Argmax||
|
||||
|ai.onnx:AveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.|
|
||||
|ai.onnx:Cast||
|
||||
|ai.onnx:Clip||
|
||||
|ai.onnx:Concat||
|
||||
|ai.onnx:Conv|Only 1D/2D Conv is supported.<br/>Bias if provided must be constant.|
|
||||
|ai.onnx:ConvTranspose|Weight and bias must be constant.<br/>padding_type of SAME_UPPER/SAME_LOWER is not supported.<br/>kernel_shape must have default values.<br/>output_shape is not supported.<br/>output_padding must have default values.|
|
||||
|ai.onnx:DepthToSpace|If 'mode' is 'CRD' the input must have a fixed shape.|
|
||||
|ai.onnx:Div||
|
||||
|ai.onnx:Erf||
|
||||
|ai.onnx:Gemm|Input B must be constant.|
|
||||
|ai.onnx:Gelu||
|
||||
|ai.onnx:GlobalAveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.|
|
||||
|ai.onnx:GlobalMaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.|
|
||||
|ai.onnx:GridSample|4D input.<br/>'mode' of 'linear' or 'zeros'.<br/>(mode==linear && padding_mode==reflection && align_corners==0) is not supported.|
|
||||
|ai.onnx:GroupNormalization||
|
||||
|ai.onnx:InstanceNormalization||
|
||||
|ai.onnx:LayerNormalization||
|
||||
|ai.onnx:LeakyRelu||
|
||||
|ai.onnx:MatMul|Only support for transA == 0, alpha == 1.0 and beta == 1.0 is currently implemented.|
|
||||
|ai.onnx:MaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.|
|
||||
|ai.onnx:Max||
|
||||
|ai.onnx:Mul||
|
||||
|ai.onnx:Pow|Only supports cases when both inputs are fp32.|
|
||||
|ai.onnx:PRelu||
|
||||
|ai.onnx:Reciprocal|this ask for a `epislon` (default 1e-4) where onnx don't provide|
|
||||
|ai.onnx:ReduceSum||
|
||||
|ai.onnx:ReduceMean||
|
||||
|ai.onnx:ReduceMax||
|
||||
|ai.onnx:Relu||
|
||||
|ai.onnx:Reshape||
|
||||
|ai.onnx:Resize|See [resize_op_builder.cc](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc) implementation. There are too many permutations to describe the valid combinations.|
|
||||
|ai.onnx:Round||
|
||||
|ai.onnx:Shape||
|
||||
|ai.onnx:Slice|starts/ends/axes/steps must be constant initializers.|
|
||||
|ai.onnx:Split|If provided, `splits` must be constant.|
|
||||
|ai.onnx:Sub||
|
||||
|ai.onnx:Sigmoid||
|
||||
|ai.onnx:Softmax||
|
||||
|ai.onnx:Sqrt||
|
||||
|ai.onnx:Squeeze||
|
||||
|ai.onnx:Tanh||
|
||||
|ai.onnx:Transpose||
|
||||
|ai.onnx:Unsqueeze||
|
||||
@@ -0,0 +1,43 @@
|
||||
<!--
|
||||
Keep in sync with doco generated from /docs/execution-providers/CoreML-ExecutionProvider.md on the gh_pages branch
|
||||
-->
|
||||
|Operator|Note|
|
||||
|--------|------|
|
||||
|ai.onnx:Add||
|
||||
|ai.onnx:ArgMax||
|
||||
|ai.onnx:AveragePool|Only 2D Pool is supported.|
|
||||
|ai.onnx:BatchNormalization||
|
||||
|ai.onnx:Cast||
|
||||
|ai.onnx:Clip||
|
||||
|ai.onnx:Concat||
|
||||
|ai.onnx:Conv|Only 1D/2D Conv is supported.<br/>Weights and bias should be constant.|
|
||||
|ai.onnx:DepthToSpace|Only DCR mode DepthToSpace is supported.|
|
||||
|ai.onnx:Div||
|
||||
|ai.onnx:Flatten||
|
||||
|ai.onnx:Gather|Input `indices` with scalar value is not supported.|
|
||||
|ai.onnx:Gemm|Input B should be constant.|
|
||||
|ai.onnx:GlobalAveragePool|Only 2D Pool is supported.|
|
||||
|ai.onnx:GlobalMaxPool|Only 2D Pool is supported.|
|
||||
|ai.onnx:LeakyRelu||
|
||||
|ai.onnx:LRN||
|
||||
|ai.onnx:MatMul|Input B should be constant.|
|
||||
|ai.onnx:MaxPool|Only 2D Pool is supported.|
|
||||
|ai.onnx:Mul||
|
||||
|ai.onnx:Pad|Only constant mode and last two dim padding is supported.<br/>Input pads and constant_value should be constant.<br/>If provided, axes should be constant.|
|
||||
|ai.onnx:Pow|Only supports cases when both inputs are fp32.|
|
||||
|ai.onnx:PRelu|Input slope should be constant.<br/>Input slope should either have shape [C, 1, 1] or have 1 element.|
|
||||
|ai.onnx:Reciprocal||
|
||||
|ai.onnx.ReduceSum||
|
||||
|ai.onnx:Relu||
|
||||
|ai.onnx:Reshape||
|
||||
|ai.onnx:Resize|4D input.<br/>`coordinate_transformation_mode` == `asymmetric`.<br/>`mode` == `linear` or `nearest`.<br/>`nearest_mode` == `floor`.<br/>`exclude_outside` == false<br/>`scales` or `sizes` must be constant.|
|
||||
|ai.onnx:Shape|Attribute `start` with non-default value is not supported.<br/>Attribute `end` is not supported.|
|
||||
|ai.onnx:Sigmoid||
|
||||
|ai.onnx:Slice|Inputs `starts`, `ends`, `axes`, and `steps` should be constant. Empty slice is not supported.|
|
||||
|ai.onnx:Softmax||
|
||||
|ai.onnx:Split|If provided, `splits` must be constant.|
|
||||
|ai.onnx:Squeeze||
|
||||
|ai.onnx:Sqrt||
|
||||
|ai.onnx:Sub||
|
||||
|ai.onnx:Tanh||
|
||||
|ai.onnx:Transpose||
|
||||
@@ -0,0 +1,58 @@
|
||||
<!--
|
||||
Keep in sync with doco generated from /docs/execution-providers/NNAPI-ExecutionProvider.md on the gh_pages branch
|
||||
-->
|
||||
|Operator|Note|
|
||||
|--------|------|
|
||||
|ai.onnx:Abs||
|
||||
|ai.onnx:Add||
|
||||
|ai.onnx:AveragePool|Only 2D Pool is supported.|
|
||||
|ai.onnx:BatchNormalization||
|
||||
|ai.onnx:Cast||
|
||||
|ai.onnx:Clip||
|
||||
|ai.onnx:Concat||
|
||||
|ai.onnx:Conv|Only 2D Conv is supported.<br/>Weights and bias should be constant.|
|
||||
|ai.onnx:DepthToSpace|Only DCR mode DepthToSpace is supported.|
|
||||
|ai.onnx:DequantizeLinear|All quantization scales and zero points should be constant.|
|
||||
|ai.onnx:Div||
|
||||
|ai.onnx:Elu||
|
||||
|ai.onnx:Exp||
|
||||
|ai.onnx:Flatten||
|
||||
|ai.onnx:Floor||
|
||||
|ai.onnx:Gather|Input indices should be constant if not int32 type.|
|
||||
|ai.onnx:Gemm|If input B is not constant, transB should be 1.|
|
||||
|ai.onnx:GlobalAveragePool|Only 2D Pool is supported.|
|
||||
|ai.onnx:GlobalMaxPool|Only 2D Pool is supported.|
|
||||
|ai.onnx:Identity||
|
||||
|ai.onnx:LeakyRelu||
|
||||
|ai.onnx:Log||
|
||||
|ai.onnx:LRN||
|
||||
|ai.onnx:MatMul||
|
||||
|ai.onnx:MaxPool|Only 2D Pool is supported.|
|
||||
|ai.onnx:Max||
|
||||
|ai.onnx:Min||
|
||||
|ai.onnx:Mul||
|
||||
|ai.onnx:Neg||
|
||||
|ai.onnx:Pad|Only constant mode Pad is supported.<br/>Input pads and constant_value should be constant.<br/>Input pads values should be non-negative.|
|
||||
|ai.onnx:Pow||
|
||||
|ai.onnx:PRelu||
|
||||
|ai.onnx:QLinearConv|Only 2D Conv is supported.<br/>Weights and bias should be constant.<br/>All quantization scales and zero points should be constant.|
|
||||
|ai.onnx:QLinearMatMul|All quantization scales and zero points should be constant.|
|
||||
|ai.onnx:QuantizeLinear|All quantization scales and zero points should be constant.|
|
||||
|ai.onnx:ReduceMean||
|
||||
|ai.onnx:Relu||
|
||||
|ai.onnx:Reshape||
|
||||
|ai.onnx:Resize|Only 2D Resize is supported.|
|
||||
|ai.onnx:Sigmoid||
|
||||
|ai.onnx:Sin||
|
||||
|ai.onnx:Slice||
|
||||
|ai.onnx:Softmax||
|
||||
|ai.onnx:Split|Number of splits must evenly divide split axis size. Input split should be constant if provided.|
|
||||
|ai.onnx:Sqrt||
|
||||
|ai.onnx:Squeeze|Input axes should be constant.|
|
||||
|ai.onnx:Sub||
|
||||
|ai.onnx:Tanh||
|
||||
|ai.onnx:Transpose||
|
||||
|ai.onnx:Unsqueeze|Input axes should be constant.|
|
||||
|com.microsoft:QLinearAdd|All quantization scales and zero points should be constant.|
|
||||
|com.microsoft:QLinearAveragePool|Only 2D Pool is supported.<br/>All quantization scales and zero points should be constant.|
|
||||
|com.microsoft:QLinearSigmoid|All quantization scales and zero points should be constant.|
|
||||
@@ -0,0 +1,738 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
from collections import deque
|
||||
from enum import IntEnum
|
||||
|
||||
import onnx
|
||||
|
||||
from ..onnx_model_utils import ModelProtoWithShapeInfo, get_producer_consumer_maps, is_fixed_size_tensor, optimize_model
|
||||
|
||||
|
||||
class _SupportedOpsChecker:
|
||||
"""
|
||||
Class to process the md file with list of supported ops and caveats for an execution provider.
|
||||
e.g. /tools/ci_build/github/android/nnapi_supported_ops.md
|
||||
/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md
|
||||
/tools/ci_build/github/apple/coreml_supported_neuralnetwork_ops.md
|
||||
"""
|
||||
|
||||
def __init__(self, filename):
|
||||
self._filename = filename
|
||||
self._ops = {} # op to caveats
|
||||
self._ops_seen = set()
|
||||
|
||||
with open(filename) as f:
|
||||
for line in f:
|
||||
# we're looking for a markdown table with 2 columns. first is op name. second is caveats
|
||||
# op name is domain:op
|
||||
if line.startswith("|"):
|
||||
pieces = line.strip().split("|")
|
||||
if len(pieces) == 4: # pre-first '|'. op, caveat, post-last '|'
|
||||
domain_op = pieces[1]
|
||||
caveat = pieces[2]
|
||||
caveat = caveat.replace("<br/>", " ") # remove some HTML tags
|
||||
# skip lines that don't have the ':' which separates the domain and op
|
||||
# e.g. the table header will fail this check
|
||||
if ":" in domain_op:
|
||||
self._ops[domain_op] = caveat
|
||||
|
||||
def is_op_supported(self, node):
|
||||
domain = node.domain if node.domain else "ai.onnx"
|
||||
domain_op = domain + ":" + node.op_type
|
||||
|
||||
is_supported = domain_op in self._ops
|
||||
if is_supported:
|
||||
self._ops_seen.add(domain_op)
|
||||
|
||||
return is_supported
|
||||
|
||||
def get_caveats(self):
|
||||
caveats = []
|
||||
for op in sorted(self._ops_seen):
|
||||
caveat = self._ops[op]
|
||||
if caveat:
|
||||
caveats.append(f"{op}:{caveat}")
|
||||
|
||||
return caveats
|
||||
|
||||
|
||||
class PartitioningInfo:
|
||||
class TryWithEP(IntEnum):
|
||||
NO = (0,)
|
||||
MAYBE = (1,)
|
||||
YES = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_nodes: int,
|
||||
num_supported_nodes: int,
|
||||
num_partitions: int,
|
||||
supported_ops_checker: _SupportedOpsChecker,
|
||||
supported_groups: list[onnx.NodeProto],
|
||||
unsupported_ops: set[str],
|
||||
nodes_unsupported_due_to_op: int,
|
||||
nodes_unsupported_due_to_dynamic_input: int,
|
||||
num_unsupported_nodes_due_to_rank: int,
|
||||
ops_with_unsupported_rank: set[str],
|
||||
):
|
||||
self.num_nodes = num_nodes
|
||||
self.num_supported_nodes = num_supported_nodes
|
||||
self.num_partitions = num_partitions
|
||||
self.supported_ops_checker = supported_ops_checker
|
||||
self.supported_groups = supported_groups
|
||||
self.unsupported_ops = unsupported_ops
|
||||
self.nodes_unsupported_due_to_op = nodes_unsupported_due_to_op
|
||||
self.nodes_unsupported_due_to_dynamic_input = nodes_unsupported_due_to_dynamic_input
|
||||
self.num_unsupported_nodes_due_to_rank = num_unsupported_nodes_due_to_rank
|
||||
self.ops_with_unsupported_rank = ops_with_unsupported_rank
|
||||
|
||||
self.num_subgraphs = 0
|
||||
self.num_nodes_in_subgraphs = 0
|
||||
|
||||
def merge(self, other: PartitioningInfo):
|
||||
"""
|
||||
Merge the information from another PartitioningInfo instance into this one.
|
||||
"""
|
||||
self.num_nodes += other.num_nodes
|
||||
self.num_supported_nodes += other.num_supported_nodes
|
||||
self.num_partitions += other.num_partitions
|
||||
self.supported_groups.extend(other.supported_groups)
|
||||
self.unsupported_ops.update(other.unsupported_ops)
|
||||
self.nodes_unsupported_due_to_op += other.nodes_unsupported_due_to_op
|
||||
self.nodes_unsupported_due_to_dynamic_input += other.nodes_unsupported_due_to_dynamic_input
|
||||
self.num_unsupported_nodes_due_to_rank += other.num_unsupported_nodes_due_to_rank
|
||||
self.ops_with_unsupported_rank.update(other.ops_with_unsupported_rank)
|
||||
|
||||
# hard assumption that we merge into the main graph partitioning info
|
||||
self.num_subgraphs += 1
|
||||
self.num_nodes_in_subgraphs += other.num_nodes
|
||||
|
||||
def suitability(self):
|
||||
# semi-arbitrary choices that err on the side of MAYBE.
|
||||
# having 1 partition is always preferred, but if that is small it may not be useful.
|
||||
# having 2 partitions may be okay if they cover most nodes
|
||||
# more than 2 partitions and the device copy cost is almost guaranteed to outweigh the benefit of using the NPU
|
||||
# NOTE: This assumes the EP is not CPU based and there is device copy overhead to consider
|
||||
pct_supported = self.num_supported_nodes / self.num_nodes * 100
|
||||
if self.num_partitions == 1:
|
||||
if pct_supported > 75:
|
||||
return PartitioningInfo.TryWithEP.YES
|
||||
elif pct_supported > 50:
|
||||
return PartitioningInfo.TryWithEP.MAYBE
|
||||
else:
|
||||
return PartitioningInfo.TryWithEP.NO
|
||||
|
||||
if self.num_partitions == 2:
|
||||
if pct_supported > 75:
|
||||
return PartitioningInfo.TryWithEP.MAYBE
|
||||
else:
|
||||
return PartitioningInfo.TryWithEP.NO
|
||||
|
||||
return PartitioningInfo.TryWithEP.NO
|
||||
|
||||
def print_analysis(self, logger: logging.Logger, ep_name: str):
|
||||
"""
|
||||
Analyze the partitioning information and log the analysis
|
||||
:param logger: Logger to use
|
||||
:param ep_name: Execution provider name to use in the log messages
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
f"{self.num_partitions} partitions with a total of {self.num_supported_nodes}/{self.num_nodes} "
|
||||
f"nodes can be handled by the {ep_name} EP."
|
||||
)
|
||||
|
||||
if self.supported_groups:
|
||||
logger.info(
|
||||
f"\tPartition sizes: [{', '.join([str(len(partition)) for partition in self.supported_groups])}]"
|
||||
)
|
||||
|
||||
# dump full groups if debug output is enabled
|
||||
for group in self.supported_groups:
|
||||
logger.debug(f"Nodes in group: {','.join([f'{node.op_type}:{node.name}' for node in group])}")
|
||||
|
||||
logger.info(f"Unsupported nodes due to operator={self.nodes_unsupported_due_to_op}")
|
||||
if self.unsupported_ops:
|
||||
logger.info(f"\tUnsupported ops: {','.join(sorted(self.unsupported_ops))}")
|
||||
|
||||
caveats = self.supported_ops_checker.get_caveats()
|
||||
if caveats:
|
||||
indent = " " * 5
|
||||
logger.info(
|
||||
"\tCaveats that have not been checked and may result in a node not actually being supported: "
|
||||
f"{''.join([os.linesep + indent + caveat for caveat in caveats])}"
|
||||
)
|
||||
|
||||
if self.nodes_unsupported_due_to_dynamic_input:
|
||||
logger.info(
|
||||
"Unsupported nodes due to input having a dynamic shape=%d",
|
||||
self.nodes_unsupported_due_to_dynamic_input,
|
||||
)
|
||||
|
||||
if self.num_unsupported_nodes_due_to_rank:
|
||||
logger.info(f"Unsupported nodes due to rank of input data={self.num_unsupported_nodes_due_to_rank}")
|
||||
logger.info(f"\tOps with unsupported rank: {','.join(sorted(self.ops_with_unsupported_rank))}")
|
||||
|
||||
if self.num_subgraphs > 0:
|
||||
# TODO: CoreML has a flag. NNAPI doesn't. Either should be able to support a subgraph when treated as a
|
||||
# separate graph (only extra detail would be making sure implicit inputs are handled).
|
||||
# Merging the subgraph into the parent graph would be more complex.
|
||||
# e.g. for CoreML we could potentially convert Loop to while_loop and If to cond if the subgraphs in the
|
||||
# control flow node are fully supported.
|
||||
# NNAPI also has While and If.
|
||||
|
||||
# It most likely will be necessary to support merging in If nodes with fully supported subgraphs,
|
||||
# as the subgraphs in those are often very simple, so the performance cost of going to the CPU EP and back
|
||||
# is high.
|
||||
logger.info(
|
||||
f"{self.num_nodes_in_subgraphs} nodes are in {self.num_subgraphs} subgraphs. "
|
||||
"Check EP as to whether subgraphs are supported."
|
||||
)
|
||||
|
||||
pct_nodes_using_ep = self.num_supported_nodes / self.num_nodes * 100
|
||||
if self.num_partitions == 0:
|
||||
logger.info(f"{ep_name} cannot run any nodes in this model.")
|
||||
elif self.num_partitions == 1:
|
||||
if pct_nodes_using_ep > 75:
|
||||
logger.info(
|
||||
f"{ep_name} should work well for this model as there is one partition "
|
||||
f"covering {pct_nodes_using_ep:.1f}% of the nodes in the model."
|
||||
)
|
||||
elif pct_nodes_using_ep > 50:
|
||||
logger.info(
|
||||
f"{ep_name} may work well for this model, however only {pct_nodes_using_ep:.1f}% of nodes "
|
||||
"will use it. Performance testing is required to validate."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"{ep_name} will probably not work will for this model as only {pct_nodes_using_ep:.2f}% "
|
||||
"of nodes will use it."
|
||||
)
|
||||
|
||||
elif self.num_partitions == 2 and pct_nodes_using_ep > 75:
|
||||
logger.info(
|
||||
f"{ep_name} can be considered for this model as there are two partitions "
|
||||
f"covering {pct_nodes_using_ep:.1f}% of the nodes. "
|
||||
"Performance testing is required to validate."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"{ep_name} is not recommended with this model as there are {self.num_partitions} partitions "
|
||||
f"covering {pct_nodes_using_ep:.1f}% of the nodes in the model. "
|
||||
"This will most likely result in worse performance than just using the CPU EP."
|
||||
)
|
||||
|
||||
|
||||
def _check_partitioning_for_graph(
|
||||
graph: onnx.GraphProto,
|
||||
node_to_producers: dict[onnx.NodeProto, set[onnx.NodeProto]],
|
||||
node_to_consumers: dict[onnx.NodeProto, set[onnx.NodeProto]],
|
||||
supported_ops_checker: _SupportedOpsChecker,
|
||||
outer_scope_initializers: set[str],
|
||||
require_fixed_input_sizes: bool,
|
||||
value_info: dict[str, onnx.ValueInfoProto],
|
||||
max_rank: int = 999, # max rank if EP has a limitation
|
||||
):
|
||||
# initializers have fixed sizes.
|
||||
initializers = [i.name for i in graph.initializer]
|
||||
|
||||
def _is_fixed_shape_value(value):
|
||||
if value in value_info:
|
||||
return is_fixed_size_tensor(value_info[value])
|
||||
|
||||
if value in initializers or value in outer_scope_initializers:
|
||||
return True
|
||||
|
||||
# if something has an unknown shape (e.g. something downstream of a Reshape with dynamic input for the shape)
|
||||
# it won't have an entry in value_info
|
||||
return False
|
||||
|
||||
#
|
||||
# Replicate logic from /onnxruntime/core/providers/partitioning_utils.cc:CreateSupportedPartitionNodeGroups
|
||||
# to roughly estimate number of partitions for nodes that is_node_supported_fn returns true for.
|
||||
#
|
||||
# We keep the structure and variable names as close as possible to the C++ implementation to simplify keeping them
|
||||
# in sync if future updates are needed.
|
||||
#
|
||||
# NOTE: CreateSupportedPartitionNodeGroups was recently updated to be QDQ aware so that partitions did not split
|
||||
# QDQ node groups. This code does not need to be QDQ aware as splitting a QDQ node group does not affect the total
|
||||
# number of partitions or supported nodes.
|
||||
#
|
||||
|
||||
# we don't currently support a callback for additional group closure checks in the python implementation
|
||||
on_group_closed_fn = None
|
||||
|
||||
supported_groups = []
|
||||
# number of inputs from unprocessed nodes (in-degree) per node
|
||||
in_degree = {}
|
||||
# nodes that are ready to process
|
||||
nodes_to_process = deque() # deque of Node instances
|
||||
# nodes that will be processed when considering the next partition node group
|
||||
nodes_to_process_with_next_group = deque()
|
||||
|
||||
# initialize in-degrees and find root nodes
|
||||
for node in graph.node:
|
||||
node_input_edge_count = len(node_to_producers[node]) if node in node_to_producers else 0
|
||||
in_degree[node] = node_input_edge_count
|
||||
if node_input_edge_count == 0:
|
||||
# node is only dependent on graph input or initializers
|
||||
nodes_to_process.append(node)
|
||||
|
||||
supported_group = []
|
||||
# the partition node group's border is the aggregate of its nodes' output nodes
|
||||
supported_group_border = set()
|
||||
num_supported_nodes = 0
|
||||
num_unsupported_nodes_due_to_op = 0
|
||||
num_unsupported_nodes_due_to_dynamic_input = 0
|
||||
num_unsupported_nodes_due_to_rank = 0
|
||||
unsupported_ops = set()
|
||||
ops_with_unsupported_rank = set()
|
||||
|
||||
def close_group():
|
||||
if supported_group:
|
||||
keep_partition = not on_group_closed_fn or on_group_closed_fn(supported_group)
|
||||
|
||||
if keep_partition:
|
||||
supported_groups.append(supported_group.copy())
|
||||
|
||||
supported_group.clear()
|
||||
supported_group_border.clear()
|
||||
|
||||
while nodes_to_process or nodes_to_process_with_next_group:
|
||||
if not nodes_to_process:
|
||||
close_group()
|
||||
nodes_to_process = nodes_to_process_with_next_group
|
||||
nodes_to_process_with_next_group = deque()
|
||||
continue
|
||||
|
||||
node = nodes_to_process.popleft()
|
||||
|
||||
is_op_supported = supported_ops_checker.is_op_supported(node)
|
||||
is_input_shape_supported = not require_fixed_input_sizes or all(_is_fixed_shape_value(i) for i in node.input)
|
||||
|
||||
is_rank_supported = True
|
||||
if value_info:
|
||||
for node_input in node.input:
|
||||
if node_input and node_input in value_info and value_info[node_input].type.HasField("tensor_type"):
|
||||
input_rank = len(value_info[node_input].type.tensor_type.shape.dim)
|
||||
if input_rank > max_rank:
|
||||
is_rank_supported = False
|
||||
break
|
||||
|
||||
# special-case if we can infer the rank from the length of the 'perms' Transpose attribute
|
||||
# e.g. this works with SegmentAnything where dynamic Reshape operators result in no shape info.
|
||||
if node.op_type == "Transpose" and len(node.attribute[0].ints) > max_rank:
|
||||
is_rank_supported = False
|
||||
|
||||
is_node_supported = is_op_supported and is_input_shape_supported and is_rank_supported
|
||||
|
||||
if not is_node_supported:
|
||||
if node in supported_group_border:
|
||||
# an unsupported node on the border will be processed after the current partition node group
|
||||
# so skip any additional processing/counting here
|
||||
nodes_to_process_with_next_group.append(node)
|
||||
continue
|
||||
|
||||
if not is_op_supported:
|
||||
unsupported_ops.add(f"{node.domain if node.domain else 'ai.onnx'}:{node.op_type}")
|
||||
num_unsupported_nodes_due_to_op += 1
|
||||
|
||||
if not is_input_shape_supported:
|
||||
num_unsupported_nodes_due_to_dynamic_input += 1
|
||||
|
||||
if not is_rank_supported:
|
||||
num_unsupported_nodes_due_to_rank += 1
|
||||
ops_with_unsupported_rank.add(f"{node.domain if node.domain else 'ai.onnx'}:{node.op_type}")
|
||||
|
||||
if is_node_supported:
|
||||
num_supported_nodes += 1
|
||||
|
||||
# add node to the partition node group
|
||||
supported_group.append(node)
|
||||
|
||||
# remove node from the border and add its outputs to the border
|
||||
if node in supported_group_border: # noqa: FURB132
|
||||
supported_group_border.remove(node)
|
||||
|
||||
# for each consumer node add to supported_group_border
|
||||
if node in node_to_consumers:
|
||||
for consumer in node_to_consumers[node]:
|
||||
supported_group_border.add(consumer)
|
||||
|
||||
# adjust in-degrees of the node outputs and add any new nodes to process
|
||||
if node in node_to_consumers:
|
||||
for consumer in node_to_consumers[node]:
|
||||
consumer_node_in_degree = in_degree[consumer]
|
||||
consumer_node_in_degree -= 1
|
||||
if consumer_node_in_degree == 0:
|
||||
nodes_to_process.append(consumer)
|
||||
|
||||
in_degree[consumer] = consumer_node_in_degree
|
||||
|
||||
close_group()
|
||||
|
||||
num_nodes = len(graph.node)
|
||||
num_partitions = len(supported_groups)
|
||||
|
||||
info = PartitioningInfo(
|
||||
num_nodes,
|
||||
num_supported_nodes,
|
||||
num_partitions,
|
||||
supported_ops_checker,
|
||||
supported_groups,
|
||||
unsupported_ops,
|
||||
num_unsupported_nodes_due_to_op,
|
||||
num_unsupported_nodes_due_to_dynamic_input,
|
||||
num_unsupported_nodes_due_to_rank,
|
||||
ops_with_unsupported_rank,
|
||||
)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def check_partitioning(
|
||||
main_graph: onnx.GraphProto,
|
||||
supported_ops_checker: _SupportedOpsChecker,
|
||||
require_fixed_input_sizes: bool,
|
||||
max_rank: int = 999,
|
||||
) -> PartitioningInfo:
|
||||
"""
|
||||
Estimate the partitions the graph will be split into for nodes that is_node_supported_fn returns true for.
|
||||
|
||||
The check on whether a node is supported is purely based on the operator type. Additional limitations
|
||||
(e.g. NNAPI EP only supports 2D Conv) are not checked, so partitions may not be 100% accurate. The limitations
|
||||
for operators in the partitions are printed so the user can manually check.
|
||||
:param main_graph: Graph to process
|
||||
:param supported_ops_checker: Checker with info on supported ops.
|
||||
:param require_fixed_input_sizes: If True, require that the inputs to a potentially supported node are fixed size
|
||||
tensors for it to be considered as supported. This requires
|
||||
onnx.shape_inference.infer_shapes to have been run on the model to populate the
|
||||
shape information.
|
||||
If False, shapes are ignored during the check.
|
||||
:param max_rank: Set if EP has a limitation on the rank of tensors it supports.
|
||||
:return PartitioningInfo instance with details
|
||||
"""
|
||||
|
||||
if require_fixed_input_sizes and len(main_graph.value_info) == 0 and len(main_graph.node) > 1:
|
||||
raise ValueError("Run onnx.shape_inference.infer_shapes on the model to populate the shape information.")
|
||||
|
||||
# create lookup map from ValueInfo for efficiency
|
||||
def _update_value_info(graph: onnx.GraphProto, value_to_shape: dict[str, onnx.ValueInfoProto]):
|
||||
for v in graph.input:
|
||||
value_to_shape[v.name] = v
|
||||
for v in graph.output:
|
||||
value_to_shape[v.name] = v
|
||||
for v in graph.value_info:
|
||||
value_to_shape[v.name] = v
|
||||
|
||||
# the producer/consumer maps are for the entire model
|
||||
node_to_producers, node_to_consumers = get_producer_consumer_maps(main_graph)
|
||||
|
||||
def _check_graph(
|
||||
graph: onnx.GraphProto,
|
||||
outer_scope_value_info: dict[str, onnx.ValueInfoProto] | None,
|
||||
outer_scope_initializers: set[str] | None = None,
|
||||
partitioning_info: PartitioningInfo | None = None,
|
||||
) -> PartitioningInfo:
|
||||
if outer_scope_value_info is not None:
|
||||
# extend value info if we're using it. we replace any value shadowed with a local one
|
||||
value_info = outer_scope_value_info.copy()
|
||||
_update_value_info(graph, value_info)
|
||||
else:
|
||||
value_info = {}
|
||||
|
||||
if outer_scope_initializers is None:
|
||||
outer_scope_initializers = set()
|
||||
|
||||
info = _check_partitioning_for_graph(
|
||||
graph,
|
||||
node_to_producers,
|
||||
node_to_consumers,
|
||||
supported_ops_checker,
|
||||
outer_scope_initializers,
|
||||
require_fixed_input_sizes,
|
||||
value_info,
|
||||
max_rank,
|
||||
)
|
||||
|
||||
if partitioning_info:
|
||||
# merge in subgraph info
|
||||
partitioning_info.merge(info)
|
||||
else:
|
||||
# main graph info
|
||||
partitioning_info = info
|
||||
|
||||
# setup outer scope initializers. we copy the input set as a model may have multiple subgraphs
|
||||
# on multiple levels, so we need to keep the set for each descent separate
|
||||
subgraph_outer_scope_initializers = set(outer_scope_initializers)
|
||||
for initializer in graph.initializer:
|
||||
subgraph_outer_scope_initializers.add(initializer.name)
|
||||
|
||||
for node in graph.node:
|
||||
# recurse into nodes with subgraphs
|
||||
for attr in node.attribute:
|
||||
if attr.HasField("g"):
|
||||
subgraph = attr.g
|
||||
partitioning_info = _check_graph(
|
||||
subgraph, value_info, subgraph_outer_scope_initializers, partitioning_info
|
||||
)
|
||||
|
||||
return partitioning_info
|
||||
|
||||
aggregated_partitioning_info = _check_graph(main_graph, {} if require_fixed_input_sizes else None)
|
||||
|
||||
return aggregated_partitioning_info
|
||||
|
||||
|
||||
def _check_ep_partitioning(
|
||||
model: onnx.ModelProto, supported_ops_config: pathlib.Path, require_fixed_input_sizes: bool, max_rank: int = 999
|
||||
):
|
||||
supported_ops = _SupportedOpsChecker(supported_ops_config)
|
||||
partition_info = check_partitioning(model.graph, supported_ops, require_fixed_input_sizes, max_rank)
|
||||
return partition_info
|
||||
|
||||
|
||||
def check_nnapi_partitions(model, require_fixed_input_sizes: bool):
|
||||
# if we're running in the ORT python package the file should be local. otherwise assume we're running from the
|
||||
# ORT repo
|
||||
script_dir = pathlib.Path(__file__).parent
|
||||
local_config = script_dir / "nnapi_supported_ops.md"
|
||||
if local_config.exists():
|
||||
config_path = local_config
|
||||
else:
|
||||
ort_root = script_dir.parents[3]
|
||||
config_path = ort_root / "tools" / "ci_build" / "github" / "android" / "nnapi_supported_ops.md"
|
||||
|
||||
return _check_ep_partitioning(model, config_path, require_fixed_input_sizes)
|
||||
|
||||
|
||||
def check_coreml_partitions(model: onnx.ModelProto, require_fixed_input_sizes: bool, config_filename: str):
|
||||
# if we're running in the ORT python package the file should be local. otherwise assume we're running from the
|
||||
# ORT repo
|
||||
script_dir = pathlib.Path(__file__).parent
|
||||
local_config = script_dir / config_filename
|
||||
if local_config.exists():
|
||||
config_path = local_config
|
||||
else:
|
||||
ort_root = script_dir.parents[3]
|
||||
config_path = ort_root / "tools" / "ci_build" / "github" / "apple" / config_filename
|
||||
|
||||
max_rank = 5
|
||||
return _check_ep_partitioning(model, config_path, require_fixed_input_sizes, max_rank)
|
||||
|
||||
|
||||
def check_shapes(graph: onnx.GraphProto, logger: logging.Logger | None = None):
|
||||
"""
|
||||
Check the shapes of graph inputs, values and graph outputs to determine if they have static or dynamic sizes.
|
||||
NNAPI does not support dynamically sized values. CoreML does, but it will most likely cost performance.
|
||||
:param graph: Graph to check. If shape inferencing has been run the checks on values will be meaningful.
|
||||
:param logger: Optional logger for diagnostic information.
|
||||
:return: Tuple of List of inputs with dynamic shapes, Number of dynamic values found
|
||||
"""
|
||||
|
||||
# it's OK if the input is dynamically sized and we do a Resize early to a fixed size.
|
||||
# it's not good if lots of ops have dynamic inputs
|
||||
|
||||
num_fixed_values = 0
|
||||
num_dynamic_values = 0
|
||||
|
||||
dynamic_inputs = []
|
||||
for i in graph.input:
|
||||
if not is_fixed_size_tensor(i):
|
||||
dynamic_inputs.append(i)
|
||||
# split/join to remove repeated whitespace and newlines from str(i)
|
||||
if logger:
|
||||
logger.info(f"Input is not a fixed size tensor: {' '.join(str(i).split())}")
|
||||
num_dynamic_values += 1
|
||||
else:
|
||||
num_fixed_values += 1
|
||||
|
||||
dynamic_outputs = []
|
||||
for o in graph.output:
|
||||
if not is_fixed_size_tensor(o):
|
||||
dynamic_outputs.append(o)
|
||||
if logger:
|
||||
logger.info(f"Output is not a fixed size tensor: {' '.join(str(o).split())}")
|
||||
num_dynamic_values += 1
|
||||
else:
|
||||
num_fixed_values += 1
|
||||
|
||||
# check we have value info.
|
||||
# special case some test graphs with a single node which only have graph input and output values, and
|
||||
# a model where all inputs are dynamic (results in no value_info)
|
||||
if not graph.value_info and not (len(graph.node) == 1 or len(dynamic_inputs) == len(graph.input)):
|
||||
logger.warning(
|
||||
"Unable to check shapes within model. ONNX shape inferencing should be run on the model prior to checking."
|
||||
)
|
||||
|
||||
for vi in graph.value_info:
|
||||
if is_fixed_size_tensor(vi):
|
||||
num_fixed_values += 1
|
||||
else:
|
||||
num_dynamic_values += 1
|
||||
|
||||
if logger:
|
||||
logger.info(
|
||||
f"Num values with fixed shape={num_fixed_values}. Num values with dynamic shape={num_dynamic_values}"
|
||||
)
|
||||
|
||||
if dynamic_inputs:
|
||||
if dynamic_outputs:
|
||||
logger.info(
|
||||
"Model has dynamic inputs and outputs. Consider re-exporting model with fixed sizes "
|
||||
"if NNAPI or CoreML can be used with this model."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"""Model has dynamically sized inputs but fixed sized outputs.
|
||||
If the sizes become fixed early in the model (e.g. pre-processing of a dynamic input size
|
||||
results in a fixed input size for the majority of the model) performance with NNAPI and CoreML,
|
||||
if applicable, should not be significantly impacted."""
|
||||
)
|
||||
|
||||
return dynamic_inputs, num_dynamic_values
|
||||
|
||||
|
||||
def checker(model_path: pathlib.Path, logger: logging.Logger):
|
||||
model_with_shape_info_wrapper = ModelProtoWithShapeInfo(model_path)
|
||||
model_with_shape_info = model_with_shape_info_wrapper.model_with_shape_info
|
||||
|
||||
dynamic_inputs, num_dynamic_values = check_shapes(model_with_shape_info.graph)
|
||||
|
||||
def check_ep(ep_name, checker_func):
|
||||
logger.info(f"Checking {ep_name}")
|
||||
|
||||
# check with shape info first so supported nodes takes into account values with dynamic shapes
|
||||
require_fixed_input_sizes = True
|
||||
partition_info = checker_func(model_with_shape_info, require_fixed_input_sizes)
|
||||
if logger.getEffectiveLevel() <= logging.INFO:
|
||||
partition_info.print_analysis(logger, ep_name)
|
||||
|
||||
suitability = partition_info.suitability()
|
||||
logger.info(f"Model should perform well with {ep_name} as is: {suitability.name}")
|
||||
|
||||
if suitability != PartitioningInfo.TryWithEP.YES and dynamic_inputs:
|
||||
logger.info("--------")
|
||||
logger.info("Checking if model will perform better if the dynamic shapes are fixed...")
|
||||
require_fixed_input_sizes = False
|
||||
partition_info_with_fixed_shapes = checker_func(model_with_shape_info, require_fixed_input_sizes)
|
||||
|
||||
if logger.getEffectiveLevel() <= logging.INFO:
|
||||
# analyze and log detailed info
|
||||
logger.info("Partition information if the model was updated to make the shapes fixed:")
|
||||
partition_info_with_fixed_shapes.print_analysis(logger, ep_name)
|
||||
|
||||
fixed_shape_suitability = partition_info_with_fixed_shapes.suitability()
|
||||
logger.info(
|
||||
f"Model should perform well with {ep_name} if modified to have fixed input shapes: "
|
||||
f"{fixed_shape_suitability.name}"
|
||||
)
|
||||
|
||||
if fixed_shape_suitability != PartitioningInfo.TryWithEP.NO:
|
||||
logger.info("Shapes can be altered using python -m onnxruntime.tools.make_dynamic_shape_fixed")
|
||||
|
||||
if fixed_shape_suitability.value > suitability.value:
|
||||
suitability = fixed_shape_suitability
|
||||
|
||||
logger.info("================")
|
||||
logger.info("")
|
||||
|
||||
return suitability
|
||||
|
||||
nnapi_suitability = check_ep("NNAPI", check_nnapi_partitions)
|
||||
|
||||
# Check for NeuralNetwork CoreML model
|
||||
def check_nn_coreml(model: onnx.ModelProto, require_fixed_input_sizes):
|
||||
return check_coreml_partitions(model, require_fixed_input_sizes, "coreml_supported_neuralnetwork_ops.md")
|
||||
|
||||
# Check for MLProgram CoreML model
|
||||
def check_mlprogram_coreml(model: onnx.ModelProto, require_fixed_input_sizes):
|
||||
return check_coreml_partitions(model, require_fixed_input_sizes, "coreml_supported_mlprogram_ops.md")
|
||||
|
||||
coreml_nn_suitability = check_ep("CoreML NeuralNetwork", check_nn_coreml)
|
||||
coreml_mlprogram_suitability = check_ep("CoreML MLProgram", check_mlprogram_coreml)
|
||||
|
||||
if (
|
||||
nnapi_suitability != PartitioningInfo.TryWithEP.YES
|
||||
or coreml_nn_suitability != PartitioningInfo.TryWithEP.YES
|
||||
or coreml_mlprogram_suitability != PartitioningInfo.TryWithEP.YES
|
||||
) and logger.getEffectiveLevel() > logging.INFO:
|
||||
logger.info("Re-run with log level of INFO for more details on the NNAPI/CoreML issues.")
|
||||
|
||||
return (
|
||||
nnapi_suitability != PartitioningInfo.TryWithEP.NO
|
||||
or coreml_nn_suitability != PartitioningInfo.TryWithEP.NO
|
||||
or coreml_mlprogram_suitability != PartitioningInfo.TryWithEP.NO
|
||||
)
|
||||
|
||||
|
||||
def analyze_model(model_path: pathlib.Path, skip_optimize: bool = False, logger: logging.Logger | None = None):
|
||||
"""
|
||||
Analyze the provided model to determine if it's likely to work well with the NNAPI or CoreML Execution Providers
|
||||
:param model_path: Model to analyze.
|
||||
:param skip_optimize: Skip optimizing to BASIC level before checking. When exporting to ORT format we will do this
|
||||
optimization..
|
||||
:param logger: Logger for output
|
||||
:return: True if either the NNAPI or CoreML Execution Providers may work well with this model.
|
||||
"""
|
||||
if not logger:
|
||||
logger = logging.getLogger("usability_checker")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
logger.info(f"Checking {model_path} for usability with ORT Mobile.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
if not skip_optimize:
|
||||
tmp_path = pathlib.Path(tmp) / model_path.name
|
||||
optimize_model(model_path, tmp_path, use_external_initializers=True)
|
||||
model_path = tmp_path
|
||||
|
||||
try_eps = checker(model_path.resolve(strict=True), logger)
|
||||
|
||||
return try_eps
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
os.path.basename(__file__), description="""Analyze an ONNX model for usage with the ORT mobile"""
|
||||
)
|
||||
|
||||
parser.add_argument("--log_level", choices=["debug", "info"], default="info", help="Logging level")
|
||||
parser.add_argument(
|
||||
"--skip_optimize",
|
||||
action="store_true",
|
||||
help="Don't optimize the model to BASIC level prior to analyzing. "
|
||||
"Optimization will occur when exporting the model to ORT format, so in general "
|
||||
"should not be skipped unless you have a specific reason to do so.",
|
||||
)
|
||||
parser.add_argument("model_path", type=pathlib.Path, help="Provide path to ONNX model")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def run_analyze_model():
|
||||
args = parse_args()
|
||||
logger = logging.getLogger("default")
|
||||
|
||||
if args.log_level == "debug":
|
||||
logger.setLevel(logging.DEBUG)
|
||||
elif args.log_level == "info":
|
||||
logger.setLevel(logging.INFO)
|
||||
elif args.log_level == "warning":
|
||||
logger.setLevel(logging.WARNING)
|
||||
else:
|
||||
logger.setLevel(logging.ERROR)
|
||||
|
||||
model_path = args.model_path.resolve()
|
||||
analyze_model(model_path, args.skip_optimize, logger)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_analyze_model()
|
||||
@@ -0,0 +1,169 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from pprint import pprint
|
||||
from typing import Any
|
||||
|
||||
import onnx
|
||||
|
||||
TuningResults = dict[str, Any]
|
||||
|
||||
_TUNING_RESULTS_KEY = "tuning_results"
|
||||
|
||||
|
||||
def _find_tuning_results_in_props(metadata_props):
|
||||
for idx, prop in enumerate(metadata_props):
|
||||
if prop.key == _TUNING_RESULTS_KEY:
|
||||
return idx
|
||||
return -1
|
||||
|
||||
|
||||
def extract(model: onnx.ModelProto):
|
||||
idx = _find_tuning_results_in_props(model.metadata_props)
|
||||
if idx < 0:
|
||||
return None
|
||||
|
||||
tuning_results_prop = model.metadata_props[idx]
|
||||
return json.loads(tuning_results_prop.value)
|
||||
|
||||
|
||||
def embed(model: onnx.ModelProto, tuning_results: list[TuningResults], overwrite=False):
|
||||
idx = _find_tuning_results_in_props(model.metadata_props)
|
||||
assert overwrite or idx <= 0, "the supplied onnx file already have tuning results embedded!"
|
||||
|
||||
if idx >= 0:
|
||||
model.metadata_props.pop(idx)
|
||||
|
||||
entry = model.metadata_props.add()
|
||||
entry.key = _TUNING_RESULTS_KEY
|
||||
entry.value = json.dumps(tuning_results)
|
||||
return model
|
||||
|
||||
|
||||
class Merger:
|
||||
class EpAndValidators:
|
||||
def __init__(self, ep: str, validators: dict[str, str]):
|
||||
self.ep = ep
|
||||
self.validators = copy.deepcopy(validators)
|
||||
self.key = (ep, tuple(sorted(validators.items())))
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.key)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.ep == other.ep and self.key == other.key
|
||||
|
||||
def __init__(self):
|
||||
self.ev_to_results = OrderedDict()
|
||||
|
||||
def merge(self, tuning_results: list[TuningResults]):
|
||||
for trs in tuning_results:
|
||||
self._merge_one(trs)
|
||||
|
||||
def get_merged(self):
|
||||
tuning_results = []
|
||||
for ev, flat_results in self.ev_to_results.items():
|
||||
results = {}
|
||||
trs = {
|
||||
"ep": ev.ep,
|
||||
"validators": ev.validators,
|
||||
"results": results,
|
||||
}
|
||||
for (op_sig, params_sig), kernel_id in flat_results.items():
|
||||
kernel_map = results.setdefault(op_sig, {})
|
||||
kernel_map[params_sig] = kernel_id
|
||||
tuning_results.append(trs)
|
||||
return tuning_results
|
||||
|
||||
def _merge_one(self, trs: TuningResults):
|
||||
ev = Merger.EpAndValidators(trs["ep"], trs["validators"])
|
||||
flat_results = self.ev_to_results.setdefault(ev, {})
|
||||
for op_sig, kernel_map in trs["results"].items():
|
||||
for params_sig, kernel_id in kernel_map.items():
|
||||
if (op_sig, params_sig) not in flat_results:
|
||||
flat_results[(op_sig, params_sig)] = kernel_id
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
sub_parsers = parser.add_subparsers(help="Command to execute", dest="cmd")
|
||||
|
||||
extract_parser = sub_parsers.add_parser("extract", help="Extract embedded tuning results from an onnx file.")
|
||||
extract_parser.add_argument("input_onnx")
|
||||
extract_parser.add_argument("output_json")
|
||||
|
||||
embed_parser = sub_parsers.add_parser("embed", help="Embed the tuning results into an onnx file.")
|
||||
embed_parser.add_argument("--force", "-f", action="store_true", help="Overwrite the tuning results if it existed.")
|
||||
embed_parser.add_argument("output_onnx", help="Path of the output onnx file.")
|
||||
embed_parser.add_argument("input_onnx", help="Path of the input onnx file.")
|
||||
embed_parser.add_argument("input_json", nargs="+", help="Path(s) of the tuning results file(s) to be embedded.")
|
||||
|
||||
merge_parser = sub_parsers.add_parser("merge", help="Merge multiple tuning results files as a single one.")
|
||||
merge_parser.add_argument("output_json", help="Path of the output tuning results file.")
|
||||
merge_parser.add_argument("input_json", nargs="+", help="Paths of the tuning results files to be merged.")
|
||||
|
||||
pprint_parser = sub_parsers.add_parser("pprint", help="Pretty print the tuning results.")
|
||||
pprint_parser.add_argument("json_or_onnx", help="A tuning results json file or an onnx file.")
|
||||
|
||||
args = parser.parse_args()
|
||||
if len(vars(args)) == 0:
|
||||
parser.print_help()
|
||||
exit(-1)
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.cmd == "extract":
|
||||
tuning_results = extract(onnx.load_model(args.input_onnx))
|
||||
if tuning_results is None:
|
||||
sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n")
|
||||
sys.exit(-1)
|
||||
json.dump(tuning_results, open(args.output_json, "w")) # noqa: SIM115
|
||||
elif args.cmd == "embed":
|
||||
model = onnx.load_model(args.input_onnx)
|
||||
merger = Merger()
|
||||
for tuning_results in [json.load(open(f)) for f in args.input_json]: # noqa: SIM115
|
||||
merger.merge(tuning_results)
|
||||
model = embed(model, merger.get_merged(), args.force)
|
||||
onnx.save_model(model, args.output_onnx)
|
||||
elif args.cmd == "merge":
|
||||
merger = Merger()
|
||||
for tuning_results in [json.load(open(f)) for f in args.input_json]: # noqa: SIM115
|
||||
merger.merge(tuning_results)
|
||||
json.dump(merger.get_merged(), open(args.output_json, "w")) # noqa: SIM115
|
||||
elif args.cmd == "pprint":
|
||||
tuning_results = None
|
||||
try: # noqa: SIM105
|
||||
tuning_results = json.load(open(args.json_or_onnx)) # noqa: SIM115
|
||||
except Exception:
|
||||
# it might be an onnx file otherwise, try it latter
|
||||
pass
|
||||
|
||||
if tuning_results is None:
|
||||
try:
|
||||
model = onnx.load_model(args.json_or_onnx)
|
||||
tuning_results = extract(model)
|
||||
if tuning_results is None:
|
||||
sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n")
|
||||
sys.exit(-1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if tuning_results is None:
|
||||
sys.stderr.write(f"{args.json_or_onnx} is not a valid tuning results file or onnx file!")
|
||||
sys.exit(-1)
|
||||
|
||||
pprint(tuning_results)
|
||||
else:
|
||||
# invalid choice will be handled by the parser
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,416 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import onnx
|
||||
from onnx import version_converter
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
def iterate_graph_per_node_func(graph, per_node_func, **func_args):
|
||||
"""
|
||||
Iterate the graph including subgraphs calling the per_node_func for each node.
|
||||
:param graph: Graph to iterate
|
||||
:param per_node_func: Function to call for each node. Signature is fn(node: onnx:NodeProto, **kwargs)
|
||||
:param func_args: The keyword args to pass through.
|
||||
"""
|
||||
|
||||
for node in graph.node:
|
||||
per_node_func(node, **func_args)
|
||||
# recurse into subgraph for control flow nodes (Scan/Loop/If)
|
||||
for attr in node.attribute:
|
||||
if attr.HasField("g"):
|
||||
iterate_graph_per_node_func(attr.g, per_node_func, **func_args)
|
||||
|
||||
|
||||
def iterate_graph_per_graph_func(graph, per_graph_func, **func_args):
|
||||
"""
|
||||
Iterate the graph including subgraphs calling the per_graph_func for each Graph.
|
||||
:param graph: Graph to iterate
|
||||
:param per_graph_func: Function to call for each graph. Signature is fn(graph: onnx:GraphProto, **kwargs)
|
||||
:param func_args: The keyword args to pass through.
|
||||
"""
|
||||
|
||||
per_graph_func(graph, **func_args)
|
||||
|
||||
for node in graph.node:
|
||||
# recurse into subgraph for control flow nodes (Scan/Loop/If)
|
||||
for attr in node.attribute:
|
||||
if attr.HasField("g"):
|
||||
iterate_graph_per_graph_func(attr.g, per_graph_func, **func_args)
|
||||
|
||||
|
||||
def get_opsets_imported(model: onnx.ModelProto):
|
||||
"""
|
||||
Get the opsets imported by the model
|
||||
:param model: Model to check.
|
||||
:return: Map of domain to opset.
|
||||
"""
|
||||
opsets = {}
|
||||
for entry in model.opset_import:
|
||||
# if empty it's ai.onnx
|
||||
domain = entry.domain or "ai.onnx"
|
||||
opsets[domain] = entry.version
|
||||
|
||||
return opsets
|
||||
|
||||
|
||||
def update_onnx_opset(
|
||||
model_path: pathlib.Path,
|
||||
opset: int,
|
||||
out_path: pathlib.Path | None = None,
|
||||
logger: logging.Logger | None = None,
|
||||
):
|
||||
"""
|
||||
Helper to update the opset of a model using onnx version_converter. Target opset must be greater than current opset.
|
||||
:param model_path: Path to model to update
|
||||
:param opset: Opset to update model to
|
||||
:param out_path: Optional output path for updated model to be saved to.
|
||||
:param logger: Optional logger for diagnostic output
|
||||
:returns: Updated onnx.ModelProto
|
||||
"""
|
||||
|
||||
model_path_str = str(model_path.resolve(strict=True))
|
||||
if logger:
|
||||
logger.info("Updating %s to opset %d", model_path_str, opset)
|
||||
|
||||
model = onnx.load(model_path_str)
|
||||
|
||||
new_model = version_converter.convert_version(model, opset)
|
||||
|
||||
if out_path:
|
||||
onnx.save(new_model, str(out_path))
|
||||
if logger:
|
||||
logger.info("Saved updated model to %s", out_path)
|
||||
|
||||
return new_model
|
||||
|
||||
|
||||
def optimize_model(
|
||||
model_path: pathlib.Path,
|
||||
output_path: pathlib.Path,
|
||||
level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
|
||||
log_level: int = 3,
|
||||
use_external_initializers: bool = False,
|
||||
):
|
||||
"""
|
||||
Optimize an ONNX model using ONNX Runtime to the specified level
|
||||
:param model_path: Path to ONNX model
|
||||
:param output_path: Path to save optimized model to.
|
||||
:param level: onnxruntime.GraphOptimizationLevel to use. Default is ORT_ENABLE_BASIC.
|
||||
:param log_level: Log level. Defaults to Error (3) so we don't get output about unused initializers being removed.
|
||||
Warning (2) or Info (1) may be desirable in some scenarios.
|
||||
:param use_external_initializers: Set flag to write initializers to an external file. Required if model > 2GB.
|
||||
Requires onnxruntime 1.17+
|
||||
"""
|
||||
so = ort.SessionOptions()
|
||||
so.optimized_model_filepath = str(output_path.resolve())
|
||||
so.graph_optimization_level = level
|
||||
so.log_severity_level = log_level
|
||||
|
||||
# save using external initializers so models > 2 GB are handled
|
||||
if use_external_initializers:
|
||||
major, minor, rest = ort.__version__.split(".", 3)
|
||||
if (int(major), int(minor)) >= (1, 17):
|
||||
so.add_session_config_entry("session.optimized_model_external_initializers_file_name", "external_data.pb")
|
||||
else:
|
||||
raise ValueError(
|
||||
"ONNX Runtime 1.17 or higher required to save initializers as external data when optimizing model. "
|
||||
f"Current ONNX Runtime version is {ort.__version__}"
|
||||
)
|
||||
|
||||
# create session to optimize. this will write the updated model to output_path
|
||||
_ = ort.InferenceSession(str(model_path.resolve(strict=True)), so, providers=["CPUExecutionProvider"])
|
||||
|
||||
|
||||
def _replace_symbolic_dim_value(graph: onnx.GraphProto, **kwargs):
|
||||
param_to_replace = kwargs["dim_param"]
|
||||
value = kwargs["value"]
|
||||
|
||||
def update_dim_values(value_infos):
|
||||
for vi in value_infos:
|
||||
if vi.type.HasField("tensor_type"):
|
||||
shape = vi.type.tensor_type.shape
|
||||
if shape:
|
||||
for dim in shape.dim:
|
||||
if dim.HasField("dim_param") and dim.dim_param == param_to_replace:
|
||||
dim.Clear()
|
||||
dim.dim_value = value
|
||||
|
||||
update_dim_values(graph.input)
|
||||
update_dim_values(graph.output)
|
||||
update_dim_values(graph.value_info)
|
||||
|
||||
|
||||
def _remove_invalid_dim_values_impl(graph: onnx.GraphProto):
|
||||
def clear_invalid_values(value):
|
||||
if value.type.HasField("tensor_type"):
|
||||
shape = value.type.tensor_type.shape
|
||||
if shape:
|
||||
for dim in shape.dim:
|
||||
if dim.HasField("dim_value") and dim.dim_value < 1:
|
||||
dim.Clear()
|
||||
|
||||
for i in graph.input:
|
||||
clear_invalid_values(i)
|
||||
|
||||
for o in graph.output:
|
||||
clear_invalid_values(o)
|
||||
|
||||
for vi in graph.value_info:
|
||||
clear_invalid_values(vi)
|
||||
|
||||
|
||||
def remove_invalid_dim_values(graph: onnx.GraphProto):
|
||||
"""
|
||||
Iterate the graph and subgraphs, unsetting any dim_value entries that have a value of less than 1.
|
||||
These are typically erroneously inserted by a converter to represent a dynamic dimension.
|
||||
:param graph: GraphProto to update
|
||||
"""
|
||||
iterate_graph_per_graph_func(graph, _remove_invalid_dim_values_impl)
|
||||
|
||||
|
||||
def make_dim_param_fixed(graph: onnx.GraphProto, param_name: str, value: int):
|
||||
"""
|
||||
Iterate all values in the graph, replacing dim_param in a tensor shape with the provided value.
|
||||
:param graph: GraphProto to update
|
||||
:param param_name: dim_param to set
|
||||
:param value: value to use
|
||||
"""
|
||||
iterate_graph_per_graph_func(graph, _replace_symbolic_dim_value, dim_param=param_name, value=value)
|
||||
|
||||
|
||||
def make_input_shape_fixed(graph: onnx.GraphProto, input_name: str, fixed_shape: [int]):
|
||||
"""
|
||||
Update the named graph input to set shape to the provided value. This can be used to set unknown dims as well
|
||||
as to replace dim values.
|
||||
If setting the input shape replaces a dim_param, update any other values in the graph that use the dim_param.
|
||||
:param graph: Graph to update
|
||||
:param input_name: Name of graph input to update.
|
||||
:param fixed_shape: Shape to use.
|
||||
"""
|
||||
|
||||
# remove any invalid dim values first. typically this is a dim_value of -1.
|
||||
remove_invalid_dim_values(graph)
|
||||
|
||||
for i in graph.input:
|
||||
if i.name == input_name:
|
||||
if not i.type.HasField("tensor_type"):
|
||||
raise ValueError(f"Input {input_name} is not a tensor")
|
||||
|
||||
# graph inputs are required to have a shape to provide the rank
|
||||
shape = i.type.tensor_type.shape
|
||||
if len(shape.dim) != len(fixed_shape):
|
||||
raise ValueError(f"Rank mismatch. Existing:{len(shape.dim)} Replacement:{len(fixed_shape)}")
|
||||
|
||||
for idx, dim in enumerate(shape.dim):
|
||||
# check any existing fixed dims match
|
||||
if dim.HasField("dim_value"):
|
||||
if dim.dim_value != fixed_shape[idx]:
|
||||
raise ValueError(
|
||||
f"Can't replace existing fixed size of {dim.dim_value} with {fixed_shape[idx]} "
|
||||
f"for dimension {idx + 1}"
|
||||
)
|
||||
elif dim.HasField("dim_param"):
|
||||
# replacing a dim_param so have to do that through the entire graph
|
||||
make_dim_param_fixed(graph, dim.dim_param, fixed_shape[idx])
|
||||
else:
|
||||
# replacing an unknown dim
|
||||
dim.Clear()
|
||||
dim.dim_value = fixed_shape[idx]
|
||||
|
||||
return
|
||||
|
||||
raise ValueError(
|
||||
f"Input {input_name} was not found in graph inputs. "
|
||||
f"Valid input names are: {','.join([i.name for i in graph.input])}"
|
||||
)
|
||||
|
||||
|
||||
def fix_output_shapes(model: onnx.ModelProto):
|
||||
"""
|
||||
Update the output shapesof a model where the input shape/s were made fixed, if possible.
|
||||
This is mainly to make the model usage clearer if the output shapes can be inferred from the new input shapes.
|
||||
:param model: Model that had input shapes fixed.
|
||||
"""
|
||||
|
||||
# get a version of the model with shape inferencing info in it. this will provide fixed output shapes if possible.
|
||||
m2 = onnx.shape_inference.infer_shapes(model)
|
||||
onnx.checker.check_model(m2)
|
||||
|
||||
for idx, o in enumerate(model.graph.output):
|
||||
if not is_fixed_size_tensor(o):
|
||||
new_o = m2.graph.output[idx]
|
||||
if is_fixed_size_tensor(new_o):
|
||||
o.type.tensor_type.shape.CopyFrom(new_o.type.tensor_type.shape)
|
||||
|
||||
|
||||
def _create_producer_consumer_link(
|
||||
node_to_producers: dict, node_to_consumers: dict, producer: onnx.NodeProto, consumer: onnx.NodeProto
|
||||
):
|
||||
"""
|
||||
Create links between two nodes for a value produced by one and consumed by the other.
|
||||
:param node_to_producers: Map of NodeProto to set of nodes that produce values the node consumes as inputs.
|
||||
:param node_to_consumers: Map of NodeProto to set of nodes that consume values the node produces as outputs.
|
||||
:param producer: Producer node
|
||||
:param consumer: Consumer node
|
||||
"""
|
||||
|
||||
if consumer not in node_to_producers:
|
||||
node_to_producers[consumer] = set()
|
||||
|
||||
if producer not in node_to_consumers:
|
||||
node_to_consumers[producer] = set()
|
||||
|
||||
# add entry mapping this node to the producer of this input
|
||||
node_to_producers[consumer].add(producer)
|
||||
node_to_consumers[producer].add(consumer)
|
||||
|
||||
|
||||
def _map_node_dependencies(graph: onnx.GraphProto, node_to_producers: dict, node_to_consumers: dict):
|
||||
graph_inputs = {i.name for i in graph.input}
|
||||
initializers = {i.name for i in graph.initializer}
|
||||
|
||||
# map of value name to node that creates it. copy parent values but override if values get shadowed
|
||||
producers = {}
|
||||
|
||||
implicit_inputs = set()
|
||||
|
||||
def is_local_value(value):
|
||||
return value in producers or value in initializers or value in graph_inputs
|
||||
|
||||
for node in graph.node:
|
||||
inputs = list(node.input)
|
||||
|
||||
for attr in node.attribute:
|
||||
if attr.HasField("g"):
|
||||
subgraph_implicit_inputs = _map_node_dependencies(attr.g, node_to_producers, node_to_consumers)
|
||||
inputs += subgraph_implicit_inputs
|
||||
|
||||
for i in inputs:
|
||||
if not i:
|
||||
# missing optional input
|
||||
continue
|
||||
|
||||
if is_local_value(i):
|
||||
if i in producers:
|
||||
producer = producers[i]
|
||||
_create_producer_consumer_link(node_to_producers, node_to_consumers, producer, node)
|
||||
else:
|
||||
implicit_inputs.add(i)
|
||||
|
||||
for o in node.output:
|
||||
producers[o] = node
|
||||
|
||||
return implicit_inputs
|
||||
|
||||
|
||||
def get_producer_consumer_maps(graph: onnx.GraphProto):
|
||||
"""
|
||||
Get maps for connections between the node that produces each value and the nodes that consume the value.
|
||||
Processing includes subgraphs. As the map key is a Node instance from the Graph there should be no ambiguity.
|
||||
:param graph: Graph to process.
|
||||
:return: Tuple with two maps.
|
||||
First is node_to_producers map of a node to set of all nodes producing input it consumes.
|
||||
Second is node_to_consumers map of a node to set of all nodes consuming output it creates.
|
||||
e.g. NodeA and NodeB provide inputs to NodeC. NodeC provides input to NodeD
|
||||
node_to_consumers[NodeA] = set([NodeC])
|
||||
node_to_consumers[NodeB] = set([NodeC])
|
||||
node_to_producers[NodeC] = set([NodeA, NodeB])
|
||||
node_to_consumers[NodeC] = set([NodeD])
|
||||
node_to_producers[NodeD] = set([NodeC])
|
||||
"""
|
||||
|
||||
# use a hash of the object id for NodeProto.
|
||||
# we need this for the partitioning checker where we keep maps with nodes as the key.
|
||||
onnx.NodeProto.__hash__ = lambda self: id(self)
|
||||
|
||||
node_to_producers = {} # map of node instance to nodes producing input values it consumes
|
||||
node_to_consumers = {} # map of node instance to nodes consuming output values it produces
|
||||
|
||||
implicit_inputs = _map_node_dependencies(graph, node_to_producers, node_to_consumers)
|
||||
|
||||
# top level graph should have no implicit inputs
|
||||
if implicit_inputs:
|
||||
raise ValueError(
|
||||
f"This appears to be an invalid model with missing inputs of {','.join(sorted(implicit_inputs))}"
|
||||
)
|
||||
|
||||
return node_to_producers, node_to_consumers
|
||||
|
||||
|
||||
def is_fixed_size_tensor(value: onnx.ValueInfoProto):
|
||||
"""
|
||||
Check if value is a tensor with a fixed shape.
|
||||
:param value: onnx.ValueInfoProto to check
|
||||
:return: True if value is a tensor, with a shape, where all dimensions have fixed values.
|
||||
"""
|
||||
|
||||
is_fixed = False
|
||||
if value.type.HasField("tensor_type"):
|
||||
shape = value.type.tensor_type.shape
|
||||
if shape:
|
||||
is_fixed = True # scalar has no dims so set to True and unset if we hit a dim without a valid value
|
||||
for dim in shape.dim:
|
||||
if dim.HasField("dim_value") and dim.dim_value > 0:
|
||||
continue
|
||||
|
||||
# anything else means it's a dynamic value
|
||||
is_fixed = False
|
||||
break
|
||||
|
||||
return is_fixed
|
||||
|
||||
|
||||
def get_optimization_level(level):
|
||||
"""Convert string to GraphOptimizationLevel."""
|
||||
if level == "disable":
|
||||
return ort.GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
if level == "basic":
|
||||
# Constant folding and other optimizations that only use ONNX operators
|
||||
return ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
if level == "extended":
|
||||
# Optimizations using custom operators, excluding NCHWc and NHWC layout optimizers
|
||||
return ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
|
||||
if level == "layout":
|
||||
# NCHWc and NHWC layout optimizers
|
||||
return ort.GraphOptimizationLevel.ORT_ENABLE_LAYOUT
|
||||
if level == "all":
|
||||
return ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
raise ValueError("Invalid optimization level of " + level)
|
||||
|
||||
|
||||
class ModelProtoWithShapeInfo:
|
||||
"""
|
||||
Class to load an ONNX model and run shape inferencing on it to populate the ValueInfo.
|
||||
The model_with_shape_info property will contain the updated model.
|
||||
If the model is > 2GB and uses external data a temporary file is required to run shape inferencing successfully.
|
||||
This helper class handles automatic removal of the temporary file.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: pathlib.Path):
|
||||
"""
|
||||
:param model_path: Path to ONNX model to load and run shape inferencing on.
|
||||
"""
|
||||
|
||||
self.model_path = model_path
|
||||
|
||||
model = onnx.load(str(model_path))
|
||||
self.model_with_shape_info = onnx.shape_inference.infer_shapes(model, strict_mode=True)
|
||||
|
||||
# ONNX has a silent failure from the call to infer_shapes when the model is > 2GB.
|
||||
# We detect that by checking the nodes in the returned model.
|
||||
self._tmp_model_path = None
|
||||
if len(model.graph.node) > 0 and len(self.model_with_shape_info.graph.node) == 0:
|
||||
self._tmp_model_path = pathlib.Path(model_path).with_suffix(".temp_with_shapeinf.onnx")
|
||||
onnx.shape_inference.infer_shapes_path(str(model_path), str(self._tmp_model_path), strict_mode=True)
|
||||
self.model_with_shape_info = onnx.load(str(self._tmp_model_path))
|
||||
|
||||
def __del__(self):
|
||||
if self._tmp_model_path:
|
||||
self._tmp_model_path.unlink(missing_ok=True)
|
||||
@@ -0,0 +1,85 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
# An offline standalone script to declassify an ONNX model by randomizing the tensor data in initializers.
|
||||
# The ORT Performance may change especially on generative models.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from onnx import load_model, numpy_helper, onnx_pb, save_model
|
||||
|
||||
# An experimental small value for differentiating shape data and weights.
|
||||
# The tensor data with larger size can't be shape data.
|
||||
# User may adjust this value as needed.
|
||||
SIZE_THRESHOLD = 10
|
||||
|
||||
|
||||
def graph_iterator(model, func):
|
||||
graph_queue = [model.graph]
|
||||
while graph_queue:
|
||||
graph = graph_queue.pop(0)
|
||||
func(graph)
|
||||
for node in graph.node:
|
||||
for attr in node.attribute:
|
||||
if attr.type == onnx_pb.AttributeProto.AttributeType.GRAPH:
|
||||
assert isinstance(attr.g, onnx_pb.GraphProto)
|
||||
graph_queue.append(attr.g)
|
||||
if attr.type == onnx_pb.AttributeProto.AttributeType.GRAPHS:
|
||||
for g in attr.graphs:
|
||||
assert isinstance(g, onnx_pb.GraphProto)
|
||||
graph_queue.append(g)
|
||||
|
||||
|
||||
def randomize_graph_initializer(graph):
|
||||
for i_tensor in graph.initializer:
|
||||
array = numpy_helper.to_array(i_tensor)
|
||||
# TODO: need to find a better way to differentiate shape data and weights.
|
||||
if array.size > SIZE_THRESHOLD:
|
||||
random_array = np.random.uniform(array.min(), array.max(), size=array.shape).astype(array.dtype)
|
||||
o_tensor = numpy_helper.from_array(random_array, i_tensor.name)
|
||||
i_tensor.CopyFrom(o_tensor)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Randomize the weights of an ONNX model")
|
||||
parser.add_argument("-m", type=str, required=True, help="input onnx model path")
|
||||
parser.add_argument("-o", type=str, required=True, help="output onnx model path")
|
||||
parser.add_argument(
|
||||
"--use_external_data_format",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Store or Save in external data format",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--all_tensors_to_one_file",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Save all tensors to one file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
data_path = None
|
||||
if args.use_external_data_format:
|
||||
if Path(args.m).parent == Path(args.o).parent:
|
||||
raise RuntimeError("Please specify output directory with different parent path to input directory.")
|
||||
if args.all_tensors_to_one_file:
|
||||
data_path = Path(args.o).name + ".data"
|
||||
|
||||
Path(args.o).parent.mkdir(parents=True, exist_ok=True)
|
||||
onnx_model = load_model(args.m, load_external_data=args.use_external_data_format)
|
||||
graph_iterator(onnx_model, randomize_graph_initializer)
|
||||
save_model(
|
||||
onnx_model,
|
||||
args.o,
|
||||
save_as_external_data=args.use_external_data_format,
|
||||
all_tensors_to_one_file=args.all_tensors_to_one_file,
|
||||
location=data_path,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,164 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from timeit import default_timer as timer
|
||||
|
||||
import numpy as np
|
||||
|
||||
import onnxruntime as onnxrt
|
||||
|
||||
float_dict = {
|
||||
"tensor(float16)": "float16",
|
||||
"tensor(float)": "float32",
|
||||
"tensor(double)": "float64",
|
||||
}
|
||||
|
||||
integer_dict = {
|
||||
"tensor(int32)": "int32",
|
||||
"tensor(int8)": "int8",
|
||||
"tensor(uint8)": "uint8",
|
||||
"tensor(int16)": "int16",
|
||||
"tensor(uint16)": "uint16",
|
||||
"tensor(int64)": "int64",
|
||||
"tensor(uint64)": "uint64",
|
||||
}
|
||||
|
||||
|
||||
def generate_feeds(sess, symbolic_dims: dict | None = None):
|
||||
feeds = {}
|
||||
symbolic_dims = symbolic_dims or {}
|
||||
for input_meta in sess.get_inputs():
|
||||
# replace any symbolic dimensions
|
||||
shape = []
|
||||
for dim in input_meta.shape:
|
||||
if not dim:
|
||||
# unknown dim
|
||||
shape.append(1)
|
||||
elif isinstance(dim, str):
|
||||
# symbolic dim. see if we have a value otherwise use 1
|
||||
if dim in symbolic_dims:
|
||||
shape.append(int(symbolic_dims[dim]))
|
||||
else:
|
||||
shape.append(1)
|
||||
else:
|
||||
shape.append(dim)
|
||||
|
||||
if input_meta.type in float_dict:
|
||||
feeds[input_meta.name] = np.random.rand(*shape).astype(float_dict[input_meta.type])
|
||||
elif input_meta.type in integer_dict:
|
||||
feeds[input_meta.name] = np.random.uniform(high=1000, size=tuple(shape)).astype(
|
||||
integer_dict[input_meta.type]
|
||||
)
|
||||
elif input_meta.type == "tensor(bool)":
|
||||
feeds[input_meta.name] = np.random.randint(2, size=tuple(shape)).astype("bool")
|
||||
else:
|
||||
print(f"unsupported input type {input_meta.type} for input {input_meta.name}")
|
||||
sys.exit(-1)
|
||||
return feeds
|
||||
|
||||
|
||||
# simple test program for loading onnx model, feeding all inputs and running the model num_iters times.
|
||||
def run_model(
|
||||
model_path,
|
||||
num_iters=1,
|
||||
debug=None,
|
||||
profile=None,
|
||||
symbolic_dims=None,
|
||||
feeds=None,
|
||||
override_initializers=True,
|
||||
):
|
||||
symbolic_dims = symbolic_dims or {}
|
||||
if debug:
|
||||
print(f"Pausing execution ready for debugger to attach to pid: {os.getpid()}")
|
||||
print("Press key to continue.")
|
||||
sys.stdin.read(1)
|
||||
|
||||
sess_options = None
|
||||
if profile:
|
||||
sess_options = onnxrt.SessionOptions()
|
||||
sess_options.enable_profiling = True
|
||||
sess_options.profile_file_prefix = os.path.basename(model_path)
|
||||
|
||||
sess = onnxrt.InferenceSession(
|
||||
model_path,
|
||||
sess_options=sess_options,
|
||||
providers=onnxrt.get_available_providers(),
|
||||
)
|
||||
meta = sess.get_modelmeta()
|
||||
|
||||
if not feeds:
|
||||
feeds = generate_feeds(sess, symbolic_dims)
|
||||
|
||||
if override_initializers:
|
||||
# Starting with IR4 some initializers provide default values
|
||||
# and can be overridden (available in IR4). For IR < 4 models
|
||||
# the list would be empty
|
||||
for initializer in sess.get_overridable_initializers():
|
||||
shape = [dim if dim else 1 for dim in initializer.shape]
|
||||
if initializer.type in float_dict:
|
||||
feeds[initializer.name] = np.random.rand(*shape).astype(float_dict[initializer.type])
|
||||
elif initializer.type in integer_dict:
|
||||
feeds[initializer.name] = np.random.uniform(high=1000, size=tuple(shape)).astype(
|
||||
integer_dict[initializer.type]
|
||||
)
|
||||
elif initializer.type == "tensor(bool)":
|
||||
feeds[initializer.name] = np.random.randint(2, size=tuple(shape)).astype("bool")
|
||||
else:
|
||||
print(f"unsupported initializer type {initializer.type} for initializer {initializer.name}")
|
||||
sys.exit(-1)
|
||||
|
||||
start = timer()
|
||||
for _i in range(num_iters):
|
||||
outputs = sess.run([], feeds) # fetch all outputs
|
||||
end = timer()
|
||||
|
||||
print(f"model: {meta.graph_name}")
|
||||
print(f"version: {meta.version}")
|
||||
print(f"iterations: {num_iters}")
|
||||
print(f"avg latency: {((end - start) * 1000) / num_iters} ms")
|
||||
|
||||
if profile:
|
||||
trace_file = sess.end_profiling()
|
||||
print(f"trace file written to: {trace_file}")
|
||||
|
||||
return 0, feeds, num_iters > 0 and outputs
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Simple ONNX Runtime Test Tool.")
|
||||
parser.add_argument("model_path", help="model path")
|
||||
parser.add_argument(
|
||||
"num_iters",
|
||||
nargs="?",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="model run iterations. default=1000",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="pause execution to allow attaching a debugger.",
|
||||
)
|
||||
parser.add_argument("--profile", action="store_true", help="enable chrome timeline trace profiling.")
|
||||
parser.add_argument(
|
||||
"--symbolic_dims",
|
||||
default={},
|
||||
type=lambda s: dict(x.split("=") for x in s.split(",")),
|
||||
help="Comma separated name=value pairs for any symbolic dimensions in the model input. "
|
||||
"e.g. --symbolic_dims batch=1,seqlen=5. "
|
||||
"If not provided, the value of 1 will be used for all symbolic dimensions.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
exit_code, _, _ = run_model(args.model_path, args.num_iters, args.debug, args.profile, args.symbolic_dims)
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
from .onnx_model_utils import get_optimization_level, optimize_model
|
||||
|
||||
|
||||
def optimize_model_helper():
|
||||
parser = argparse.ArgumentParser(
|
||||
f"{os.path.basename(__file__)}:{optimize_model_helper.__name__}",
|
||||
description="""
|
||||
Optimize an ONNX model using ONNX Runtime to the specified level.
|
||||
See https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html for more
|
||||
details of the optimization levels.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--opt_level",
|
||||
default="basic",
|
||||
choices=["disable", "basic", "extended", "layout", "all"],
|
||||
help="Optimization level to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_level",
|
||||
choices=["debug", "info", "warning", "error"],
|
||||
type=str,
|
||||
required=False,
|
||||
default="error",
|
||||
help="Log level. Defaults to Error so we don't get output about unused initializers "
|
||||
"being removed. Warning or Info may be desirable in some scenarios.",
|
||||
)
|
||||
|
||||
parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.")
|
||||
parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write optimized ONNX model to.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.log_level == "error":
|
||||
log_level = 3
|
||||
elif args.log_level == "debug":
|
||||
log_level = 0 # ORT verbose level
|
||||
elif args.log_level == "info":
|
||||
log_level = 1
|
||||
elif args.log_level == "warning":
|
||||
log_level = 2
|
||||
|
||||
optimize_model(args.input_model, args.output_model, get_optimization_level(args.opt_level), log_level)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
optimize_model_helper()
|
||||
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# need to add the path to the ORT flatbuffers python module before we import anything else here.
|
||||
# we also auto-magically adjust to whether we're running from the ORT repo, or from within the ORT python package
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
fbs_py_schema_dirname = "ort_flatbuffers_py"
|
||||
if os.path.isdir(os.path.join(script_dir, fbs_py_schema_dirname)):
|
||||
# fbs bindings are in this directory, so we're running in the ORT python package
|
||||
ort_fbs_py_parent_dir = script_dir
|
||||
else:
|
||||
# running directly from ORT repo, so fbs bindings are under onnxruntime/core/flatbuffers
|
||||
ort_root = os.path.abspath(os.path.join(script_dir, "..", "..", "..", ".."))
|
||||
ort_fbs_py_parent_dir = os.path.join(ort_root, "onnxruntime", "core", "flatbuffers")
|
||||
|
||||
sys.path.append(ort_fbs_py_parent_dir)
|
||||
|
||||
from .operator_type_usage_processors import ( # noqa: E402
|
||||
GloballyAllowedTypesOpTypeImplFilter, # noqa: F401
|
||||
OperatorTypeUsageManager, # noqa: F401
|
||||
OpTypeImplFilterInterface, # noqa: F401
|
||||
)
|
||||
from .ort_model_processor import OrtFormatModelProcessor # noqa: E402, F401
|
||||
from .utils import create_config_from_models # noqa: E402, F401
|
||||
@@ -0,0 +1,653 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import ort_flatbuffers_py.fbs as fbs
|
||||
|
||||
from .types import FbsTypeInfo, value_name_to_typestr
|
||||
|
||||
|
||||
def _create_op_key(domain: str, optype: str):
|
||||
return f"{domain}:{optype}"
|
||||
|
||||
|
||||
def _ort_constant_for_domain(domain: str):
|
||||
"""
|
||||
Map a string domain value to the internal ONNX Runtime constant for that domain.
|
||||
:param domain: Domain string to map.
|
||||
:return: Internal ONNX Runtime constant
|
||||
"""
|
||||
|
||||
# constants are defined in <ORT root>/include/onnxruntime/core/graph/constants.h
|
||||
# This list is limited to just the domains we have processors for
|
||||
domain_to_constant_map = {"ai.onnx": "kOnnxDomain", "ai.onnx.ml": "kMLDomain", "com.microsoft": "kMSDomain"}
|
||||
|
||||
if domain not in domain_to_constant_map:
|
||||
raise ValueError(f"Domain {domain} not found in map to ONNX Runtime constant. Please update map.")
|
||||
|
||||
return domain_to_constant_map[domain]
|
||||
|
||||
|
||||
def _reg_type_to_cpp_type(reg_type: str):
|
||||
if reg_type == "string":
|
||||
return "std::string"
|
||||
return reg_type
|
||||
|
||||
|
||||
def _split_reg_types(reg_types_str: str):
|
||||
"""
|
||||
Split on underscores but append "_t" to the previous element.
|
||||
"""
|
||||
tokens = reg_types_str.split("_")
|
||||
reg_types = []
|
||||
for token in tokens:
|
||||
if token == "t" and len(reg_types) > 0:
|
||||
reg_types[-1] += "_t"
|
||||
else:
|
||||
reg_types += [token]
|
||||
return reg_types
|
||||
|
||||
|
||||
class TypeUsageProcessor(ABC):
|
||||
"""
|
||||
Abstract base class for processors which implement operator specific logic to determine the type or types required.
|
||||
"""
|
||||
|
||||
def __init__(self, domain: str, optype: str):
|
||||
self.domain = domain
|
||||
self.optype = optype
|
||||
self.name = _create_op_key(domain, optype)
|
||||
|
||||
@abstractmethod
|
||||
def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
|
||||
pass
|
||||
|
||||
def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None):
|
||||
"""
|
||||
Given the string from a kernel registration, determine if the registration is required or not.
|
||||
:param type_in_registration: Type string from kernel registration
|
||||
:param globally_allowed_types: Optional set of globally allowed types. If provided, these types take precedence
|
||||
in determining the required types.
|
||||
:return: True is required. False if not.
|
||||
"""
|
||||
# Not all operators have typed registrations, so this is optionally implemented by derived classes
|
||||
raise RuntimeError(f"Did not expect processor for {self.name} to have typed registrations.")
|
||||
|
||||
def get_cpp_entry(self):
|
||||
"""
|
||||
Get the C++ code that specifies this operator's required types.
|
||||
:return: List with any applicable C++ code for this operator's required types. One line per entry.
|
||||
"""
|
||||
# Not applicable for some ops, so return no lines by default.
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def to_config_entry(self):
|
||||
"""
|
||||
Generate a configuration file entry in JSON format with the required types for the operator.
|
||||
:return: JSON string with required type information.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def from_config_entry(self, entry: str):
|
||||
"""
|
||||
Re-create the types required from a configuration file entry created with to_config_entry.
|
||||
NOTE: Any existing type information should be cleared prior to re-creating from a config file entry.
|
||||
:param entry: Configuration file entry
|
||||
"""
|
||||
|
||||
|
||||
class DefaultTypeUsageProcessor(TypeUsageProcessor):
|
||||
"""
|
||||
Operator processor which tracks the types used for selected input/s and/or output/s.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
domain: str,
|
||||
optype: str,
|
||||
inputs: [int] = [0], # noqa: B006
|
||||
outputs: [int] = [], # noqa: B006
|
||||
required_input_types: dict[int, set[str]] = {}, # noqa: B006
|
||||
required_output_types: dict[int, set[str]] = {}, # noqa: B006
|
||||
):
|
||||
"""
|
||||
Create DefaultTypeUsageProcessor. Types for one or more inputs and/or outputs can be tracked by the processor.
|
||||
The default is to track the types required for input 0, as this is the most common use case in ONNX.
|
||||
|
||||
Required input and output types may be specified. These are only applicable to is_typed_registration_needed().
|
||||
If a registration type matches a required type, the typed registration is needed.
|
||||
There is a separate mechanism for specifying required types from C++ for kernels with untyped registration.
|
||||
|
||||
:param domain: Operator domain.
|
||||
:param optype: Operator name.
|
||||
:param inputs: Inputs to track. Zero based index. May be empty.
|
||||
:param outputs: Outputs to track. Zero based index. May be empty.
|
||||
:param required_input_types: Required input types. May be empty.
|
||||
:param required_output_types: Required output types. May be empty.
|
||||
"""
|
||||
super().__init__(domain, optype)
|
||||
self._input_types = {}
|
||||
self._output_types = {}
|
||||
|
||||
for i in inputs:
|
||||
self._input_types[i] = set()
|
||||
|
||||
for o in outputs:
|
||||
self._output_types[o] = set()
|
||||
|
||||
if not inputs and not outputs:
|
||||
raise ValueError("At least one input or output must be tracked")
|
||||
|
||||
self._required_input_types = required_input_types
|
||||
self._required_output_types = required_output_types
|
||||
|
||||
def _is_type_enabled(self, reg_type, index, required_types, allowed_type_set):
|
||||
cpp_type = _reg_type_to_cpp_type(reg_type)
|
||||
return cpp_type in required_types.get(index, set()) or cpp_type in allowed_type_set
|
||||
|
||||
def is_input_type_enabled(self, reg_type, index, allowed_type_set=None):
|
||||
"""Whether input type is enabled based on required and allowed types."""
|
||||
if allowed_type_set is None:
|
||||
allowed_type_set = self._input_types[index]
|
||||
return self._is_type_enabled(reg_type, index, self._required_input_types, allowed_type_set)
|
||||
|
||||
def is_output_type_enabled(self, reg_type, index, allowed_type_set=None):
|
||||
"""Whether output type is enabled based on required and allowed types."""
|
||||
if allowed_type_set is None:
|
||||
allowed_type_set = self._output_types[index]
|
||||
return self._is_type_enabled(reg_type, index, self._required_output_types, allowed_type_set)
|
||||
|
||||
def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
|
||||
for i in self._input_types:
|
||||
if i >= node.InputsLength():
|
||||
# Some operators have fewer inputs in earlier versions where data that was as an attribute
|
||||
# become an input in later versions to allow it to be dynamically provided. Allow for that.
|
||||
# e.g. Slice-1 had attributes for the indices, and Slice-10 moved those to be inputs
|
||||
# raise RuntimeError('Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.'
|
||||
# .format(node.OutputsLength(), self.name, o))
|
||||
pass
|
||||
else:
|
||||
type_str = value_name_to_typestr(node.Inputs(i), value_name_to_typeinfo)
|
||||
self._input_types[i].add(type_str)
|
||||
|
||||
for o in self._output_types:
|
||||
# Don't know of any ops where the number of outputs changed across versions, so require a valid length
|
||||
if o >= node.OutputsLength():
|
||||
raise RuntimeError(
|
||||
f"Node has {node.OutputsLength()} outputs. Tracker for {self.name} incorrectly configured as it requires {o}."
|
||||
)
|
||||
|
||||
type_str = value_name_to_typestr(node.Outputs(o), value_name_to_typeinfo)
|
||||
self._output_types[o].add(type_str)
|
||||
|
||||
def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None):
|
||||
if 0 not in self._input_types:
|
||||
# currently all standard typed registrations are for input 0.
|
||||
# custom registrations can be handled by operator specific processors (e.g. OneHotProcessor below).
|
||||
raise RuntimeError(f"Expected typed registration to use type from input 0. Node:{self.name}")
|
||||
|
||||
return self.is_input_type_enabled(type_in_registration, 0, globally_allowed_types)
|
||||
|
||||
def get_cpp_entry(self):
|
||||
entries = []
|
||||
domain = _ort_constant_for_domain(self.domain)
|
||||
for i in sorted(self._input_types.keys()):
|
||||
if self._input_types[i]:
|
||||
entries.append(
|
||||
"ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES({}, {}, Input, {}, {});".format(
|
||||
domain, self.optype, i, ", ".join(sorted(self._input_types[i]))
|
||||
)
|
||||
)
|
||||
|
||||
for o in sorted(self._output_types.keys()):
|
||||
if self._output_types[o]:
|
||||
entries.append(
|
||||
"ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES({}, {}, Output, {}, {});".format(
|
||||
domain, self.optype, o, ", ".join(sorted(self._output_types[o]))
|
||||
)
|
||||
)
|
||||
|
||||
return entries
|
||||
|
||||
def to_config_entry(self):
|
||||
# convert the sets of types to lists so they can easily written out using the json model
|
||||
aggregate_info = {"inputs": {}, "outputs": {}}
|
||||
|
||||
# filter out empty entries and sort the types
|
||||
for i in sorted(self._input_types.keys()):
|
||||
if self._input_types[i]:
|
||||
aggregate_info["inputs"][i] = sorted(self._input_types[i])
|
||||
|
||||
for o in sorted(self._output_types.keys()):
|
||||
if self._output_types[o]:
|
||||
aggregate_info["outputs"][o] = sorted(self._output_types[o])
|
||||
|
||||
# remove any empty keys
|
||||
if not aggregate_info["inputs"]:
|
||||
aggregate_info.pop("inputs")
|
||||
if not aggregate_info["outputs"]:
|
||||
aggregate_info.pop("outputs")
|
||||
|
||||
entry = json.dumps(aggregate_info) if aggregate_info else None
|
||||
return entry
|
||||
|
||||
def from_config_entry(self, entry: str):
|
||||
self._input_types.clear()
|
||||
self._output_types.clear()
|
||||
|
||||
aggregate_info = json.loads(entry)
|
||||
if "inputs" in aggregate_info:
|
||||
for i_str, values in aggregate_info["inputs"].items():
|
||||
self._input_types[int(i_str)] = set(values)
|
||||
|
||||
if "outputs" in aggregate_info:
|
||||
for o_str, values in aggregate_info["outputs"].items():
|
||||
self._output_types[int(o_str)] = set(values)
|
||||
|
||||
|
||||
class Input1TypedRegistrationProcessor(DefaultTypeUsageProcessor):
|
||||
"""
|
||||
Processor for operators where the second input type is used in a typed kernel registration.
|
||||
"""
|
||||
|
||||
def __init__(self, domain: str, optype: str):
|
||||
# init with tracking of input 1 only.
|
||||
super().__init__(domain, optype, inputs=[1], outputs=[])
|
||||
|
||||
def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None):
|
||||
return self.is_input_type_enabled(type_in_registration, 1, globally_allowed_types)
|
||||
|
||||
|
||||
class Output0TypedRegistrationProcessor(DefaultTypeUsageProcessor):
|
||||
"""
|
||||
Processor for operators where the first output type is used in a typed kernel registration.
|
||||
"""
|
||||
|
||||
def __init__(self, domain: str, optype: str):
|
||||
# init with tracking of output 0 only.
|
||||
super().__init__(domain, optype, inputs=[], outputs=[0])
|
||||
|
||||
def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None):
|
||||
return self.is_output_type_enabled(type_in_registration, 0, globally_allowed_types)
|
||||
|
||||
|
||||
class OneHotProcessor(TypeUsageProcessor):
|
||||
"""
|
||||
Processor for the OneHot operator, which requires custom logic as the type registration key is a concatenation of
|
||||
the three types involved instead of a single type name.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("ai.onnx", "OneHot")
|
||||
self._triples = set()
|
||||
|
||||
def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
|
||||
type0 = value_name_to_typestr(node.Inputs(0), value_name_to_typeinfo)
|
||||
type1 = value_name_to_typestr(node.Inputs(1), value_name_to_typeinfo)
|
||||
type2 = value_name_to_typestr(node.Inputs(2), value_name_to_typeinfo)
|
||||
# types in kernel registration are ordered this way: input (T1), output (T3), depth (T2)
|
||||
key = (type0, type2, type1)
|
||||
self._triples.add(key)
|
||||
|
||||
def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None):
|
||||
# the OneHot registration involves a concatenation of the 3 types involved
|
||||
reg_types = tuple([_reg_type_to_cpp_type(reg_type) for reg_type in _split_reg_types(type_in_registration)])
|
||||
if globally_allowed_types is not None:
|
||||
return all(reg_type in globally_allowed_types for reg_type in reg_types)
|
||||
else:
|
||||
return reg_types in self._triples
|
||||
|
||||
def to_config_entry(self):
|
||||
if not self._triples:
|
||||
return None
|
||||
|
||||
aggregate_info = {"custom": sorted(self._triples)}
|
||||
entry = json.dumps(aggregate_info)
|
||||
return entry
|
||||
|
||||
def from_config_entry(self, entry: str):
|
||||
self._triples.clear()
|
||||
aggregate_info = json.loads(entry)
|
||||
if "custom" in aggregate_info:
|
||||
self._triples = {tuple(triple) for triple in aggregate_info["custom"]}
|
||||
|
||||
|
||||
def _create_operator_type_usage_processors():
|
||||
"""
|
||||
Create a set of processors that determine the required types for all enabled operators.
|
||||
:return: Dictionary of operator key to processor. Key is 'domain:operator (e.g. ai.onnx:Cast)'.
|
||||
"""
|
||||
operator_processors = {}
|
||||
|
||||
def add(processor):
|
||||
if processor.name in operator_processors:
|
||||
raise RuntimeError("Duplicate processor for " + processor.name)
|
||||
|
||||
operator_processors[processor.name] = processor
|
||||
|
||||
# Starting with ops from:
|
||||
# - Priority 1P models
|
||||
# - Mobilenet + SSD Mobilenet + MobileBert
|
||||
# - some known large kernels
|
||||
#
|
||||
# Ops we are ignoring currently so as not to produce meaningless/unused output:
|
||||
# - Implementation is type agnostic:
|
||||
# ai.onnx: If, Loop, Reshape, Scan, Shape, Squeeze, Tile, Unsqueeze
|
||||
# com.microsoft: DynamicQuantizeMatMul, MatMulIntegerToFloat
|
||||
# - Only one type supported in the ORT implementation:
|
||||
# ai.onnx: NonMaxSuppression
|
||||
# com.microsoft: FusedConv, FusedGemm, FusedMatMul
|
||||
# - Implementation does not have any significant type specific code:
|
||||
# ai.onnx: Concat, Flatten, Not, Reshape, Shape, Squeeze, Unsqueeze
|
||||
#
|
||||
default_processor_onnx_ops = [
|
||||
"Abs",
|
||||
"ArgMax",
|
||||
"ArgMin",
|
||||
"AveragePool",
|
||||
"BatchNormalization",
|
||||
"BitShift",
|
||||
"Ceil",
|
||||
"Clip",
|
||||
"Conv",
|
||||
"CumSum",
|
||||
"Exp",
|
||||
"Expand",
|
||||
"Floor",
|
||||
"Gemm",
|
||||
"IsNaN",
|
||||
"Log",
|
||||
"LogSoftmax",
|
||||
"LpNormalization",
|
||||
"MatMul",
|
||||
"Max",
|
||||
"MaxPool",
|
||||
"Mean",
|
||||
"Min",
|
||||
"NonZero",
|
||||
"Pad",
|
||||
"QLinearConv",
|
||||
"QLinearMatMul",
|
||||
"Range",
|
||||
"Reciprocal",
|
||||
"ReduceL1",
|
||||
"ReduceL2",
|
||||
"ReduceLogSum",
|
||||
"ReduceLogSumExp",
|
||||
"ReduceMax",
|
||||
"ReduceMean",
|
||||
"ReduceMin",
|
||||
"ReduceProd",
|
||||
"ReduceSum",
|
||||
"ReduceSumSquare",
|
||||
"Relu",
|
||||
"Resize",
|
||||
"ReverseSequence",
|
||||
"RoiAlign",
|
||||
"Round",
|
||||
"Scatter",
|
||||
"ScatterElements",
|
||||
"ScatterND",
|
||||
"Shrink",
|
||||
"Sigmoid",
|
||||
"Sign",
|
||||
"Sin",
|
||||
"Softmax",
|
||||
"Split",
|
||||
"SplitToSequence",
|
||||
"Sqrt",
|
||||
"Sum",
|
||||
"Tanh",
|
||||
"TopK",
|
||||
"Transpose",
|
||||
"Unique",
|
||||
]
|
||||
|
||||
# ops that are used to manipulate shapes or indices so require int32_t and int64_t to be available
|
||||
default_processor_onnx_ops_requiring_ints_for_input_0 = [
|
||||
"Add",
|
||||
"Concat",
|
||||
"Div",
|
||||
"Equal",
|
||||
"Greater",
|
||||
"Less",
|
||||
"Mul",
|
||||
"Neg", # used in tflite TransposeConv conversion
|
||||
"Sub",
|
||||
]
|
||||
|
||||
# NOTE: QLinearConv has ONNX and internal implementations
|
||||
internal_ops = ["QLinearAdd", "QLinearMul", "QLinearConv"]
|
||||
|
||||
# TODO - review and add ML ops as needed
|
||||
# ML Op notes.
|
||||
# CastMap: Switch on value type of input map type, and output type
|
||||
# DictVectorizer: Templatized on key+value of input so need to handle like OneHot with custom processor
|
||||
# LabelEncoder: Implementation switches on input and output types (only supports string and int64 in T1 and T2)
|
||||
# LinearClassifier: Internal switch on input type and also switch on output type
|
||||
# SVMClassifier: ditto
|
||||
# TreeEnsembleClassifier: Templatized on input type and also switch on output type
|
||||
# ZipMap: Switch on output type (derived from attributes)
|
||||
default_processor_onnxml_ops = []
|
||||
|
||||
[add(DefaultTypeUsageProcessor("ai.onnx", op)) for op in default_processor_onnx_ops]
|
||||
[
|
||||
add(DefaultTypeUsageProcessor("ai.onnx", op, required_input_types={0: {"int32_t", "int64_t"}}))
|
||||
for op in default_processor_onnx_ops_requiring_ints_for_input_0
|
||||
]
|
||||
[add(DefaultTypeUsageProcessor("ai.onnx.ml", op)) for op in default_processor_onnxml_ops]
|
||||
[add(DefaultTypeUsageProcessor("com.microsoft", op)) for op in internal_ops]
|
||||
|
||||
#
|
||||
# Operators that require custom handling
|
||||
#
|
||||
|
||||
# Cast switches on types of input 0 and output 0
|
||||
add(DefaultTypeUsageProcessor("ai.onnx", "Cast", inputs=[0], outputs=[0]))
|
||||
|
||||
# Operators that switch on the type of input 0 and 1
|
||||
add(DefaultTypeUsageProcessor("ai.onnx", "Gather", inputs=[0, 1]))
|
||||
add(DefaultTypeUsageProcessor("ai.onnx", "GatherElements", inputs=[0, 1]))
|
||||
add(DefaultTypeUsageProcessor("ai.onnx", "Pow", inputs=[0, 1]))
|
||||
add(DefaultTypeUsageProcessor("ai.onnx", "Slice", inputs=[0, 1]))
|
||||
|
||||
# Operators that switch on output type
|
||||
add(DefaultTypeUsageProcessor("ai.onnx", "ConstantOfShape", inputs=[], outputs=[0]))
|
||||
|
||||
# Random generator ops produce new data so we track the output type
|
||||
onnx_random_ops = ["RandomNormal", "RandomNormalLike", "RandomUniform", "RandomUniformLike", "Multinomial"]
|
||||
[add(DefaultTypeUsageProcessor("ai.onnx", op, inputs=[], outputs=[0])) for op in onnx_random_ops]
|
||||
|
||||
# Where always has a boolean first input so track the second input type for typed registration
|
||||
add(Input1TypedRegistrationProcessor("ai.onnx", "Where"))
|
||||
|
||||
# we only support 'float' as input for [Dynamic]QuantizeLinear so just track the output type
|
||||
# as that's what is used in the typed registration
|
||||
add(Output0TypedRegistrationProcessor("ai.onnx", "QuantizeLinear"))
|
||||
add(Output0TypedRegistrationProcessor("ai.onnx", "DynamicQuantizeLinear"))
|
||||
|
||||
# make sure all the dequantize types are enabled. we use int32_t for parts of GEMM and Conv so just
|
||||
# enabling int8 and uint8 is not enough.
|
||||
# TODO: Only apply required types to the global type list and ignore if it's model based per-op type reduction
|
||||
add(
|
||||
DefaultTypeUsageProcessor(
|
||||
"ai.onnx", "DequantizeLinear", inputs=[0], required_input_types={0: {"int8_t", "uint8_t", "int32_t"}}
|
||||
)
|
||||
)
|
||||
|
||||
# OneHot concatenates type strings into a triple in the typed registration
|
||||
# e.g. float_int64_t_int64_t
|
||||
add(OneHotProcessor())
|
||||
|
||||
return operator_processors
|
||||
|
||||
|
||||
class OpTypeImplFilterInterface(ABC):
|
||||
"""
|
||||
Class that filters operator implementations based on type.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str):
|
||||
"""
|
||||
Given the string from a kernel registration, determine if the registration is required or not.
|
||||
:param domain: Operator domain.
|
||||
:param optype: Operator type.
|
||||
:param type_registration_str: Type string from kernel registration
|
||||
:return: True is required. False if not.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_cpp_entries(self):
|
||||
"""
|
||||
Get the C++ code that specifies the operator types to enable.
|
||||
:return: List of strings. One line of C++ code per entry.
|
||||
"""
|
||||
|
||||
|
||||
class OperatorTypeUsageManager:
|
||||
"""
|
||||
Class to manage the operator type usage processors.
|
||||
TODO: Currently the type tracking is not specific to a version of the operator.
|
||||
It's unclear how/where version specific logic could/should be added, and it would add significant complexity
|
||||
to track types on a per-version basis. Not clear there's enough benefit from doing so either.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._all_operator_processors = _create_operator_type_usage_processors() # all possible processors
|
||||
self._operator_processors = {} # processors we have actually used so we can limit output to be meaningful
|
||||
|
||||
def _get_op_processor(self, key):
|
||||
"Add the processor to _operator_processors as it is about to be used."
|
||||
processor = None
|
||||
if key in self._all_operator_processors:
|
||||
if key not in self._operator_processors:
|
||||
self._operator_processors[key] = self._all_operator_processors[key]
|
||||
|
||||
processor = self._operator_processors[key]
|
||||
|
||||
return processor
|
||||
|
||||
def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
|
||||
"""
|
||||
Process a Node and record info on the types used.
|
||||
:param node: Node from ORT format model
|
||||
:param value_name_to_typeinfo: Map of value names to TypeInfo instances
|
||||
"""
|
||||
optype = node.OpType().decode()
|
||||
domain = node.Domain().decode() or "ai.onnx" # empty domain defaults to ai.onnx
|
||||
|
||||
key = _create_op_key(domain, optype)
|
||||
op_processor = self._get_op_processor(key)
|
||||
if op_processor:
|
||||
op_processor.process_node(node, value_name_to_typeinfo)
|
||||
|
||||
def get_config_entry(self, domain: str, optype: str):
|
||||
"""
|
||||
Get the config entry specifying the types for this operator.
|
||||
:param domain: Operator domain.
|
||||
:param optype: Operator type.
|
||||
:return: JSON string with type info if available, else None
|
||||
"""
|
||||
key = _create_op_key(domain, optype)
|
||||
config_str = None
|
||||
if key in self._operator_processors:
|
||||
config_str = self._operator_processors[key].to_config_entry()
|
||||
|
||||
return config_str
|
||||
|
||||
def restore_from_config_entry(self, domain: str, optype: str, config_entry: str):
|
||||
"""
|
||||
Restore the per-operator type information from a configuration file entry.
|
||||
:param domain: Operator domain.
|
||||
:param optype: Operator type.
|
||||
:param config_entry: JSON string with type info as created by get_config_entry
|
||||
"""
|
||||
key = _create_op_key(domain, optype)
|
||||
op_processor = self._get_op_processor(key)
|
||||
if op_processor:
|
||||
op_processor.from_config_entry(config_entry)
|
||||
|
||||
def debug_dump(self):
|
||||
print("C++ code that will be emitted:")
|
||||
[print(cpp_line) for cpp_line in self.get_cpp_entries()]
|
||||
|
||||
print("Config file type information that will be returned by get_config_entry:")
|
||||
for key in sorted(self._operator_processors.keys()):
|
||||
entry = self._operator_processors[key].to_config_entry()
|
||||
if entry:
|
||||
print(f"{key} -> {entry}")
|
||||
|
||||
# roundtrip test to validate that we can initialize the processor from the entry and get the
|
||||
# same values back
|
||||
self._operator_processors[key].from_config_entry(entry)
|
||||
assert entry == self._operator_processors[key].to_config_entry()
|
||||
|
||||
class _OpTypeImplFilter(OpTypeImplFilterInterface):
|
||||
def __init__(self, manager):
|
||||
self._manager = manager
|
||||
|
||||
def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str):
|
||||
needed = True # we keep the registration unless the per-operator processor says not to
|
||||
key = _create_op_key(domain, optype)
|
||||
if key in self._manager._operator_processors:
|
||||
needed = self._manager._operator_processors[key].is_typed_registration_needed(
|
||||
type_in_registration=type_registration_str, globally_allowed_types=None
|
||||
)
|
||||
|
||||
return needed
|
||||
|
||||
def get_cpp_entries(self):
|
||||
entries = []
|
||||
for key in sorted(self._manager._operator_processors.keys()):
|
||||
entries.extend(self._manager._operator_processors[key].get_cpp_entry())
|
||||
|
||||
return entries
|
||||
|
||||
def make_op_type_impl_filter(self):
|
||||
"""
|
||||
Creates an OpTypeImplFilterInterface instance from this manager.
|
||||
Filtering uses the manager's operator type usage processor state.
|
||||
"""
|
||||
return OperatorTypeUsageManager._OpTypeImplFilter(self)
|
||||
|
||||
|
||||
class GloballyAllowedTypesOpTypeImplFilter(OpTypeImplFilterInterface):
|
||||
"""
|
||||
Operator implementation filter which uses globally allowed types.
|
||||
"""
|
||||
|
||||
_valid_allowed_types = set(FbsTypeInfo.tensordatatype_to_string.values()) # noqa: RUF012
|
||||
|
||||
def __init__(self, globally_allowed_types: set[str]):
|
||||
self._operator_processors = _create_operator_type_usage_processors()
|
||||
|
||||
if not globally_allowed_types.issubset(self._valid_allowed_types):
|
||||
raise ValueError(
|
||||
f"Globally allowed types must all be valid. Invalid types: {sorted(globally_allowed_types - self._valid_allowed_types)}"
|
||||
)
|
||||
|
||||
self._globally_allowed_types = globally_allowed_types
|
||||
|
||||
def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str):
|
||||
key = _create_op_key(domain, optype)
|
||||
if key in self._operator_processors:
|
||||
needed = self._operator_processors[key].is_typed_registration_needed(
|
||||
type_in_registration=type_registration_str, globally_allowed_types=self._globally_allowed_types
|
||||
)
|
||||
else:
|
||||
needed = _reg_type_to_cpp_type(type_registration_str) in self._globally_allowed_types
|
||||
|
||||
return needed
|
||||
|
||||
def get_cpp_entries(self):
|
||||
return [
|
||||
"ORT_SPECIFY_OP_KERNEL_GLOBAL_ALLOWED_TYPES({});".format(", ".join(sorted(self._globally_allowed_types)))
|
||||
]
|
||||
|
||||
def global_type_list(self):
|
||||
return self._globally_allowed_types
|
||||
@@ -0,0 +1,7 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
class ArgType(object):
|
||||
INPUT = 0
|
||||
OUTPUT = 1
|
||||
@@ -0,0 +1,67 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class ArgTypeAndIndex(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = ArgTypeAndIndex()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsArgTypeAndIndex(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def ArgTypeAndIndexBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# ArgTypeAndIndex
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# ArgTypeAndIndex
|
||||
def ArgType(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# ArgTypeAndIndex
|
||||
def Index(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
def ArgTypeAndIndexStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
ArgTypeAndIndexStart(builder)
|
||||
|
||||
def ArgTypeAndIndexAddArgType(builder, argType):
|
||||
builder.PrependInt8Slot(0, argType, 0)
|
||||
|
||||
def AddArgType(builder, argType):
|
||||
ArgTypeAndIndexAddArgType(builder, argType)
|
||||
|
||||
def ArgTypeAndIndexAddIndex(builder, index):
|
||||
builder.PrependUint32Slot(1, index, 0)
|
||||
|
||||
def AddIndex(builder, index):
|
||||
ArgTypeAndIndexAddIndex(builder, index)
|
||||
|
||||
def ArgTypeAndIndexEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return ArgTypeAndIndexEnd(builder)
|
||||
@@ -0,0 +1,337 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class Attribute(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = Attribute()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsAttribute(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def AttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# Attribute
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# Attribute
|
||||
def Name(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# Attribute
|
||||
def DocString(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# Attribute
|
||||
def Type(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# Attribute
|
||||
def F(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
|
||||
return 0.0
|
||||
|
||||
# Attribute
|
||||
def I(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# Attribute
|
||||
def S(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# Attribute
|
||||
def T(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.Tensor import Tensor
|
||||
obj = Tensor()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Attribute
|
||||
def G(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.Graph import Graph
|
||||
obj = Graph()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Attribute
|
||||
def Floats(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.Get(flatbuffers.number_types.Float32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return 0
|
||||
|
||||
# Attribute
|
||||
def FloatsAsNumpy(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
|
||||
if o != 0:
|
||||
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Float32Flags, o)
|
||||
return 0
|
||||
|
||||
# Attribute
|
||||
def FloatsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Attribute
|
||||
def FloatsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
|
||||
return o == 0
|
||||
|
||||
# Attribute
|
||||
def Ints(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.Get(flatbuffers.number_types.Int64Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8))
|
||||
return 0
|
||||
|
||||
# Attribute
|
||||
def IntsAsNumpy(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||
if o != 0:
|
||||
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o)
|
||||
return 0
|
||||
|
||||
# Attribute
|
||||
def IntsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Attribute
|
||||
def IntsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||
return o == 0
|
||||
|
||||
# Attribute
|
||||
def Strings(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return ""
|
||||
|
||||
# Attribute
|
||||
def StringsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Attribute
|
||||
def StringsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24))
|
||||
return o == 0
|
||||
|
||||
# Attribute
|
||||
def Tensors(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(26))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.Tensor import Tensor
|
||||
obj = Tensor()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Attribute
|
||||
def TensorsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(26))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Attribute
|
||||
def TensorsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(26))
|
||||
return o == 0
|
||||
|
||||
# Attribute
|
||||
def Graphs(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.Graph import Graph
|
||||
obj = Graph()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Attribute
|
||||
def GraphsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Attribute
|
||||
def GraphsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28))
|
||||
return o == 0
|
||||
|
||||
def AttributeStart(builder):
|
||||
builder.StartObject(13)
|
||||
|
||||
def Start(builder):
|
||||
AttributeStart(builder)
|
||||
|
||||
def AttributeAddName(builder, name):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
||||
|
||||
def AddName(builder, name):
|
||||
AttributeAddName(builder, name)
|
||||
|
||||
def AttributeAddDocString(builder, docString):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(docString), 0)
|
||||
|
||||
def AddDocString(builder, docString):
|
||||
AttributeAddDocString(builder, docString)
|
||||
|
||||
def AttributeAddType(builder, type):
|
||||
builder.PrependInt32Slot(2, type, 0)
|
||||
|
||||
def AddType(builder, type):
|
||||
AttributeAddType(builder, type)
|
||||
|
||||
def AttributeAddF(builder, f):
|
||||
builder.PrependFloat32Slot(3, f, 0.0)
|
||||
|
||||
def AddF(builder, f):
|
||||
AttributeAddF(builder, f)
|
||||
|
||||
def AttributeAddI(builder, i):
|
||||
builder.PrependInt64Slot(4, i, 0)
|
||||
|
||||
def AddI(builder, i):
|
||||
AttributeAddI(builder, i)
|
||||
|
||||
def AttributeAddS(builder, s):
|
||||
builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(s), 0)
|
||||
|
||||
def AddS(builder, s):
|
||||
AttributeAddS(builder, s)
|
||||
|
||||
def AttributeAddT(builder, t):
|
||||
builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(t), 0)
|
||||
|
||||
def AddT(builder, t):
|
||||
AttributeAddT(builder, t)
|
||||
|
||||
def AttributeAddG(builder, g):
|
||||
builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(g), 0)
|
||||
|
||||
def AddG(builder, g):
|
||||
AttributeAddG(builder, g)
|
||||
|
||||
def AttributeAddFloats(builder, floats):
|
||||
builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(floats), 0)
|
||||
|
||||
def AddFloats(builder, floats):
|
||||
AttributeAddFloats(builder, floats)
|
||||
|
||||
def AttributeStartFloatsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartFloatsVector(builder, numElems: int) -> int:
|
||||
return AttributeStartFloatsVector(builder, numElems)
|
||||
|
||||
def AttributeAddInts(builder, ints):
|
||||
builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(ints), 0)
|
||||
|
||||
def AddInts(builder, ints):
|
||||
AttributeAddInts(builder, ints)
|
||||
|
||||
def AttributeStartIntsVector(builder, numElems):
|
||||
return builder.StartVector(8, numElems, 8)
|
||||
|
||||
def StartIntsVector(builder, numElems: int) -> int:
|
||||
return AttributeStartIntsVector(builder, numElems)
|
||||
|
||||
def AttributeAddStrings(builder, strings):
|
||||
builder.PrependUOffsetTRelativeSlot(10, flatbuffers.number_types.UOffsetTFlags.py_type(strings), 0)
|
||||
|
||||
def AddStrings(builder, strings):
|
||||
AttributeAddStrings(builder, strings)
|
||||
|
||||
def AttributeStartStringsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartStringsVector(builder, numElems: int) -> int:
|
||||
return AttributeStartStringsVector(builder, numElems)
|
||||
|
||||
def AttributeAddTensors(builder, tensors):
|
||||
builder.PrependUOffsetTRelativeSlot(11, flatbuffers.number_types.UOffsetTFlags.py_type(tensors), 0)
|
||||
|
||||
def AddTensors(builder, tensors):
|
||||
AttributeAddTensors(builder, tensors)
|
||||
|
||||
def AttributeStartTensorsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartTensorsVector(builder, numElems: int) -> int:
|
||||
return AttributeStartTensorsVector(builder, numElems)
|
||||
|
||||
def AttributeAddGraphs(builder, graphs):
|
||||
builder.PrependUOffsetTRelativeSlot(12, flatbuffers.number_types.UOffsetTFlags.py_type(graphs), 0)
|
||||
|
||||
def AddGraphs(builder, graphs):
|
||||
AttributeAddGraphs(builder, graphs)
|
||||
|
||||
def AttributeStartGraphsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartGraphsVector(builder, numElems: int) -> int:
|
||||
return AttributeStartGraphsVector(builder, numElems)
|
||||
|
||||
def AttributeEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return AttributeEnd(builder)
|
||||
@@ -0,0 +1,18 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
class AttributeType(object):
|
||||
UNDEFINED = 0
|
||||
FLOAT = 1
|
||||
INT = 2
|
||||
STRING = 3
|
||||
TENSOR = 4
|
||||
GRAPH = 5
|
||||
FLOATS = 6
|
||||
INTS = 7
|
||||
STRINGS = 8
|
||||
TENSORS = 9
|
||||
GRAPHS = 10
|
||||
SPARSE_TENSOR = 11
|
||||
SPARSE_TENSORS = 12
|
||||
@@ -0,0 +1,125 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class Checkpoint(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = Checkpoint()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsCheckpoint(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def CheckpointBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)
|
||||
|
||||
# Checkpoint
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# Checkpoint
|
||||
def Version(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# Checkpoint
|
||||
def ModuleState(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.ModuleState import ModuleState
|
||||
obj = ModuleState()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Checkpoint
|
||||
def OptimizerGroups(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.OptimizerGroup import OptimizerGroup
|
||||
obj = OptimizerGroup()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Checkpoint
|
||||
def OptimizerGroupsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Checkpoint
|
||||
def OptimizerGroupsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
return o == 0
|
||||
|
||||
# Checkpoint
|
||||
def PropertyBag(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.PropertyBag import PropertyBag
|
||||
obj = PropertyBag()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
def CheckpointStart(builder):
|
||||
builder.StartObject(4)
|
||||
|
||||
def Start(builder):
|
||||
CheckpointStart(builder)
|
||||
|
||||
def CheckpointAddVersion(builder, version):
|
||||
builder.PrependInt32Slot(0, version, 0)
|
||||
|
||||
def AddVersion(builder, version):
|
||||
CheckpointAddVersion(builder, version)
|
||||
|
||||
def CheckpointAddModuleState(builder, moduleState):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(moduleState), 0)
|
||||
|
||||
def AddModuleState(builder, moduleState):
|
||||
CheckpointAddModuleState(builder, moduleState)
|
||||
|
||||
def CheckpointAddOptimizerGroups(builder, optimizerGroups):
|
||||
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(optimizerGroups), 0)
|
||||
|
||||
def AddOptimizerGroups(builder, optimizerGroups):
|
||||
CheckpointAddOptimizerGroups(builder, optimizerGroups)
|
||||
|
||||
def CheckpointStartOptimizerGroupsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartOptimizerGroupsVector(builder, numElems: int) -> int:
|
||||
return CheckpointStartOptimizerGroupsVector(builder, numElems)
|
||||
|
||||
def CheckpointAddPropertyBag(builder, propertyBag):
|
||||
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(propertyBag), 0)
|
||||
|
||||
def AddPropertyBag(builder, propertyBag):
|
||||
CheckpointAddPropertyBag(builder, propertyBag)
|
||||
|
||||
def CheckpointEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return CheckpointEnd(builder)
|
||||
@@ -0,0 +1,120 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
# deprecated: no longer using kernel def hashes
|
||||
class DeprecatedKernelCreateInfos(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = DeprecatedKernelCreateInfos()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsDeprecatedKernelCreateInfos(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def DeprecatedKernelCreateInfosBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# DeprecatedKernelCreateInfos
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# DeprecatedKernelCreateInfos
|
||||
def NodeIndices(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return 0
|
||||
|
||||
# DeprecatedKernelCreateInfos
|
||||
def NodeIndicesAsNumpy(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
|
||||
return 0
|
||||
|
||||
# DeprecatedKernelCreateInfos
|
||||
def NodeIndicesLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# DeprecatedKernelCreateInfos
|
||||
def NodeIndicesIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
return o == 0
|
||||
|
||||
# DeprecatedKernelCreateInfos
|
||||
def KernelDefHashes(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.Get(flatbuffers.number_types.Uint64Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8))
|
||||
return 0
|
||||
|
||||
# DeprecatedKernelCreateInfos
|
||||
def KernelDefHashesAsNumpy(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint64Flags, o)
|
||||
return 0
|
||||
|
||||
# DeprecatedKernelCreateInfos
|
||||
def KernelDefHashesLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# DeprecatedKernelCreateInfos
|
||||
def KernelDefHashesIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
return o == 0
|
||||
|
||||
def DeprecatedKernelCreateInfosStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
DeprecatedKernelCreateInfosStart(builder)
|
||||
|
||||
def DeprecatedKernelCreateInfosAddNodeIndices(builder, nodeIndices):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(nodeIndices), 0)
|
||||
|
||||
def AddNodeIndices(builder, nodeIndices):
|
||||
DeprecatedKernelCreateInfosAddNodeIndices(builder, nodeIndices)
|
||||
|
||||
def DeprecatedKernelCreateInfosStartNodeIndicesVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartNodeIndicesVector(builder, numElems: int) -> int:
|
||||
return DeprecatedKernelCreateInfosStartNodeIndicesVector(builder, numElems)
|
||||
|
||||
def DeprecatedKernelCreateInfosAddKernelDefHashes(builder, kernelDefHashes):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(kernelDefHashes), 0)
|
||||
|
||||
def AddKernelDefHashes(builder, kernelDefHashes):
|
||||
DeprecatedKernelCreateInfosAddKernelDefHashes(builder, kernelDefHashes)
|
||||
|
||||
def DeprecatedKernelCreateInfosStartKernelDefHashesVector(builder, numElems):
|
||||
return builder.StartVector(8, numElems, 8)
|
||||
|
||||
def StartKernelDefHashesVector(builder, numElems: int) -> int:
|
||||
return DeprecatedKernelCreateInfosStartKernelDefHashesVector(builder, numElems)
|
||||
|
||||
def DeprecatedKernelCreateInfosEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return DeprecatedKernelCreateInfosEnd(builder)
|
||||
@@ -0,0 +1,68 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
# deprecated: no longer using kernel def hashes
|
||||
class DeprecatedNodeIndexAndKernelDefHash(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = DeprecatedNodeIndexAndKernelDefHash()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsDeprecatedNodeIndexAndKernelDefHash(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def DeprecatedNodeIndexAndKernelDefHashBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# DeprecatedNodeIndexAndKernelDefHash
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# DeprecatedNodeIndexAndKernelDefHash
|
||||
def NodeIndex(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# DeprecatedNodeIndexAndKernelDefHash
|
||||
def KernelDefHash(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
def DeprecatedNodeIndexAndKernelDefHashStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
DeprecatedNodeIndexAndKernelDefHashStart(builder)
|
||||
|
||||
def DeprecatedNodeIndexAndKernelDefHashAddNodeIndex(builder, nodeIndex):
|
||||
builder.PrependUint32Slot(0, nodeIndex, 0)
|
||||
|
||||
def AddNodeIndex(builder, nodeIndex):
|
||||
DeprecatedNodeIndexAndKernelDefHashAddNodeIndex(builder, nodeIndex)
|
||||
|
||||
def DeprecatedNodeIndexAndKernelDefHashAddKernelDefHash(builder, kernelDefHash):
|
||||
builder.PrependUint64Slot(1, kernelDefHash, 0)
|
||||
|
||||
def AddKernelDefHash(builder, kernelDefHash):
|
||||
DeprecatedNodeIndexAndKernelDefHashAddKernelDefHash(builder, kernelDefHash)
|
||||
|
||||
def DeprecatedNodeIndexAndKernelDefHashEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return DeprecatedNodeIndexAndKernelDefHashEnd(builder)
|
||||
@@ -0,0 +1,96 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
# deprecated: no longer using kernel def hashes
|
||||
class DeprecatedSessionState(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = DeprecatedSessionState()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsDeprecatedSessionState(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def DeprecatedSessionStateBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# DeprecatedSessionState
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# DeprecatedSessionState
|
||||
def Kernels(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.DeprecatedKernelCreateInfos import DeprecatedKernelCreateInfos
|
||||
obj = DeprecatedKernelCreateInfos()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# DeprecatedSessionState
|
||||
def SubGraphSessionStates(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.DeprecatedSubGraphSessionState import DeprecatedSubGraphSessionState
|
||||
obj = DeprecatedSubGraphSessionState()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# DeprecatedSessionState
|
||||
def SubGraphSessionStatesLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# DeprecatedSessionState
|
||||
def SubGraphSessionStatesIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
return o == 0
|
||||
|
||||
def DeprecatedSessionStateStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
DeprecatedSessionStateStart(builder)
|
||||
|
||||
def DeprecatedSessionStateAddKernels(builder, kernels):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(kernels), 0)
|
||||
|
||||
def AddKernels(builder, kernels):
|
||||
DeprecatedSessionStateAddKernels(builder, kernels)
|
||||
|
||||
def DeprecatedSessionStateAddSubGraphSessionStates(builder, subGraphSessionStates):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(subGraphSessionStates), 0)
|
||||
|
||||
def AddSubGraphSessionStates(builder, subGraphSessionStates):
|
||||
DeprecatedSessionStateAddSubGraphSessionStates(builder, subGraphSessionStates)
|
||||
|
||||
def DeprecatedSessionStateStartSubGraphSessionStatesVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartSubGraphSessionStatesVector(builder, numElems: int) -> int:
|
||||
return DeprecatedSessionStateStartSubGraphSessionStatesVector(builder, numElems)
|
||||
|
||||
def DeprecatedSessionStateEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return DeprecatedSessionStateEnd(builder)
|
||||
@@ -0,0 +1,72 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
# deprecated: no longer using kernel def hashes
|
||||
class DeprecatedSubGraphSessionState(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = DeprecatedSubGraphSessionState()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsDeprecatedSubGraphSessionState(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def DeprecatedSubGraphSessionStateBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# DeprecatedSubGraphSessionState
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# DeprecatedSubGraphSessionState
|
||||
def GraphId(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# DeprecatedSubGraphSessionState
|
||||
def SessionState(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.DeprecatedSessionState import DeprecatedSessionState
|
||||
obj = DeprecatedSessionState()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
def DeprecatedSubGraphSessionStateStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
DeprecatedSubGraphSessionStateStart(builder)
|
||||
|
||||
def DeprecatedSubGraphSessionStateAddGraphId(builder, graphId):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(graphId), 0)
|
||||
|
||||
def AddGraphId(builder, graphId):
|
||||
DeprecatedSubGraphSessionStateAddGraphId(builder, graphId)
|
||||
|
||||
def DeprecatedSubGraphSessionStateAddSessionState(builder, sessionState):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(sessionState), 0)
|
||||
|
||||
def AddSessionState(builder, sessionState):
|
||||
DeprecatedSubGraphSessionStateAddSessionState(builder, sessionState)
|
||||
|
||||
def DeprecatedSubGraphSessionStateEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return DeprecatedSubGraphSessionStateEnd(builder)
|
||||
@@ -0,0 +1,71 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class Dimension(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = Dimension()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsDimension(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def DimensionBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# Dimension
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# Dimension
|
||||
def Value(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.DimensionValue import DimensionValue
|
||||
obj = DimensionValue()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Dimension
|
||||
def Denotation(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
def DimensionStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
DimensionStart(builder)
|
||||
|
||||
def DimensionAddValue(builder, value):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(value), 0)
|
||||
|
||||
def AddValue(builder, value):
|
||||
DimensionAddValue(builder, value)
|
||||
|
||||
def DimensionAddDenotation(builder, denotation):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(denotation), 0)
|
||||
|
||||
def AddDenotation(builder, denotation):
|
||||
DimensionAddDenotation(builder, denotation)
|
||||
|
||||
def DimensionEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return DimensionEnd(builder)
|
||||
@@ -0,0 +1,80 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class DimensionValue(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = DimensionValue()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsDimensionValue(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def DimensionValueBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# DimensionValue
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# DimensionValue
|
||||
def DimType(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# DimensionValue
|
||||
def DimValue(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# DimensionValue
|
||||
def DimParam(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
def DimensionValueStart(builder):
|
||||
builder.StartObject(3)
|
||||
|
||||
def Start(builder):
|
||||
DimensionValueStart(builder)
|
||||
|
||||
def DimensionValueAddDimType(builder, dimType):
|
||||
builder.PrependInt8Slot(0, dimType, 0)
|
||||
|
||||
def AddDimType(builder, dimType):
|
||||
DimensionValueAddDimType(builder, dimType)
|
||||
|
||||
def DimensionValueAddDimValue(builder, dimValue):
|
||||
builder.PrependInt64Slot(1, dimValue, 0)
|
||||
|
||||
def AddDimValue(builder, dimValue):
|
||||
DimensionValueAddDimValue(builder, dimValue)
|
||||
|
||||
def DimensionValueAddDimParam(builder, dimParam):
|
||||
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(dimParam), 0)
|
||||
|
||||
def AddDimParam(builder, dimParam):
|
||||
DimensionValueAddDimParam(builder, dimParam)
|
||||
|
||||
def DimensionValueEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return DimensionValueEnd(builder)
|
||||
@@ -0,0 +1,8 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
class DimensionValueType(object):
|
||||
UNKNOWN = 0
|
||||
VALUE = 1
|
||||
PARAM = 2
|
||||
@@ -0,0 +1,32 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class EdgeEnd(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def SizeOf(cls):
|
||||
return 12
|
||||
|
||||
# EdgeEnd
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# EdgeEnd
|
||||
def NodeIndex(self): return self._tab.Get(flatbuffers.number_types.Uint32Flags, self._tab.Pos + flatbuffers.number_types.UOffsetTFlags.py_type(0))
|
||||
# EdgeEnd
|
||||
def SrcArgIndex(self): return self._tab.Get(flatbuffers.number_types.Int32Flags, self._tab.Pos + flatbuffers.number_types.UOffsetTFlags.py_type(4))
|
||||
# EdgeEnd
|
||||
def DstArgIndex(self): return self._tab.Get(flatbuffers.number_types.Int32Flags, self._tab.Pos + flatbuffers.number_types.UOffsetTFlags.py_type(8))
|
||||
|
||||
def CreateEdgeEnd(builder, nodeIndex, srcArgIndex, dstArgIndex):
|
||||
builder.Prep(4, 12)
|
||||
builder.PrependInt32(dstArgIndex)
|
||||
builder.PrependInt32(srcArgIndex)
|
||||
builder.PrependUint32(nodeIndex)
|
||||
return builder.Offset()
|
||||
@@ -0,0 +1,67 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class FloatProperty(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = FloatProperty()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsFloatProperty(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def FloatPropertyBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)
|
||||
|
||||
# FloatProperty
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# FloatProperty
|
||||
def Name(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# FloatProperty
|
||||
def Value(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
|
||||
return 0.0
|
||||
|
||||
def FloatPropertyStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
FloatPropertyStart(builder)
|
||||
|
||||
def FloatPropertyAddName(builder, name):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
||||
|
||||
def AddName(builder, name):
|
||||
FloatPropertyAddName(builder, name)
|
||||
|
||||
def FloatPropertyAddValue(builder, value):
|
||||
builder.PrependFloat32Slot(1, value, 0.0)
|
||||
|
||||
def AddValue(builder, value):
|
||||
FloatPropertyAddValue(builder, value)
|
||||
|
||||
def FloatPropertyEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return FloatPropertyEnd(builder)
|
||||
@@ -0,0 +1,320 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class Graph(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = Graph()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsGraph(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def GraphBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# Graph
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# Graph
|
||||
def Initializers(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.Tensor import Tensor
|
||||
obj = Tensor()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Graph
|
||||
def InitializersLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Graph
|
||||
def InitializersIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
return o == 0
|
||||
|
||||
# Graph
|
||||
def NodeArgs(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.ValueInfo import ValueInfo
|
||||
obj = ValueInfo()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Graph
|
||||
def NodeArgsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Graph
|
||||
def NodeArgsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
return o == 0
|
||||
|
||||
# Graph
|
||||
def Nodes(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.Node import Node
|
||||
obj = Node()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Graph
|
||||
def NodesLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Graph
|
||||
def NodesIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
return o == 0
|
||||
|
||||
# Graph
|
||||
def MaxNodeIndex(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# Graph
|
||||
def NodeEdges(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.NodeEdge import NodeEdge
|
||||
obj = NodeEdge()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Graph
|
||||
def NodeEdgesLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Graph
|
||||
def NodeEdgesIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
||||
return o == 0
|
||||
|
||||
# Graph
|
||||
def Inputs(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return ""
|
||||
|
||||
# Graph
|
||||
def InputsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Graph
|
||||
def InputsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
|
||||
return o == 0
|
||||
|
||||
# Graph
|
||||
def Outputs(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return ""
|
||||
|
||||
# Graph
|
||||
def OutputsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Graph
|
||||
def OutputsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
|
||||
return o == 0
|
||||
|
||||
# Graph
|
||||
def SparseInitializers(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.SparseTensor import SparseTensor
|
||||
obj = SparseTensor()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Graph
|
||||
def SparseInitializersLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Graph
|
||||
def SparseInitializersIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
|
||||
return o == 0
|
||||
|
||||
# Graph
|
||||
def RuntimeOptimizations(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.RuntimeOptimizations import RuntimeOptimizations
|
||||
obj = RuntimeOptimizations()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
def GraphStart(builder):
|
||||
builder.StartObject(9)
|
||||
|
||||
def Start(builder):
|
||||
GraphStart(builder)
|
||||
|
||||
def GraphAddInitializers(builder, initializers):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(initializers), 0)
|
||||
|
||||
def AddInitializers(builder, initializers):
|
||||
GraphAddInitializers(builder, initializers)
|
||||
|
||||
def GraphStartInitializersVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartInitializersVector(builder, numElems: int) -> int:
|
||||
return GraphStartInitializersVector(builder, numElems)
|
||||
|
||||
def GraphAddNodeArgs(builder, nodeArgs):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(nodeArgs), 0)
|
||||
|
||||
def AddNodeArgs(builder, nodeArgs):
|
||||
GraphAddNodeArgs(builder, nodeArgs)
|
||||
|
||||
def GraphStartNodeArgsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartNodeArgsVector(builder, numElems: int) -> int:
|
||||
return GraphStartNodeArgsVector(builder, numElems)
|
||||
|
||||
def GraphAddNodes(builder, nodes):
|
||||
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(nodes), 0)
|
||||
|
||||
def AddNodes(builder, nodes):
|
||||
GraphAddNodes(builder, nodes)
|
||||
|
||||
def GraphStartNodesVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartNodesVector(builder, numElems: int) -> int:
|
||||
return GraphStartNodesVector(builder, numElems)
|
||||
|
||||
def GraphAddMaxNodeIndex(builder, maxNodeIndex):
|
||||
builder.PrependUint32Slot(3, maxNodeIndex, 0)
|
||||
|
||||
def AddMaxNodeIndex(builder, maxNodeIndex):
|
||||
GraphAddMaxNodeIndex(builder, maxNodeIndex)
|
||||
|
||||
def GraphAddNodeEdges(builder, nodeEdges):
|
||||
builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(nodeEdges), 0)
|
||||
|
||||
def AddNodeEdges(builder, nodeEdges):
|
||||
GraphAddNodeEdges(builder, nodeEdges)
|
||||
|
||||
def GraphStartNodeEdgesVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartNodeEdgesVector(builder, numElems: int) -> int:
|
||||
return GraphStartNodeEdgesVector(builder, numElems)
|
||||
|
||||
def GraphAddInputs(builder, inputs):
|
||||
builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(inputs), 0)
|
||||
|
||||
def AddInputs(builder, inputs):
|
||||
GraphAddInputs(builder, inputs)
|
||||
|
||||
def GraphStartInputsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartInputsVector(builder, numElems: int) -> int:
|
||||
return GraphStartInputsVector(builder, numElems)
|
||||
|
||||
def GraphAddOutputs(builder, outputs):
|
||||
builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(outputs), 0)
|
||||
|
||||
def AddOutputs(builder, outputs):
|
||||
GraphAddOutputs(builder, outputs)
|
||||
|
||||
def GraphStartOutputsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartOutputsVector(builder, numElems: int) -> int:
|
||||
return GraphStartOutputsVector(builder, numElems)
|
||||
|
||||
def GraphAddSparseInitializers(builder, sparseInitializers):
|
||||
builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(sparseInitializers), 0)
|
||||
|
||||
def AddSparseInitializers(builder, sparseInitializers):
|
||||
GraphAddSparseInitializers(builder, sparseInitializers)
|
||||
|
||||
def GraphStartSparseInitializersVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartSparseInitializersVector(builder, numElems: int) -> int:
|
||||
return GraphStartSparseInitializersVector(builder, numElems)
|
||||
|
||||
def GraphAddRuntimeOptimizations(builder, runtimeOptimizations):
|
||||
builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(runtimeOptimizations), 0)
|
||||
|
||||
def AddRuntimeOptimizations(builder, runtimeOptimizations):
|
||||
GraphAddRuntimeOptimizations(builder, runtimeOptimizations)
|
||||
|
||||
def GraphEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return GraphEnd(builder)
|
||||
@@ -0,0 +1,88 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class InferenceSession(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = InferenceSession()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsInferenceSession(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def InferenceSessionBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# InferenceSession
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# InferenceSession
|
||||
def OrtVersion(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# InferenceSession
|
||||
def Model(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.Model import Model
|
||||
obj = Model()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# InferenceSession
|
||||
def KernelTypeStrResolver(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.KernelTypeStrResolver import KernelTypeStrResolver
|
||||
obj = KernelTypeStrResolver()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
def InferenceSessionStart(builder):
|
||||
builder.StartObject(4)
|
||||
|
||||
def Start(builder):
|
||||
InferenceSessionStart(builder)
|
||||
|
||||
def InferenceSessionAddOrtVersion(builder, ortVersion):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(ortVersion), 0)
|
||||
|
||||
def AddOrtVersion(builder, ortVersion):
|
||||
InferenceSessionAddOrtVersion(builder, ortVersion)
|
||||
|
||||
def InferenceSessionAddModel(builder, model):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(model), 0)
|
||||
|
||||
def AddModel(builder, model):
|
||||
InferenceSessionAddModel(builder, model)
|
||||
|
||||
def InferenceSessionAddKernelTypeStrResolver(builder, kernelTypeStrResolver):
|
||||
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(kernelTypeStrResolver), 0)
|
||||
|
||||
def AddKernelTypeStrResolver(builder, kernelTypeStrResolver):
|
||||
InferenceSessionAddKernelTypeStrResolver(builder, kernelTypeStrResolver)
|
||||
|
||||
def InferenceSessionEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return InferenceSessionEnd(builder)
|
||||
@@ -0,0 +1,67 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class IntProperty(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = IntProperty()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsIntProperty(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def IntPropertyBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)
|
||||
|
||||
# IntProperty
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# IntProperty
|
||||
def Name(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# IntProperty
|
||||
def Value(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
def IntPropertyStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
IntPropertyStart(builder)
|
||||
|
||||
def IntPropertyAddName(builder, name):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
||||
|
||||
def AddName(builder, name):
|
||||
IntPropertyAddName(builder, name)
|
||||
|
||||
def IntPropertyAddValue(builder, value):
|
||||
builder.PrependInt64Slot(1, value, 0)
|
||||
|
||||
def AddValue(builder, value):
|
||||
IntPropertyAddValue(builder, value)
|
||||
|
||||
def IntPropertyEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return IntPropertyEnd(builder)
|
||||
@@ -0,0 +1,91 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class KernelTypeStrArgsEntry(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = KernelTypeStrArgsEntry()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsKernelTypeStrArgsEntry(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def KernelTypeStrArgsEntryBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# KernelTypeStrArgsEntry
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# KernelTypeStrArgsEntry
|
||||
def KernelTypeStr(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# KernelTypeStrArgsEntry
|
||||
def Args(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.ArgTypeAndIndex import ArgTypeAndIndex
|
||||
obj = ArgTypeAndIndex()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# KernelTypeStrArgsEntry
|
||||
def ArgsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# KernelTypeStrArgsEntry
|
||||
def ArgsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
return o == 0
|
||||
|
||||
def KernelTypeStrArgsEntryStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
KernelTypeStrArgsEntryStart(builder)
|
||||
|
||||
def KernelTypeStrArgsEntryAddKernelTypeStr(builder, kernelTypeStr):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(kernelTypeStr), 0)
|
||||
|
||||
def AddKernelTypeStr(builder, kernelTypeStr):
|
||||
KernelTypeStrArgsEntryAddKernelTypeStr(builder, kernelTypeStr)
|
||||
|
||||
def KernelTypeStrArgsEntryAddArgs(builder, args):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(args), 0)
|
||||
|
||||
def AddArgs(builder, args):
|
||||
KernelTypeStrArgsEntryAddArgs(builder, args)
|
||||
|
||||
def KernelTypeStrArgsEntryStartArgsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartArgsVector(builder, numElems: int) -> int:
|
||||
return KernelTypeStrArgsEntryStartArgsVector(builder, numElems)
|
||||
|
||||
def KernelTypeStrArgsEntryEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return KernelTypeStrArgsEntryEnd(builder)
|
||||
@@ -0,0 +1,78 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class KernelTypeStrResolver(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = KernelTypeStrResolver()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsKernelTypeStrResolver(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def KernelTypeStrResolverBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# KernelTypeStrResolver
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# KernelTypeStrResolver
|
||||
def OpKernelTypeStrArgs(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.OpIdKernelTypeStrArgsEntry import OpIdKernelTypeStrArgsEntry
|
||||
obj = OpIdKernelTypeStrArgsEntry()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# KernelTypeStrResolver
|
||||
def OpKernelTypeStrArgsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# KernelTypeStrResolver
|
||||
def OpKernelTypeStrArgsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
return o == 0
|
||||
|
||||
def KernelTypeStrResolverStart(builder):
|
||||
builder.StartObject(1)
|
||||
|
||||
def Start(builder):
|
||||
KernelTypeStrResolverStart(builder)
|
||||
|
||||
def KernelTypeStrResolverAddOpKernelTypeStrArgs(builder, opKernelTypeStrArgs):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(opKernelTypeStrArgs), 0)
|
||||
|
||||
def AddOpKernelTypeStrArgs(builder, opKernelTypeStrArgs):
|
||||
KernelTypeStrResolverAddOpKernelTypeStrArgs(builder, opKernelTypeStrArgs)
|
||||
|
||||
def KernelTypeStrResolverStartOpKernelTypeStrArgsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartOpKernelTypeStrArgsVector(builder, numElems: int) -> int:
|
||||
return KernelTypeStrResolverStartOpKernelTypeStrArgsVector(builder, numElems)
|
||||
|
||||
def KernelTypeStrResolverEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return KernelTypeStrResolverEnd(builder)
|
||||
@@ -0,0 +1,71 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class MapType(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = MapType()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsMapType(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def MapTypeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# MapType
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# MapType
|
||||
def KeyType(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# MapType
|
||||
def ValueType(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.TypeInfo import TypeInfo
|
||||
obj = TypeInfo()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
def MapTypeStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
MapTypeStart(builder)
|
||||
|
||||
def MapTypeAddKeyType(builder, keyType):
|
||||
builder.PrependInt32Slot(0, keyType, 0)
|
||||
|
||||
def AddKeyType(builder, keyType):
|
||||
MapTypeAddKeyType(builder, keyType)
|
||||
|
||||
def MapTypeAddValueType(builder, valueType):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(valueType), 0)
|
||||
|
||||
def AddValueType(builder, valueType):
|
||||
MapTypeAddValueType(builder, valueType)
|
||||
|
||||
def MapTypeEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return MapTypeEnd(builder)
|
||||
@@ -0,0 +1,223 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class Model(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = Model()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsModel(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def ModelBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# Model
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# Model
|
||||
def IrVersion(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# Model
|
||||
def OpsetImport(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.OperatorSetId import OperatorSetId
|
||||
obj = OperatorSetId()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Model
|
||||
def OpsetImportLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Model
|
||||
def OpsetImportIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
return o == 0
|
||||
|
||||
# Model
|
||||
def ProducerName(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# Model
|
||||
def ProducerVersion(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# Model
|
||||
def Domain(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# Model
|
||||
def ModelVersion(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# Model
|
||||
def DocString(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# Model
|
||||
def Graph(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.Graph import Graph
|
||||
obj = Graph()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Model
|
||||
def GraphDocString(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# Model
|
||||
def MetadataProps(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.StringStringEntry import StringStringEntry
|
||||
obj = StringStringEntry()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Model
|
||||
def MetadataPropsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Model
|
||||
def MetadataPropsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||
return o == 0
|
||||
|
||||
def ModelStart(builder):
|
||||
builder.StartObject(10)
|
||||
|
||||
def Start(builder):
|
||||
ModelStart(builder)
|
||||
|
||||
def ModelAddIrVersion(builder, irVersion):
|
||||
builder.PrependInt64Slot(0, irVersion, 0)
|
||||
|
||||
def AddIrVersion(builder, irVersion):
|
||||
ModelAddIrVersion(builder, irVersion)
|
||||
|
||||
def ModelAddOpsetImport(builder, opsetImport):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(opsetImport), 0)
|
||||
|
||||
def AddOpsetImport(builder, opsetImport):
|
||||
ModelAddOpsetImport(builder, opsetImport)
|
||||
|
||||
def ModelStartOpsetImportVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartOpsetImportVector(builder, numElems: int) -> int:
|
||||
return ModelStartOpsetImportVector(builder, numElems)
|
||||
|
||||
def ModelAddProducerName(builder, producerName):
|
||||
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(producerName), 0)
|
||||
|
||||
def AddProducerName(builder, producerName):
|
||||
ModelAddProducerName(builder, producerName)
|
||||
|
||||
def ModelAddProducerVersion(builder, producerVersion):
|
||||
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(producerVersion), 0)
|
||||
|
||||
def AddProducerVersion(builder, producerVersion):
|
||||
ModelAddProducerVersion(builder, producerVersion)
|
||||
|
||||
def ModelAddDomain(builder, domain):
|
||||
builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(domain), 0)
|
||||
|
||||
def AddDomain(builder, domain):
|
||||
ModelAddDomain(builder, domain)
|
||||
|
||||
def ModelAddModelVersion(builder, modelVersion):
|
||||
builder.PrependInt64Slot(5, modelVersion, 0)
|
||||
|
||||
def AddModelVersion(builder, modelVersion):
|
||||
ModelAddModelVersion(builder, modelVersion)
|
||||
|
||||
def ModelAddDocString(builder, docString):
|
||||
builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(docString), 0)
|
||||
|
||||
def AddDocString(builder, docString):
|
||||
ModelAddDocString(builder, docString)
|
||||
|
||||
def ModelAddGraph(builder, graph):
|
||||
builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(graph), 0)
|
||||
|
||||
def AddGraph(builder, graph):
|
||||
ModelAddGraph(builder, graph)
|
||||
|
||||
def ModelAddGraphDocString(builder, graphDocString):
|
||||
builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(graphDocString), 0)
|
||||
|
||||
def AddGraphDocString(builder, graphDocString):
|
||||
ModelAddGraphDocString(builder, graphDocString)
|
||||
|
||||
def ModelAddMetadataProps(builder, metadataProps):
|
||||
builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(metadataProps), 0)
|
||||
|
||||
def AddMetadataProps(builder, metadataProps):
|
||||
ModelAddMetadataProps(builder, metadataProps)
|
||||
|
||||
def ModelStartMetadataPropsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartMetadataPropsVector(builder, numElems: int) -> int:
|
||||
return ModelStartMetadataPropsVector(builder, numElems)
|
||||
|
||||
def ModelEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return ModelEnd(builder)
|
||||
@@ -0,0 +1,141 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class ModuleState(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = ModuleState()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsModuleState(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def ModuleStateBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)
|
||||
|
||||
# ModuleState
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# ModuleState
|
||||
def RequiresGradParams(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.Tensor import Tensor
|
||||
obj = Tensor()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# ModuleState
|
||||
def RequiresGradParamsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# ModuleState
|
||||
def RequiresGradParamsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
return o == 0
|
||||
|
||||
# ModuleState
|
||||
def FrozenParams(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.Tensor import Tensor
|
||||
obj = Tensor()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# ModuleState
|
||||
def FrozenParamsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# ModuleState
|
||||
def FrozenParamsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
return o == 0
|
||||
|
||||
# ModuleState
|
||||
def IsNominalState(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
|
||||
return False
|
||||
|
||||
# ModuleState
|
||||
def HasExternalData(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
if o != 0:
|
||||
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
|
||||
return False
|
||||
|
||||
def ModuleStateStart(builder):
|
||||
builder.StartObject(4)
|
||||
|
||||
def Start(builder):
|
||||
ModuleStateStart(builder)
|
||||
|
||||
def ModuleStateAddRequiresGradParams(builder, requiresGradParams):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(requiresGradParams), 0)
|
||||
|
||||
def AddRequiresGradParams(builder, requiresGradParams):
|
||||
ModuleStateAddRequiresGradParams(builder, requiresGradParams)
|
||||
|
||||
def ModuleStateStartRequiresGradParamsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartRequiresGradParamsVector(builder, numElems: int) -> int:
|
||||
return ModuleStateStartRequiresGradParamsVector(builder, numElems)
|
||||
|
||||
def ModuleStateAddFrozenParams(builder, frozenParams):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(frozenParams), 0)
|
||||
|
||||
def AddFrozenParams(builder, frozenParams):
|
||||
ModuleStateAddFrozenParams(builder, frozenParams)
|
||||
|
||||
def ModuleStateStartFrozenParamsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartFrozenParamsVector(builder, numElems: int) -> int:
|
||||
return ModuleStateStartFrozenParamsVector(builder, numElems)
|
||||
|
||||
def ModuleStateAddIsNominalState(builder, isNominalState):
|
||||
builder.PrependBoolSlot(2, isNominalState, 0)
|
||||
|
||||
def AddIsNominalState(builder, isNominalState):
|
||||
ModuleStateAddIsNominalState(builder, isNominalState)
|
||||
|
||||
def ModuleStateAddHasExternalData(builder, hasExternalData):
|
||||
builder.PrependBoolSlot(3, hasExternalData, 0)
|
||||
|
||||
def AddHasExternalData(builder, hasExternalData):
|
||||
ModuleStateAddHasExternalData(builder, hasExternalData)
|
||||
|
||||
def ModuleStateEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return ModuleStateEnd(builder)
|
||||
@@ -0,0 +1,317 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class Node(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = Node()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsNode(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def NodeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# Node
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# Node
|
||||
def Name(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# Node
|
||||
def DocString(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# Node
|
||||
def Domain(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# Node
|
||||
def SinceVersion(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# Node
|
||||
def Index(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# Node
|
||||
def OpType(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# Node
|
||||
def Type(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# Node
|
||||
def ExecutionProviderType(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# Node
|
||||
def Inputs(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return ""
|
||||
|
||||
# Node
|
||||
def InputsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Node
|
||||
def InputsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
|
||||
return o == 0
|
||||
|
||||
# Node
|
||||
def Outputs(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return ""
|
||||
|
||||
# Node
|
||||
def OutputsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Node
|
||||
def OutputsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||
return o == 0
|
||||
|
||||
# Node
|
||||
def Attributes(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.Attribute import Attribute
|
||||
obj = Attribute()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Node
|
||||
def AttributesLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Node
|
||||
def AttributesIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24))
|
||||
return o == 0
|
||||
|
||||
# Node
|
||||
def InputArgCounts(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(26))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return 0
|
||||
|
||||
# Node
|
||||
def InputArgCountsAsNumpy(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(26))
|
||||
if o != 0:
|
||||
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
|
||||
return 0
|
||||
|
||||
# Node
|
||||
def InputArgCountsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(26))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Node
|
||||
def InputArgCountsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(26))
|
||||
return o == 0
|
||||
|
||||
# Node
|
||||
def ImplicitInputs(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return ""
|
||||
|
||||
# Node
|
||||
def ImplicitInputsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Node
|
||||
def ImplicitInputsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28))
|
||||
return o == 0
|
||||
|
||||
def NodeStart(builder):
|
||||
builder.StartObject(13)
|
||||
|
||||
def Start(builder):
|
||||
NodeStart(builder)
|
||||
|
||||
def NodeAddName(builder, name):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
||||
|
||||
def AddName(builder, name):
|
||||
NodeAddName(builder, name)
|
||||
|
||||
def NodeAddDocString(builder, docString):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(docString), 0)
|
||||
|
||||
def AddDocString(builder, docString):
|
||||
NodeAddDocString(builder, docString)
|
||||
|
||||
def NodeAddDomain(builder, domain):
|
||||
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(domain), 0)
|
||||
|
||||
def AddDomain(builder, domain):
|
||||
NodeAddDomain(builder, domain)
|
||||
|
||||
def NodeAddSinceVersion(builder, sinceVersion):
|
||||
builder.PrependInt32Slot(3, sinceVersion, 0)
|
||||
|
||||
def AddSinceVersion(builder, sinceVersion):
|
||||
NodeAddSinceVersion(builder, sinceVersion)
|
||||
|
||||
def NodeAddIndex(builder, index):
|
||||
builder.PrependUint32Slot(4, index, 0)
|
||||
|
||||
def AddIndex(builder, index):
|
||||
NodeAddIndex(builder, index)
|
||||
|
||||
def NodeAddOpType(builder, opType):
|
||||
builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(opType), 0)
|
||||
|
||||
def AddOpType(builder, opType):
|
||||
NodeAddOpType(builder, opType)
|
||||
|
||||
def NodeAddType(builder, type):
|
||||
builder.PrependInt32Slot(6, type, 0)
|
||||
|
||||
def AddType(builder, type):
|
||||
NodeAddType(builder, type)
|
||||
|
||||
def NodeAddExecutionProviderType(builder, executionProviderType):
|
||||
builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(executionProviderType), 0)
|
||||
|
||||
def AddExecutionProviderType(builder, executionProviderType):
|
||||
NodeAddExecutionProviderType(builder, executionProviderType)
|
||||
|
||||
def NodeAddInputs(builder, inputs):
|
||||
builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(inputs), 0)
|
||||
|
||||
def AddInputs(builder, inputs):
|
||||
NodeAddInputs(builder, inputs)
|
||||
|
||||
def NodeStartInputsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartInputsVector(builder, numElems: int) -> int:
|
||||
return NodeStartInputsVector(builder, numElems)
|
||||
|
||||
def NodeAddOutputs(builder, outputs):
|
||||
builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(outputs), 0)
|
||||
|
||||
def AddOutputs(builder, outputs):
|
||||
NodeAddOutputs(builder, outputs)
|
||||
|
||||
def NodeStartOutputsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartOutputsVector(builder, numElems: int) -> int:
|
||||
return NodeStartOutputsVector(builder, numElems)
|
||||
|
||||
def NodeAddAttributes(builder, attributes):
|
||||
builder.PrependUOffsetTRelativeSlot(10, flatbuffers.number_types.UOffsetTFlags.py_type(attributes), 0)
|
||||
|
||||
def AddAttributes(builder, attributes):
|
||||
NodeAddAttributes(builder, attributes)
|
||||
|
||||
def NodeStartAttributesVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartAttributesVector(builder, numElems: int) -> int:
|
||||
return NodeStartAttributesVector(builder, numElems)
|
||||
|
||||
def NodeAddInputArgCounts(builder, inputArgCounts):
|
||||
builder.PrependUOffsetTRelativeSlot(11, flatbuffers.number_types.UOffsetTFlags.py_type(inputArgCounts), 0)
|
||||
|
||||
def AddInputArgCounts(builder, inputArgCounts):
|
||||
NodeAddInputArgCounts(builder, inputArgCounts)
|
||||
|
||||
def NodeStartInputArgCountsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartInputArgCountsVector(builder, numElems: int) -> int:
|
||||
return NodeStartInputArgCountsVector(builder, numElems)
|
||||
|
||||
def NodeAddImplicitInputs(builder, implicitInputs):
|
||||
builder.PrependUOffsetTRelativeSlot(12, flatbuffers.number_types.UOffsetTFlags.py_type(implicitInputs), 0)
|
||||
|
||||
def AddImplicitInputs(builder, implicitInputs):
|
||||
NodeAddImplicitInputs(builder, implicitInputs)
|
||||
|
||||
def NodeStartImplicitInputsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartImplicitInputsVector(builder, numElems: int) -> int:
|
||||
return NodeStartImplicitInputsVector(builder, numElems)
|
||||
|
||||
def NodeEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return NodeEnd(builder)
|
||||
@@ -0,0 +1,126 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class NodeEdge(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = NodeEdge()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsNodeEdge(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def NodeEdgeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# NodeEdge
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# NodeEdge
|
||||
def NodeIndex(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# NodeEdge
|
||||
def InputEdges(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 12
|
||||
from ort_flatbuffers_py.fbs.EdgeEnd import EdgeEnd
|
||||
obj = EdgeEnd()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# NodeEdge
|
||||
def InputEdgesLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# NodeEdge
|
||||
def InputEdgesIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
return o == 0
|
||||
|
||||
# NodeEdge
|
||||
def OutputEdges(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 12
|
||||
from ort_flatbuffers_py.fbs.EdgeEnd import EdgeEnd
|
||||
obj = EdgeEnd()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# NodeEdge
|
||||
def OutputEdgesLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# NodeEdge
|
||||
def OutputEdgesIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
return o == 0
|
||||
|
||||
def NodeEdgeStart(builder):
|
||||
builder.StartObject(3)
|
||||
|
||||
def Start(builder):
|
||||
NodeEdgeStart(builder)
|
||||
|
||||
def NodeEdgeAddNodeIndex(builder, nodeIndex):
|
||||
builder.PrependUint32Slot(0, nodeIndex, 0)
|
||||
|
||||
def AddNodeIndex(builder, nodeIndex):
|
||||
NodeEdgeAddNodeIndex(builder, nodeIndex)
|
||||
|
||||
def NodeEdgeAddInputEdges(builder, inputEdges):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(inputEdges), 0)
|
||||
|
||||
def AddInputEdges(builder, inputEdges):
|
||||
NodeEdgeAddInputEdges(builder, inputEdges)
|
||||
|
||||
def NodeEdgeStartInputEdgesVector(builder, numElems):
|
||||
return builder.StartVector(12, numElems, 4)
|
||||
|
||||
def StartInputEdgesVector(builder, numElems: int) -> int:
|
||||
return NodeEdgeStartInputEdgesVector(builder, numElems)
|
||||
|
||||
def NodeEdgeAddOutputEdges(builder, outputEdges):
|
||||
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(outputEdges), 0)
|
||||
|
||||
def AddOutputEdges(builder, outputEdges):
|
||||
NodeEdgeAddOutputEdges(builder, outputEdges)
|
||||
|
||||
def NodeEdgeStartOutputEdgesVector(builder, numElems):
|
||||
return builder.StartVector(12, numElems, 4)
|
||||
|
||||
def StartOutputEdgesVector(builder, numElems: int) -> int:
|
||||
return NodeEdgeStartOutputEdgesVector(builder, numElems)
|
||||
|
||||
def NodeEdgeEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return NodeEdgeEnd(builder)
|
||||
@@ -0,0 +1,7 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
class NodeType(object):
|
||||
Primitive = 0
|
||||
Fused = 1
|
||||
@@ -0,0 +1,160 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
# nodes to consider for a runtime optimization
|
||||
# see corresponding type in onnxruntime/core/graph/runtime_optimization_record.h
|
||||
class NodesToOptimizeIndices(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = NodesToOptimizeIndices()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsNodesToOptimizeIndices(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def NodesToOptimizeIndicesBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# NodesToOptimizeIndices
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# NodesToOptimizeIndices
|
||||
def NodeIndices(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return 0
|
||||
|
||||
# NodesToOptimizeIndices
|
||||
def NodeIndicesAsNumpy(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
|
||||
return 0
|
||||
|
||||
# NodesToOptimizeIndices
|
||||
def NodeIndicesLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# NodesToOptimizeIndices
|
||||
def NodeIndicesIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
return o == 0
|
||||
|
||||
# NodesToOptimizeIndices
|
||||
def NumInputs(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# NodesToOptimizeIndices
|
||||
def NumOutputs(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# NodesToOptimizeIndices
|
||||
def HasVariadicInput(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
if o != 0:
|
||||
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
|
||||
return False
|
||||
|
||||
# NodesToOptimizeIndices
|
||||
def HasVariadicOutput(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
||||
if o != 0:
|
||||
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
|
||||
return False
|
||||
|
||||
# NodesToOptimizeIndices
|
||||
def NumVariadicInputs(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# NodesToOptimizeIndices
|
||||
def NumVariadicOutputs(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
def NodesToOptimizeIndicesStart(builder):
|
||||
builder.StartObject(7)
|
||||
|
||||
def Start(builder):
|
||||
NodesToOptimizeIndicesStart(builder)
|
||||
|
||||
def NodesToOptimizeIndicesAddNodeIndices(builder, nodeIndices):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(nodeIndices), 0)
|
||||
|
||||
def AddNodeIndices(builder, nodeIndices):
|
||||
NodesToOptimizeIndicesAddNodeIndices(builder, nodeIndices)
|
||||
|
||||
def NodesToOptimizeIndicesStartNodeIndicesVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartNodeIndicesVector(builder, numElems: int) -> int:
|
||||
return NodesToOptimizeIndicesStartNodeIndicesVector(builder, numElems)
|
||||
|
||||
def NodesToOptimizeIndicesAddNumInputs(builder, numInputs):
|
||||
builder.PrependUint32Slot(1, numInputs, 0)
|
||||
|
||||
def AddNumInputs(builder, numInputs):
|
||||
NodesToOptimizeIndicesAddNumInputs(builder, numInputs)
|
||||
|
||||
def NodesToOptimizeIndicesAddNumOutputs(builder, numOutputs):
|
||||
builder.PrependUint32Slot(2, numOutputs, 0)
|
||||
|
||||
def AddNumOutputs(builder, numOutputs):
|
||||
NodesToOptimizeIndicesAddNumOutputs(builder, numOutputs)
|
||||
|
||||
def NodesToOptimizeIndicesAddHasVariadicInput(builder, hasVariadicInput):
|
||||
builder.PrependBoolSlot(3, hasVariadicInput, 0)
|
||||
|
||||
def AddHasVariadicInput(builder, hasVariadicInput):
|
||||
NodesToOptimizeIndicesAddHasVariadicInput(builder, hasVariadicInput)
|
||||
|
||||
def NodesToOptimizeIndicesAddHasVariadicOutput(builder, hasVariadicOutput):
|
||||
builder.PrependBoolSlot(4, hasVariadicOutput, 0)
|
||||
|
||||
def AddHasVariadicOutput(builder, hasVariadicOutput):
|
||||
NodesToOptimizeIndicesAddHasVariadicOutput(builder, hasVariadicOutput)
|
||||
|
||||
def NodesToOptimizeIndicesAddNumVariadicInputs(builder, numVariadicInputs):
|
||||
builder.PrependUint32Slot(5, numVariadicInputs, 0)
|
||||
|
||||
def AddNumVariadicInputs(builder, numVariadicInputs):
|
||||
NodesToOptimizeIndicesAddNumVariadicInputs(builder, numVariadicInputs)
|
||||
|
||||
def NodesToOptimizeIndicesAddNumVariadicOutputs(builder, numVariadicOutputs):
|
||||
builder.PrependUint32Slot(6, numVariadicOutputs, 0)
|
||||
|
||||
def AddNumVariadicOutputs(builder, numVariadicOutputs):
|
||||
NodesToOptimizeIndicesAddNumVariadicOutputs(builder, numVariadicOutputs)
|
||||
|
||||
def NodesToOptimizeIndicesEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return NodesToOptimizeIndicesEnd(builder)
|
||||
@@ -0,0 +1,91 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class OpIdKernelTypeStrArgsEntry(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = OpIdKernelTypeStrArgsEntry()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsOpIdKernelTypeStrArgsEntry(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def OpIdKernelTypeStrArgsEntryBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# OpIdKernelTypeStrArgsEntry
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# OpIdKernelTypeStrArgsEntry
|
||||
def OpId(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# OpIdKernelTypeStrArgsEntry
|
||||
def KernelTypeStrArgs(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.KernelTypeStrArgsEntry import KernelTypeStrArgsEntry
|
||||
obj = KernelTypeStrArgsEntry()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# OpIdKernelTypeStrArgsEntry
|
||||
def KernelTypeStrArgsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# OpIdKernelTypeStrArgsEntry
|
||||
def KernelTypeStrArgsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
return o == 0
|
||||
|
||||
def OpIdKernelTypeStrArgsEntryStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
OpIdKernelTypeStrArgsEntryStart(builder)
|
||||
|
||||
def OpIdKernelTypeStrArgsEntryAddOpId(builder, opId):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(opId), 0)
|
||||
|
||||
def AddOpId(builder, opId):
|
||||
OpIdKernelTypeStrArgsEntryAddOpId(builder, opId)
|
||||
|
||||
def OpIdKernelTypeStrArgsEntryAddKernelTypeStrArgs(builder, kernelTypeStrArgs):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(kernelTypeStrArgs), 0)
|
||||
|
||||
def AddKernelTypeStrArgs(builder, kernelTypeStrArgs):
|
||||
OpIdKernelTypeStrArgsEntryAddKernelTypeStrArgs(builder, kernelTypeStrArgs)
|
||||
|
||||
def OpIdKernelTypeStrArgsEntryStartKernelTypeStrArgsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartKernelTypeStrArgsVector(builder, numElems: int) -> int:
|
||||
return OpIdKernelTypeStrArgsEntryStartKernelTypeStrArgsVector(builder, numElems)
|
||||
|
||||
def OpIdKernelTypeStrArgsEntryEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return OpIdKernelTypeStrArgsEntryEnd(builder)
|
||||
@@ -0,0 +1,67 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class OperatorSetId(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = OperatorSetId()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsOperatorSetId(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def OperatorSetIdBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# OperatorSetId
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# OperatorSetId
|
||||
def Domain(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# OperatorSetId
|
||||
def Version(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
def OperatorSetIdStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
OperatorSetIdStart(builder)
|
||||
|
||||
def OperatorSetIdAddDomain(builder, domain):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(domain), 0)
|
||||
|
||||
def AddDomain(builder, domain):
|
||||
OperatorSetIdAddDomain(builder, domain)
|
||||
|
||||
def OperatorSetIdAddVersion(builder, version):
|
||||
builder.PrependInt64Slot(1, version, 0)
|
||||
|
||||
def AddVersion(builder, version):
|
||||
OperatorSetIdAddVersion(builder, version)
|
||||
|
||||
def OperatorSetIdEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return OperatorSetIdEnd(builder)
|
||||
@@ -0,0 +1,117 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class OptimizerGroup(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = OptimizerGroup()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsOptimizerGroup(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def OptimizerGroupBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)
|
||||
|
||||
# OptimizerGroup
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# OptimizerGroup
|
||||
def GroupName(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# OptimizerGroup
|
||||
def Step(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# OptimizerGroup
|
||||
def InitialLearningRate(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
|
||||
return 0.0
|
||||
|
||||
# OptimizerGroup
|
||||
def OptimizerStates(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.ParameterOptimizerState import ParameterOptimizerState
|
||||
obj = ParameterOptimizerState()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# OptimizerGroup
|
||||
def OptimizerStatesLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# OptimizerGroup
|
||||
def OptimizerStatesIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
return o == 0
|
||||
|
||||
def OptimizerGroupStart(builder):
|
||||
builder.StartObject(4)
|
||||
|
||||
def Start(builder):
|
||||
OptimizerGroupStart(builder)
|
||||
|
||||
def OptimizerGroupAddGroupName(builder, groupName):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(groupName), 0)
|
||||
|
||||
def AddGroupName(builder, groupName):
|
||||
OptimizerGroupAddGroupName(builder, groupName)
|
||||
|
||||
def OptimizerGroupAddStep(builder, step):
|
||||
builder.PrependInt64Slot(1, step, 0)
|
||||
|
||||
def AddStep(builder, step):
|
||||
OptimizerGroupAddStep(builder, step)
|
||||
|
||||
def OptimizerGroupAddInitialLearningRate(builder, initialLearningRate):
|
||||
builder.PrependFloat32Slot(2, initialLearningRate, 0.0)
|
||||
|
||||
def AddInitialLearningRate(builder, initialLearningRate):
|
||||
OptimizerGroupAddInitialLearningRate(builder, initialLearningRate)
|
||||
|
||||
def OptimizerGroupAddOptimizerStates(builder, optimizerStates):
|
||||
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(optimizerStates), 0)
|
||||
|
||||
def AddOptimizerStates(builder, optimizerStates):
|
||||
OptimizerGroupAddOptimizerStates(builder, optimizerStates)
|
||||
|
||||
def OptimizerGroupStartOptimizerStatesVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartOptimizerStatesVector(builder, numElems: int) -> int:
|
||||
return OptimizerGroupStartOptimizerStatesVector(builder, numElems)
|
||||
|
||||
def OptimizerGroupEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return OptimizerGroupEnd(builder)
|
||||
@@ -0,0 +1,91 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class ParameterOptimizerState(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = ParameterOptimizerState()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsParameterOptimizerState(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def ParameterOptimizerStateBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)
|
||||
|
||||
# ParameterOptimizerState
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# ParameterOptimizerState
|
||||
def ParamName(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# ParameterOptimizerState
|
||||
def Momentums(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.Tensor import Tensor
|
||||
obj = Tensor()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# ParameterOptimizerState
|
||||
def MomentumsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# ParameterOptimizerState
|
||||
def MomentumsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
return o == 0
|
||||
|
||||
def ParameterOptimizerStateStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
ParameterOptimizerStateStart(builder)
|
||||
|
||||
def ParameterOptimizerStateAddParamName(builder, paramName):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(paramName), 0)
|
||||
|
||||
def AddParamName(builder, paramName):
|
||||
ParameterOptimizerStateAddParamName(builder, paramName)
|
||||
|
||||
def ParameterOptimizerStateAddMomentums(builder, momentums):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(momentums), 0)
|
||||
|
||||
def AddMomentums(builder, momentums):
|
||||
ParameterOptimizerStateAddMomentums(builder, momentums)
|
||||
|
||||
def ParameterOptimizerStateStartMomentumsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartMomentumsVector(builder, numElems: int) -> int:
|
||||
return ParameterOptimizerStateStartMomentumsVector(builder, numElems)
|
||||
|
||||
def ParameterOptimizerStateEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return ParameterOptimizerStateEnd(builder)
|
||||
@@ -0,0 +1,152 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class PropertyBag(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = PropertyBag()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsPropertyBag(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def PropertyBagBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)
|
||||
|
||||
# PropertyBag
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# PropertyBag
|
||||
def Ints(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.IntProperty import IntProperty
|
||||
obj = IntProperty()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# PropertyBag
|
||||
def IntsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# PropertyBag
|
||||
def IntsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
return o == 0
|
||||
|
||||
# PropertyBag
|
||||
def Floats(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.FloatProperty import FloatProperty
|
||||
obj = FloatProperty()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# PropertyBag
|
||||
def FloatsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# PropertyBag
|
||||
def FloatsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
return o == 0
|
||||
|
||||
# PropertyBag
|
||||
def Strings(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.StringProperty import StringProperty
|
||||
obj = StringProperty()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# PropertyBag
|
||||
def StringsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# PropertyBag
|
||||
def StringsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
return o == 0
|
||||
|
||||
def PropertyBagStart(builder):
|
||||
builder.StartObject(3)
|
||||
|
||||
def Start(builder):
|
||||
PropertyBagStart(builder)
|
||||
|
||||
def PropertyBagAddInts(builder, ints):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(ints), 0)
|
||||
|
||||
def AddInts(builder, ints):
|
||||
PropertyBagAddInts(builder, ints)
|
||||
|
||||
def PropertyBagStartIntsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartIntsVector(builder, numElems: int) -> int:
|
||||
return PropertyBagStartIntsVector(builder, numElems)
|
||||
|
||||
def PropertyBagAddFloats(builder, floats):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(floats), 0)
|
||||
|
||||
def AddFloats(builder, floats):
|
||||
PropertyBagAddFloats(builder, floats)
|
||||
|
||||
def PropertyBagStartFloatsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartFloatsVector(builder, numElems: int) -> int:
|
||||
return PropertyBagStartFloatsVector(builder, numElems)
|
||||
|
||||
def PropertyBagAddStrings(builder, strings):
|
||||
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(strings), 0)
|
||||
|
||||
def AddStrings(builder, strings):
|
||||
PropertyBagAddStrings(builder, strings)
|
||||
|
||||
def PropertyBagStartStringsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartStringsVector(builder, numElems: int) -> int:
|
||||
return PropertyBagStartStringsVector(builder, numElems)
|
||||
|
||||
def PropertyBagEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return PropertyBagEnd(builder)
|
||||
@@ -0,0 +1,105 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
# a single runtime optimization
|
||||
# see corresponding type in onnxruntime/core/graph/runtime_optimization_record.h
|
||||
class RuntimeOptimizationRecord(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = RuntimeOptimizationRecord()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsRuntimeOptimizationRecord(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def RuntimeOptimizationRecordBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# RuntimeOptimizationRecord
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# RuntimeOptimizationRecord
|
||||
def ActionId(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# RuntimeOptimizationRecord
|
||||
def NodesToOptimizeIndices(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.NodesToOptimizeIndices import NodesToOptimizeIndices
|
||||
obj = NodesToOptimizeIndices()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# RuntimeOptimizationRecord
|
||||
def ProducedOpIds(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return ""
|
||||
|
||||
# RuntimeOptimizationRecord
|
||||
def ProducedOpIdsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# RuntimeOptimizationRecord
|
||||
def ProducedOpIdsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
return o == 0
|
||||
|
||||
def RuntimeOptimizationRecordStart(builder):
|
||||
builder.StartObject(4)
|
||||
|
||||
def Start(builder):
|
||||
RuntimeOptimizationRecordStart(builder)
|
||||
|
||||
def RuntimeOptimizationRecordAddActionId(builder, actionId):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(actionId), 0)
|
||||
|
||||
def AddActionId(builder, actionId):
|
||||
RuntimeOptimizationRecordAddActionId(builder, actionId)
|
||||
|
||||
def RuntimeOptimizationRecordAddNodesToOptimizeIndices(builder, nodesToOptimizeIndices):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(nodesToOptimizeIndices), 0)
|
||||
|
||||
def AddNodesToOptimizeIndices(builder, nodesToOptimizeIndices):
|
||||
RuntimeOptimizationRecordAddNodesToOptimizeIndices(builder, nodesToOptimizeIndices)
|
||||
|
||||
def RuntimeOptimizationRecordAddProducedOpIds(builder, producedOpIds):
|
||||
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(producedOpIds), 0)
|
||||
|
||||
def AddProducedOpIds(builder, producedOpIds):
|
||||
RuntimeOptimizationRecordAddProducedOpIds(builder, producedOpIds)
|
||||
|
||||
def RuntimeOptimizationRecordStartProducedOpIdsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartProducedOpIdsVector(builder, numElems: int) -> int:
|
||||
return RuntimeOptimizationRecordStartProducedOpIdsVector(builder, numElems)
|
||||
|
||||
def RuntimeOptimizationRecordEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return RuntimeOptimizationRecordEnd(builder)
|
||||
@@ -0,0 +1,91 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class RuntimeOptimizationRecordContainerEntry(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = RuntimeOptimizationRecordContainerEntry()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsRuntimeOptimizationRecordContainerEntry(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def RuntimeOptimizationRecordContainerEntryBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# RuntimeOptimizationRecordContainerEntry
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# RuntimeOptimizationRecordContainerEntry
|
||||
def OptimizerName(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# RuntimeOptimizationRecordContainerEntry
|
||||
def RuntimeOptimizationRecords(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.RuntimeOptimizationRecord import RuntimeOptimizationRecord
|
||||
obj = RuntimeOptimizationRecord()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# RuntimeOptimizationRecordContainerEntry
|
||||
def RuntimeOptimizationRecordsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# RuntimeOptimizationRecordContainerEntry
|
||||
def RuntimeOptimizationRecordsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
return o == 0
|
||||
|
||||
def RuntimeOptimizationRecordContainerEntryStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
RuntimeOptimizationRecordContainerEntryStart(builder)
|
||||
|
||||
def RuntimeOptimizationRecordContainerEntryAddOptimizerName(builder, optimizerName):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(optimizerName), 0)
|
||||
|
||||
def AddOptimizerName(builder, optimizerName):
|
||||
RuntimeOptimizationRecordContainerEntryAddOptimizerName(builder, optimizerName)
|
||||
|
||||
def RuntimeOptimizationRecordContainerEntryAddRuntimeOptimizationRecords(builder, runtimeOptimizationRecords):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(runtimeOptimizationRecords), 0)
|
||||
|
||||
def AddRuntimeOptimizationRecords(builder, runtimeOptimizationRecords):
|
||||
RuntimeOptimizationRecordContainerEntryAddRuntimeOptimizationRecords(builder, runtimeOptimizationRecords)
|
||||
|
||||
def RuntimeOptimizationRecordContainerEntryStartRuntimeOptimizationRecordsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartRuntimeOptimizationRecordsVector(builder, numElems: int) -> int:
|
||||
return RuntimeOptimizationRecordContainerEntryStartRuntimeOptimizationRecordsVector(builder, numElems)
|
||||
|
||||
def RuntimeOptimizationRecordContainerEntryEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return RuntimeOptimizationRecordContainerEntryEnd(builder)
|
||||
@@ -0,0 +1,79 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class RuntimeOptimizations(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = RuntimeOptimizations()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsRuntimeOptimizations(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def RuntimeOptimizationsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# RuntimeOptimizations
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# mapping from optimizer name to [RuntimeOptimizationRecord]
|
||||
# RuntimeOptimizations
|
||||
def Records(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.RuntimeOptimizationRecordContainerEntry import RuntimeOptimizationRecordContainerEntry
|
||||
obj = RuntimeOptimizationRecordContainerEntry()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# RuntimeOptimizations
|
||||
def RecordsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# RuntimeOptimizations
|
||||
def RecordsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
return o == 0
|
||||
|
||||
def RuntimeOptimizationsStart(builder):
|
||||
builder.StartObject(1)
|
||||
|
||||
def Start(builder):
|
||||
RuntimeOptimizationsStart(builder)
|
||||
|
||||
def RuntimeOptimizationsAddRecords(builder, records):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(records), 0)
|
||||
|
||||
def AddRecords(builder, records):
|
||||
RuntimeOptimizationsAddRecords(builder, records)
|
||||
|
||||
def RuntimeOptimizationsStartRecordsVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartRecordsVector(builder, numElems: int) -> int:
|
||||
return RuntimeOptimizationsStartRecordsVector(builder, numElems)
|
||||
|
||||
def RuntimeOptimizationsEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return RuntimeOptimizationsEnd(builder)
|
||||
@@ -0,0 +1,58 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class SequenceType(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = SequenceType()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsSequenceType(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def SequenceTypeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# SequenceType
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# SequenceType
|
||||
def ElemType(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.TypeInfo import TypeInfo
|
||||
obj = TypeInfo()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
def SequenceTypeStart(builder):
|
||||
builder.StartObject(1)
|
||||
|
||||
def Start(builder):
|
||||
SequenceTypeStart(builder)
|
||||
|
||||
def SequenceTypeAddElemType(builder, elemType):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(elemType), 0)
|
||||
|
||||
def AddElemType(builder, elemType):
|
||||
SequenceTypeAddElemType(builder, elemType)
|
||||
|
||||
def SequenceTypeEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return SequenceTypeEnd(builder)
|
||||
@@ -0,0 +1,78 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class Shape(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = Shape()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsShape(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def ShapeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# Shape
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# Shape
|
||||
def Dim(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
x = self._tab.Vector(o)
|
||||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
||||
x = self._tab.Indirect(x)
|
||||
from ort_flatbuffers_py.fbs.Dimension import Dimension
|
||||
obj = Dimension()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# Shape
|
||||
def DimLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Shape
|
||||
def DimIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
return o == 0
|
||||
|
||||
def ShapeStart(builder):
|
||||
builder.StartObject(1)
|
||||
|
||||
def Start(builder):
|
||||
ShapeStart(builder)
|
||||
|
||||
def ShapeAddDim(builder, dim):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(dim), 0)
|
||||
|
||||
def AddDim(builder, dim):
|
||||
ShapeAddDim(builder, dim)
|
||||
|
||||
def ShapeStartDimVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartDimVector(builder, numElems: int) -> int:
|
||||
return ShapeStartDimVector(builder, numElems)
|
||||
|
||||
def ShapeEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return ShapeEnd(builder)
|
||||
@@ -0,0 +1,114 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class SparseTensor(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = SparseTensor()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsSparseTensor(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def SparseTensorBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# SparseTensor
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# SparseTensor
|
||||
def Values(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.Tensor import Tensor
|
||||
obj = Tensor()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# SparseTensor
|
||||
def Indices(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.Tensor import Tensor
|
||||
obj = Tensor()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
# SparseTensor
|
||||
def Dims(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.Get(flatbuffers.number_types.Int64Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8))
|
||||
return 0
|
||||
|
||||
# SparseTensor
|
||||
def DimsAsNumpy(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o)
|
||||
return 0
|
||||
|
||||
# SparseTensor
|
||||
def DimsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# SparseTensor
|
||||
def DimsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
return o == 0
|
||||
|
||||
def SparseTensorStart(builder):
|
||||
builder.StartObject(3)
|
||||
|
||||
def Start(builder):
|
||||
SparseTensorStart(builder)
|
||||
|
||||
def SparseTensorAddValues(builder, values):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(values), 0)
|
||||
|
||||
def AddValues(builder, values):
|
||||
SparseTensorAddValues(builder, values)
|
||||
|
||||
def SparseTensorAddIndices(builder, indices):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(indices), 0)
|
||||
|
||||
def AddIndices(builder, indices):
|
||||
SparseTensorAddIndices(builder, indices)
|
||||
|
||||
def SparseTensorAddDims(builder, dims):
|
||||
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(dims), 0)
|
||||
|
||||
def AddDims(builder, dims):
|
||||
SparseTensorAddDims(builder, dims)
|
||||
|
||||
def SparseTensorStartDimsVector(builder, numElems):
|
||||
return builder.StartVector(8, numElems, 8)
|
||||
|
||||
def StartDimsVector(builder, numElems: int) -> int:
|
||||
return SparseTensorStartDimsVector(builder, numElems)
|
||||
|
||||
def SparseTensorEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return SparseTensorEnd(builder)
|
||||
@@ -0,0 +1,67 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class StringProperty(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = StringProperty()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsStringProperty(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def StringPropertyBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)
|
||||
|
||||
# StringProperty
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# StringProperty
|
||||
def Name(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# StringProperty
|
||||
def Value(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
def StringPropertyStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
StringPropertyStart(builder)
|
||||
|
||||
def StringPropertyAddName(builder, name):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
||||
|
||||
def AddName(builder, name):
|
||||
StringPropertyAddName(builder, name)
|
||||
|
||||
def StringPropertyAddValue(builder, value):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(value), 0)
|
||||
|
||||
def AddValue(builder, value):
|
||||
StringPropertyAddValue(builder, value)
|
||||
|
||||
def StringPropertyEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return StringPropertyEnd(builder)
|
||||
@@ -0,0 +1,67 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class StringStringEntry(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = StringStringEntry()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsStringStringEntry(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def StringStringEntryBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# StringStringEntry
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# StringStringEntry
|
||||
def Key(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# StringStringEntry
|
||||
def Value(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
def StringStringEntryStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
StringStringEntryStart(builder)
|
||||
|
||||
def StringStringEntryAddKey(builder, key):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(key), 0)
|
||||
|
||||
def AddKey(builder, key):
|
||||
StringStringEntryAddKey(builder, key)
|
||||
|
||||
def StringStringEntryAddValue(builder, value):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(value), 0)
|
||||
|
||||
def AddValue(builder, value):
|
||||
StringStringEntryAddValue(builder, value)
|
||||
|
||||
def StringStringEntryEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return StringStringEntryEnd(builder)
|
||||
@@ -0,0 +1,203 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class Tensor(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = Tensor()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsTensor(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def TensorBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# Tensor
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# Tensor
|
||||
def Name(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# Tensor
|
||||
def DocString(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# Tensor
|
||||
def Dims(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.Get(flatbuffers.number_types.Int64Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8))
|
||||
return 0
|
||||
|
||||
# Tensor
|
||||
def DimsAsNumpy(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o)
|
||||
return 0
|
||||
|
||||
# Tensor
|
||||
def DimsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Tensor
|
||||
def DimsIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
return o == 0
|
||||
|
||||
# Tensor
|
||||
def DataType(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# Tensor
|
||||
def RawData(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.Get(flatbuffers.number_types.Uint8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1))
|
||||
return 0
|
||||
|
||||
# Tensor
|
||||
def RawDataAsNumpy(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
||||
if o != 0:
|
||||
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint8Flags, o)
|
||||
return 0
|
||||
|
||||
# Tensor
|
||||
def RawDataLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Tensor
|
||||
def RawDataIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
||||
return o == 0
|
||||
|
||||
# Tensor
|
||||
def StringData(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return ""
|
||||
|
||||
# Tensor
|
||||
def StringDataLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# Tensor
|
||||
def StringDataIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
|
||||
return o == 0
|
||||
|
||||
# Tensor
|
||||
def ExternalDataOffset(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos)
|
||||
return -1
|
||||
|
||||
def TensorStart(builder):
|
||||
builder.StartObject(7)
|
||||
|
||||
def Start(builder):
|
||||
TensorStart(builder)
|
||||
|
||||
def TensorAddName(builder, name):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
||||
|
||||
def AddName(builder, name):
|
||||
TensorAddName(builder, name)
|
||||
|
||||
def TensorAddDocString(builder, docString):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(docString), 0)
|
||||
|
||||
def AddDocString(builder, docString):
|
||||
TensorAddDocString(builder, docString)
|
||||
|
||||
def TensorAddDims(builder, dims):
|
||||
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(dims), 0)
|
||||
|
||||
def AddDims(builder, dims):
|
||||
TensorAddDims(builder, dims)
|
||||
|
||||
def TensorStartDimsVector(builder, numElems):
|
||||
return builder.StartVector(8, numElems, 8)
|
||||
|
||||
def StartDimsVector(builder, numElems: int) -> int:
|
||||
return TensorStartDimsVector(builder, numElems)
|
||||
|
||||
def TensorAddDataType(builder, dataType):
|
||||
builder.PrependInt32Slot(3, dataType, 0)
|
||||
|
||||
def AddDataType(builder, dataType):
|
||||
TensorAddDataType(builder, dataType)
|
||||
|
||||
def TensorAddRawData(builder, rawData):
|
||||
builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(rawData), 0)
|
||||
|
||||
def AddRawData(builder, rawData):
|
||||
TensorAddRawData(builder, rawData)
|
||||
|
||||
def TensorStartRawDataVector(builder, numElems):
|
||||
return builder.StartVector(1, numElems, 1)
|
||||
|
||||
def StartRawDataVector(builder, numElems: int) -> int:
|
||||
return TensorStartRawDataVector(builder, numElems)
|
||||
|
||||
def TensorAddStringData(builder, stringData):
|
||||
builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(stringData), 0)
|
||||
|
||||
def AddStringData(builder, stringData):
|
||||
TensorAddStringData(builder, stringData)
|
||||
|
||||
def TensorStartStringDataVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def StartStringDataVector(builder, numElems: int) -> int:
|
||||
return TensorStartStringDataVector(builder, numElems)
|
||||
|
||||
def TensorAddExternalDataOffset(builder, externalDataOffset):
|
||||
builder.PrependInt64Slot(6, externalDataOffset, -1)
|
||||
|
||||
def AddExternalDataOffset(builder, externalDataOffset):
|
||||
TensorAddExternalDataOffset(builder, externalDataOffset)
|
||||
|
||||
def TensorEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return TensorEnd(builder)
|
||||
@@ -0,0 +1,26 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
class TensorDataType(object):
|
||||
UNDEFINED = 0
|
||||
FLOAT = 1
|
||||
UINT8 = 2
|
||||
INT8 = 3
|
||||
UINT16 = 4
|
||||
INT16 = 5
|
||||
INT32 = 6
|
||||
INT64 = 7
|
||||
STRING = 8
|
||||
BOOL = 9
|
||||
FLOAT16 = 10
|
||||
DOUBLE = 11
|
||||
UINT32 = 12
|
||||
UINT64 = 13
|
||||
COMPLEX64 = 14
|
||||
COMPLEX128 = 15
|
||||
BFLOAT16 = 16
|
||||
FLOAT8E4M3FN = 17
|
||||
FLOAT8E4M3FNUZ = 18
|
||||
FLOAT8E5M2 = 19
|
||||
FLOAT8E5M2FNUZ = 20
|
||||
@@ -0,0 +1,71 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class TensorTypeAndShape(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = TensorTypeAndShape()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsTensorTypeAndShape(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def TensorTypeAndShapeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# TensorTypeAndShape
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# TensorTypeAndShape
|
||||
def ElemType(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# TensorTypeAndShape
|
||||
def Shape(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.Shape import Shape
|
||||
obj = Shape()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
def TensorTypeAndShapeStart(builder):
|
||||
builder.StartObject(2)
|
||||
|
||||
def Start(builder):
|
||||
TensorTypeAndShapeStart(builder)
|
||||
|
||||
def TensorTypeAndShapeAddElemType(builder, elemType):
|
||||
builder.PrependInt32Slot(0, elemType, 0)
|
||||
|
||||
def AddElemType(builder, elemType):
|
||||
TensorTypeAndShapeAddElemType(builder, elemType)
|
||||
|
||||
def TensorTypeAndShapeAddShape(builder, shape):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0)
|
||||
|
||||
def AddShape(builder, shape):
|
||||
TensorTypeAndShapeAddShape(builder, shape)
|
||||
|
||||
def TensorTypeAndShapeEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return TensorTypeAndShapeEnd(builder)
|
||||
@@ -0,0 +1,83 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class TypeInfo(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = TypeInfo()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsTypeInfo(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def TypeInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# TypeInfo
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# TypeInfo
|
||||
def Denotation(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# TypeInfo
|
||||
def ValueType(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
# TypeInfo
|
||||
def Value(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
from flatbuffers.table import Table
|
||||
obj = Table(bytearray(), 0)
|
||||
self._tab.Union(obj, o)
|
||||
return obj
|
||||
return None
|
||||
|
||||
def TypeInfoStart(builder):
|
||||
builder.StartObject(3)
|
||||
|
||||
def Start(builder):
|
||||
TypeInfoStart(builder)
|
||||
|
||||
def TypeInfoAddDenotation(builder, denotation):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(denotation), 0)
|
||||
|
||||
def AddDenotation(builder, denotation):
|
||||
TypeInfoAddDenotation(builder, denotation)
|
||||
|
||||
def TypeInfoAddValueType(builder, valueType):
|
||||
builder.PrependUint8Slot(1, valueType, 0)
|
||||
|
||||
def AddValueType(builder, valueType):
|
||||
TypeInfoAddValueType(builder, valueType)
|
||||
|
||||
def TypeInfoAddValue(builder, value):
|
||||
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(value), 0)
|
||||
|
||||
def AddValue(builder, value):
|
||||
TypeInfoAddValue(builder, value)
|
||||
|
||||
def TypeInfoEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return TypeInfoEnd(builder)
|
||||
@@ -0,0 +1,9 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
class TypeInfoValue(object):
|
||||
NONE = 0
|
||||
tensor_type = 1
|
||||
sequence_type = 2
|
||||
map_type = 3
|
||||
@@ -0,0 +1,84 @@
|
||||
# automatically generated by the FlatBuffers compiler, do not modify
|
||||
|
||||
# namespace: fbs
|
||||
|
||||
import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class ValueInfo(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
def GetRootAs(cls, buf, offset=0):
|
||||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
||||
x = ValueInfo()
|
||||
x.Init(buf, n + offset)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def GetRootAsValueInfo(cls, buf, offset=0):
|
||||
"""This method is deprecated. Please switch to GetRootAs."""
|
||||
return cls.GetRootAs(buf, offset)
|
||||
@classmethod
|
||||
def ValueInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
|
||||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
|
||||
|
||||
# ValueInfo
|
||||
def Init(self, buf, pos):
|
||||
self._tab = flatbuffers.table.Table(buf, pos)
|
||||
|
||||
# ValueInfo
|
||||
def Name(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# ValueInfo
|
||||
def DocString(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# ValueInfo
|
||||
def Type(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
if o != 0:
|
||||
x = self._tab.Indirect(o + self._tab.Pos)
|
||||
from ort_flatbuffers_py.fbs.TypeInfo import TypeInfo
|
||||
obj = TypeInfo()
|
||||
obj.Init(self._tab.Bytes, x)
|
||||
return obj
|
||||
return None
|
||||
|
||||
def ValueInfoStart(builder):
|
||||
builder.StartObject(3)
|
||||
|
||||
def Start(builder):
|
||||
ValueInfoStart(builder)
|
||||
|
||||
def ValueInfoAddName(builder, name):
|
||||
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
||||
|
||||
def AddName(builder, name):
|
||||
ValueInfoAddName(builder, name)
|
||||
|
||||
def ValueInfoAddDocString(builder, docString):
|
||||
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(docString), 0)
|
||||
|
||||
def AddDocString(builder, docString):
|
||||
ValueInfoAddDocString(builder, docString)
|
||||
|
||||
def ValueInfoAddType(builder, type):
|
||||
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(type), 0)
|
||||
|
||||
def AddType(builder, type):
|
||||
ValueInfoAddType(builder, type)
|
||||
|
||||
def ValueInfoEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
def End(builder):
|
||||
return ValueInfoEnd(builder)
|
||||
@@ -0,0 +1,6 @@
|
||||
from os.path import dirname, basename, isfile, join, splitext
|
||||
import glob
|
||||
modules = glob.glob(join(dirname(__file__), "*.py"))
|
||||
__all__ = [splitext(basename(f))[0] for f in modules if isfile(f) and not f.endswith('__init__.py')]
|
||||
|
||||
from . import *
|
||||
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import ort_flatbuffers_py.fbs as fbs
|
||||
|
||||
from .operator_type_usage_processors import OperatorTypeUsageManager
|
||||
|
||||
|
||||
class OrtFormatModelProcessor:
|
||||
"Class to process an ORT format model and determine required operators and types."
|
||||
|
||||
def __init__(self, model_path: str, required_ops: dict, processors: OperatorTypeUsageManager):
|
||||
"""
|
||||
Initialize ORT format model processor
|
||||
:param model_path: Path to model to load
|
||||
:param required_ops: Dictionary required operator information will be added to.
|
||||
:param processors: Operator type usage processors which will be called for each matching Node.
|
||||
"""
|
||||
self._required_ops = required_ops # dictionary of {domain: {opset:[operators]}}
|
||||
self._file = open(model_path, "rb").read() # noqa: SIM115
|
||||
self._buffer = bytearray(self._file)
|
||||
if not fbs.InferenceSession.InferenceSession.InferenceSessionBufferHasIdentifier(self._buffer, 0):
|
||||
raise RuntimeError(f"File does not appear to be a valid ORT format model: '{model_path}'")
|
||||
self._model = fbs.InferenceSession.InferenceSession.GetRootAsInferenceSession(self._buffer, 0).Model()
|
||||
self._op_type_processors = processors
|
||||
|
||||
@staticmethod
|
||||
def _setup_type_info(graph: fbs.Graph, outer_scope_value_typeinfo={}): # noqa: B006
|
||||
"""
|
||||
Setup the node args for this level of Graph.
|
||||
We copy the current list which represents the outer scope values, and add the local node args to that
|
||||
to create the valid list of values for the current Graph.
|
||||
:param graph: Graph to create NodeArg list for
|
||||
:param outer_scope_value_typeinfo: TypeInfo for outer scope values. Empty for the top-level graph in a model.
|
||||
:return: Dictionary of NodeArg name to TypeInfo
|
||||
"""
|
||||
value_name_to_typeinfo = outer_scope_value_typeinfo.copy()
|
||||
for j in range(graph.NodeArgsLength()):
|
||||
n = graph.NodeArgs(j)
|
||||
value_name_to_typeinfo[n.Name()] = n.Type() # TypeInfo for this NodeArg's name
|
||||
|
||||
return value_name_to_typeinfo
|
||||
|
||||
def _add_required_op(self, domain: str, opset: int, op_type: str):
|
||||
if domain not in self._required_ops:
|
||||
self._required_ops[domain] = {opset: {op_type}}
|
||||
elif opset not in self._required_ops[domain]:
|
||||
self._required_ops[domain][opset] = {op_type}
|
||||
else:
|
||||
self._required_ops[domain][opset].add(op_type)
|
||||
|
||||
def _process_graph(self, graph: fbs.Graph, outer_scope_value_typeinfo: dict):
|
||||
"""
|
||||
Process one level of the Graph, descending into any subgraphs when they are found
|
||||
:param outer_scope_value_typeinfo: Outer scope NodeArg dictionary from ancestor graphs
|
||||
"""
|
||||
# Merge the TypeInfo for all values in this level of the graph with the outer scope value TypeInfo.
|
||||
value_name_to_typeinfo = OrtFormatModelProcessor._setup_type_info(graph, outer_scope_value_typeinfo)
|
||||
|
||||
for i in range(graph.NodesLength()):
|
||||
node = graph.Nodes(i)
|
||||
|
||||
optype = node.OpType().decode()
|
||||
domain = node.Domain().decode() or "ai.onnx" # empty domain defaults to ai.onnx
|
||||
|
||||
self._add_required_op(domain, node.SinceVersion(), optype)
|
||||
|
||||
if self._op_type_processors:
|
||||
self._op_type_processors.process_node(node, value_name_to_typeinfo)
|
||||
|
||||
# Read all the attributes
|
||||
for j in range(node.AttributesLength()):
|
||||
attr = node.Attributes(j)
|
||||
attr_type = attr.Type()
|
||||
if attr_type == fbs.AttributeType.AttributeType.GRAPH:
|
||||
self._process_graph(attr.G(), value_name_to_typeinfo)
|
||||
elif attr_type == fbs.AttributeType.AttributeType.GRAPHS:
|
||||
# the ONNX spec doesn't currently define any operators that have multiple graphs in an attribute
|
||||
# so entering this 'elif' isn't currently possible
|
||||
for k in range(attr.GraphsLength()):
|
||||
self._process_graph(attr.Graphs(k), value_name_to_typeinfo)
|
||||
|
||||
def process(self):
|
||||
graph = self._model.Graph()
|
||||
outer_scope_value_typeinfo = {} # no outer scope values for the main graph
|
||||
self._process_graph(graph, outer_scope_value_typeinfo)
|
||||
@@ -0,0 +1,85 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import ort_flatbuffers_py.fbs as fbs
|
||||
|
||||
|
||||
class FbsTypeInfo:
|
||||
"Class to provide conversion between ORT flatbuffers schema values and C++ types"
|
||||
|
||||
tensordatatype_to_string = { # noqa: RUF012
|
||||
fbs.TensorDataType.TensorDataType.FLOAT: "float",
|
||||
fbs.TensorDataType.TensorDataType.UINT8: "uint8_t",
|
||||
fbs.TensorDataType.TensorDataType.INT8: "int8_t",
|
||||
fbs.TensorDataType.TensorDataType.UINT16: "uint16_t",
|
||||
fbs.TensorDataType.TensorDataType.INT16: "int16_t",
|
||||
fbs.TensorDataType.TensorDataType.INT32: "int32_t",
|
||||
fbs.TensorDataType.TensorDataType.INT64: "int64_t",
|
||||
fbs.TensorDataType.TensorDataType.STRING: "std::string",
|
||||
fbs.TensorDataType.TensorDataType.BOOL: "bool",
|
||||
fbs.TensorDataType.TensorDataType.FLOAT16: "MLFloat16",
|
||||
fbs.TensorDataType.TensorDataType.DOUBLE: "double",
|
||||
fbs.TensorDataType.TensorDataType.UINT32: "uint32_t",
|
||||
fbs.TensorDataType.TensorDataType.UINT64: "uint64_t",
|
||||
# fbs.TensorDataType.TensorDataType.COMPLEX64: 'complex64 is not supported',
|
||||
# fbs.TensorDataType.TensorDataType.COMPLEX128: 'complex128 is not supported',
|
||||
fbs.TensorDataType.TensorDataType.BFLOAT16: "BFloat16",
|
||||
fbs.TensorDataType.TensorDataType.FLOAT8E4M3FN: "Float8E4M3FN",
|
||||
fbs.TensorDataType.TensorDataType.FLOAT8E4M3FNUZ: "Float8E4M3FNUZ",
|
||||
fbs.TensorDataType.TensorDataType.FLOAT8E5M2: "Float8E5M2",
|
||||
fbs.TensorDataType.TensorDataType.FLOAT8E5M2FNUZ: "Float8E5M2FNUZ",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def typeinfo_to_str(type: fbs.TypeInfo):
|
||||
value_type = type.ValueType()
|
||||
value = type.Value()
|
||||
type_str = "unknown"
|
||||
|
||||
if value_type == fbs.TypeInfoValue.TypeInfoValue.tensor_type:
|
||||
tensor_type_and_shape = fbs.TensorTypeAndShape.TensorTypeAndShape()
|
||||
tensor_type_and_shape.Init(value.Bytes, value.Pos)
|
||||
elem_type = tensor_type_and_shape.ElemType()
|
||||
type_str = FbsTypeInfo.tensordatatype_to_string[elem_type]
|
||||
|
||||
elif value_type == fbs.TypeInfoValue.TypeInfoValue.map_type:
|
||||
map_type = fbs.MapType.MapType()
|
||||
map_type.init(value.Bytes, value.Pos)
|
||||
key_type = map_type.KeyType() # TensorDataType
|
||||
key_type_str = FbsTypeInfo.tensordatatype_to_string[key_type]
|
||||
value_type = map_type.ValueType() # TypeInfo
|
||||
value_type_str = FbsTypeInfo.typeinfo_to_str(value_type)
|
||||
type_str = f"std::map<{key_type_str},{value_type_str}>"
|
||||
|
||||
elif value_type == fbs.TypeInfoValue.TypeInfoValue.sequence_type:
|
||||
sequence_type = fbs.SequenceType.SequenceType()
|
||||
sequence_type.Init(value.Bytes, value.Pos)
|
||||
elem_type = sequence_type.ElemType() # TypeInfo
|
||||
elem_type_str = FbsTypeInfo.typeinfo_to_str(elem_type)
|
||||
# TODO: Decide if we need to wrap the type in a std::vector. Issue is that the element type is internal
|
||||
# to the onnxruntime::Tensor class so we're really returning the type inside the Tensor not vector<Tensor>.
|
||||
# For now, return the element type (which will be the Tensor element type, or a map<A,B>) as
|
||||
# an operator input or output will either be a sequence or a not, so we don't need to disambiguate
|
||||
# between the two (i.e. we know if the returned value refers to the contents of a sequence, and can
|
||||
# handle whether it's the element type of a Tensor in the sequence, or the map type in a sequence of maps
|
||||
# due to this).
|
||||
type_str = elem_type_str
|
||||
else:
|
||||
raise ValueError(f"Unknown or missing value type of {value_type}")
|
||||
|
||||
return type_str
|
||||
|
||||
|
||||
def get_typeinfo(name: str, value_name_to_typeinfo: dict) -> fbs.TypeInfo:
|
||||
"Lookup a name in a dictionary mapping value name to TypeInfo."
|
||||
if name not in value_name_to_typeinfo:
|
||||
raise RuntimeError("Missing TypeInfo entry for " + name)
|
||||
|
||||
return value_name_to_typeinfo[name] # TypeInfo object
|
||||
|
||||
|
||||
def value_name_to_typestr(name: str, value_name_to_typeinfo: dict):
|
||||
"Lookup TypeInfo for value name and convert to a string representing the C++ type."
|
||||
type = get_typeinfo(name, value_name_to_typeinfo)
|
||||
type_str = FbsTypeInfo.typeinfo_to_str(type)
|
||||
return type_str
|
||||
@@ -0,0 +1,61 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pathlib
|
||||
import typing
|
||||
|
||||
from ..logger import get_logger
|
||||
from .operator_type_usage_processors import OperatorTypeUsageManager
|
||||
from .ort_model_processor import OrtFormatModelProcessor
|
||||
|
||||
log = get_logger("ort_format_model.utils")
|
||||
|
||||
|
||||
def _extract_ops_and_types_from_ort_models(model_files: typing.Iterable[pathlib.Path], enable_type_reduction: bool):
|
||||
required_ops = {}
|
||||
op_type_usage_manager = OperatorTypeUsageManager() if enable_type_reduction else None
|
||||
|
||||
for model_file in model_files:
|
||||
if not model_file.is_file():
|
||||
raise ValueError(f"Path is not a file: '{model_file}'")
|
||||
model_processor = OrtFormatModelProcessor(str(model_file), required_ops, op_type_usage_manager)
|
||||
model_processor.process() # this updates required_ops and op_type_processors
|
||||
|
||||
return required_ops, op_type_usage_manager
|
||||
|
||||
|
||||
def create_config_from_models(
|
||||
model_files: typing.Iterable[pathlib.Path], output_file: pathlib.Path, enable_type_reduction: bool
|
||||
):
|
||||
"""
|
||||
Create a configuration file with required operators and optionally required types.
|
||||
:param model_files: Model files to use to generate the configuration file.
|
||||
:param output_file: File to write configuration to.
|
||||
:param enable_type_reduction: Include required type information for individual operators in the configuration.
|
||||
"""
|
||||
|
||||
required_ops, op_type_processors = _extract_ops_and_types_from_ort_models(model_files, enable_type_reduction)
|
||||
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_file, "w") as out:
|
||||
out.write("# Generated from model/s:\n")
|
||||
out.writelines(f"# - {model_file}\n" for model_file in sorted(model_files))
|
||||
|
||||
for domain in sorted(required_ops.keys()):
|
||||
for opset in sorted(required_ops[domain].keys()):
|
||||
ops = required_ops[domain][opset]
|
||||
if ops:
|
||||
out.write(f"{domain};{opset};")
|
||||
if enable_type_reduction:
|
||||
# type string is empty if op hasn't been seen
|
||||
entries = [
|
||||
"{}{}".format(op, op_type_processors.get_config_entry(domain, op) or "")
|
||||
for op in sorted(ops)
|
||||
]
|
||||
else:
|
||||
entries = sorted(ops)
|
||||
|
||||
out.write("{}\n".format(",".join(entries)))
|
||||
|
||||
log.info("Created config in %s", output_file)
|
||||
@@ -0,0 +1,129 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Support for registering ONNX Runtime's built-in contrib ops with
|
||||
PyTorch-ONNX exporter (torch.onnx.export).
|
||||
"""
|
||||
|
||||
import typing
|
||||
|
||||
try:
|
||||
# TODO(justinchuby): Create a function to alert users when torch is not installed
|
||||
import torch
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError( # noqa: B904
|
||||
"This module is only useful in combination with PyTorch. To install PyTorch see https://pytorch.org/."
|
||||
)
|
||||
|
||||
from torch.onnx import symbolic_helper
|
||||
|
||||
_OPSET_VERSION = 1
|
||||
_registered_ops: typing.AbstractSet[str] = set()
|
||||
|
||||
|
||||
def _reg(symbolic_fn: typing.Callable, namespace: str = ""):
|
||||
name = f"{namespace}::{symbolic_fn.__name__}"
|
||||
torch.onnx.register_custom_op_symbolic(name, symbolic_fn, _OPSET_VERSION)
|
||||
_registered_ops.add(name)
|
||||
|
||||
|
||||
def register():
|
||||
"""Register ONNX Runtime's built-in contrib ops.
|
||||
|
||||
Should be run before torch.onnx.export().
|
||||
"""
|
||||
|
||||
def grid_sampler(g, input, grid, mode, padding_mode, align_corners):
|
||||
# mode
|
||||
# 'bilinear' : onnx::Constant[value={0}]
|
||||
# 'nearest' : onnx::Constant[value={1}]
|
||||
# 'bicubic' : onnx::Constant[value={2}]
|
||||
# padding_mode
|
||||
# 'zeros' : onnx::Constant[value={0}]
|
||||
# 'border' : onnx::Constant[value={1}]
|
||||
# 'reflection' : onnx::Constant[value={2}]
|
||||
mode = symbolic_helper._maybe_get_const(mode, "i")
|
||||
padding_mode = symbolic_helper._maybe_get_const(padding_mode, "i")
|
||||
mode_str = ["bilinear", "nearest", "bicubic"][mode]
|
||||
padding_mode_str = ["zeros", "border", "reflection"][padding_mode]
|
||||
align_corners = int(symbolic_helper._maybe_get_const(align_corners, "b"))
|
||||
|
||||
# From opset v13 onward, the output shape can be specified with
|
||||
# (N, C, H, W) (N, H_out, W_out, 2) => (N, C, H_out, W_out)
|
||||
# input_shape = input.type().sizes()
|
||||
# gird_shape = grid.type().sizes()
|
||||
# output_shape = input_shape[:2] + gird_shape[1:3]
|
||||
# g.op(...).setType(input.type().with_sizes(output_shape))
|
||||
|
||||
return g.op(
|
||||
"com.microsoft::GridSample",
|
||||
input,
|
||||
grid,
|
||||
mode_s=mode_str,
|
||||
padding_mode_s=padding_mode_str,
|
||||
align_corners_i=align_corners,
|
||||
)
|
||||
|
||||
_reg(grid_sampler)
|
||||
|
||||
def inverse(g, self):
|
||||
return g.op("com.microsoft::Inverse", self).setType(self.type())
|
||||
|
||||
_reg(inverse)
|
||||
|
||||
@torch.onnx.symbolic_helper.parse_args("v", "s")
|
||||
def gelu(g, self: torch._C.Value, approximate: str = "none"):
|
||||
# Use microsoft::Gelu for performance if possible. It only supports approximate == "none"
|
||||
if approximate == "none":
|
||||
return g.op("com.microsoft::Gelu", self).setType(self.type())
|
||||
return torch.onnx.symbolic_opset9.gelu(g, self, approximate)
|
||||
|
||||
_reg(gelu)
|
||||
|
||||
def triu(g, self, diagonal):
|
||||
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=1).setType(self.type())
|
||||
|
||||
_reg(triu)
|
||||
|
||||
def tril(g, self, diagonal):
|
||||
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=0).setType(self.type())
|
||||
|
||||
_reg(tril)
|
||||
|
||||
@torch.onnx.symbolic_helper.parse_args("v")
|
||||
def DynamicTimeWarping(g, self): # noqa: N802
|
||||
return g.op("com.microsoft::DynamicTimeWarping", self)
|
||||
|
||||
_reg(DynamicTimeWarping, namespace="onnxruntime")
|
||||
|
||||
def UnfoldTensor(g, self, dim, size, step): # noqa: N802
|
||||
dim = int(symbolic_helper._maybe_get_const(dim, "i"))
|
||||
size = int(symbolic_helper._maybe_get_const(size, "i"))
|
||||
step = int(symbolic_helper._maybe_get_const(step, "i"))
|
||||
return g.op(
|
||||
"com.microsoft::UnfoldTensor",
|
||||
self,
|
||||
dim_i=dim,
|
||||
size_i=size,
|
||||
step_i=step,
|
||||
).setType(self.type().with_sizes([None, None, None, None, size]))
|
||||
|
||||
_reg(UnfoldTensor, namespace="onnxruntime")
|
||||
|
||||
|
||||
def unregister():
|
||||
"""Unregister ONNX Runtime's built-in contrib ops."""
|
||||
for name in _registered_ops:
|
||||
try:
|
||||
torch.onnx.unregister_custom_op_symbolic(name, _OPSET_VERSION)
|
||||
except AttributeError:
|
||||
# The symbolic_registry module was removed in PyTorch 1.13.
|
||||
# We are importing it here for backwards compatibility
|
||||
# because unregister_custom_op_symbolic is not available before PyTorch 1.12
|
||||
from torch.onnx import symbolic_registry # noqa: PLC0415
|
||||
|
||||
namespace, kind = name.split("::")
|
||||
for version in symbolic_helper._onnx_stable_opsets:
|
||||
if version >= _OPSET_VERSION and symbolic_registry.is_registered_op(kind, namespace, version):
|
||||
del symbolic_registry._registry[(namespace, version)][kind]
|
||||
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import inspect
|
||||
from collections import abc
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _parse_inputs_for_onnx_export(all_input_parameters, inputs, kwargs):
|
||||
# extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L433
|
||||
|
||||
def _add_input(name, input):
|
||||
"""Returns number of expanded inputs that _add_input processed"""
|
||||
|
||||
if input is None:
|
||||
# Drop all None inputs and return 0.
|
||||
return 0
|
||||
|
||||
num_expanded_non_none_inputs = 0
|
||||
if isinstance(input, abc.Sequence):
|
||||
# If the input is a sequence (like a list), expand the list so that
|
||||
# each element of the list is an input by itself.
|
||||
for i, val in enumerate(input):
|
||||
# Name each input with the index appended to the original name of the
|
||||
# argument.
|
||||
num_expanded_non_none_inputs += _add_input(f"{name}_{i}", val)
|
||||
|
||||
# Return here since the list by itself is not a valid input.
|
||||
# All the elements of the list have already been added as inputs individually.
|
||||
return num_expanded_non_none_inputs
|
||||
elif isinstance(input, abc.Mapping):
|
||||
# If the input is a mapping (like a dict), expand the dict so that
|
||||
# each element of the dict is an input by itself.
|
||||
for key, val in input.items():
|
||||
num_expanded_non_none_inputs += _add_input(f"{name}_{key}", val)
|
||||
|
||||
# Return here since the dict by itself is not a valid input.
|
||||
# All the elements of the dict have already been added as inputs individually.
|
||||
return num_expanded_non_none_inputs
|
||||
|
||||
# InputInfo should contain all the names irrespective of whether they are
|
||||
# a part of the onnx graph or not.
|
||||
input_names.append(name)
|
||||
|
||||
# A single input non none input was processed, return 1
|
||||
return 1
|
||||
|
||||
input_names = []
|
||||
var_positional_idx = 0
|
||||
num_expanded_non_none_positional_inputs = 0
|
||||
|
||||
for input_idx, input_parameter in enumerate(all_input_parameters):
|
||||
if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL:
|
||||
# VAR_POSITIONAL parameter carries all *args parameters from original forward method
|
||||
for args_i in range(input_idx, len(inputs)):
|
||||
name = f"{input_parameter.name}_{var_positional_idx}"
|
||||
var_positional_idx += 1
|
||||
inp = inputs[args_i]
|
||||
num_expanded_non_none_positional_inputs += _add_input(name, inp)
|
||||
elif (
|
||||
input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY
|
||||
or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
or input_parameter.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
):
|
||||
# All positional non-*args and non-**kwargs are processed here
|
||||
name = input_parameter.name
|
||||
inp = None
|
||||
input_idx += var_positional_idx # noqa: PLW2901
|
||||
is_positional = True
|
||||
if input_idx < len(inputs) and inputs[input_idx] is not None:
|
||||
inp = inputs[input_idx]
|
||||
elif name in kwargs and kwargs[name] is not None:
|
||||
inp = kwargs[name]
|
||||
is_positional = False
|
||||
num_expanded_non_none_inputs_local = _add_input(name, inp)
|
||||
if is_positional:
|
||||
num_expanded_non_none_positional_inputs += num_expanded_non_none_inputs_local
|
||||
elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
# **kwargs is always the last argument of forward()
|
||||
for name, inp in kwargs.items():
|
||||
if name not in input_names:
|
||||
_add_input(name, inp)
|
||||
|
||||
return input_names
|
||||
|
||||
|
||||
def _flatten_module_input(names, args, kwargs):
|
||||
"""Flatten args and kwargs in a single tuple of tensors."""
|
||||
# extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L110
|
||||
|
||||
def is_primitive_type(value):
|
||||
return type(value) in {int, bool, float}
|
||||
|
||||
def to_tensor(value):
|
||||
return torch.tensor(value)
|
||||
|
||||
ret = [to_tensor(arg) if is_primitive_type(arg) else arg for arg in args]
|
||||
ret += [
|
||||
to_tensor(kwargs[name]) if is_primitive_type(kwargs[name]) else kwargs[name] for name in names if name in kwargs
|
||||
]
|
||||
|
||||
# if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter
|
||||
# happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise.
|
||||
if not kwargs:
|
||||
ret.append({})
|
||||
|
||||
return tuple(ret)
|
||||
|
||||
|
||||
def infer_input_info(module: torch.nn.Module, *inputs, **kwargs):
|
||||
"""
|
||||
Infer the input names and order from the arguments used to execute a PyTorch module for usage exporting
|
||||
the model via torch.onnx.export.
|
||||
Assumes model is on CPU. Use `module.to(torch.device('cpu'))` if it isn't.
|
||||
|
||||
Example usage:
|
||||
input_names, inputs_as_tuple = infer_input_info(module, ...)
|
||||
torch.onnx.export(module, inputs_as_type, 'model.onnx', input_names=input_names, output_names=[...], ...)
|
||||
|
||||
:param module: Module
|
||||
:param inputs: Positional inputs
|
||||
:param kwargs: Keyword argument inputs
|
||||
:return: Tuple of ordered input names and input values. These can be used directly with torch.onnx.export as the
|
||||
`input_names` and `inputs` arguments.
|
||||
"""
|
||||
module_parameters = inspect.signature(module.forward).parameters.values()
|
||||
input_names = _parse_inputs_for_onnx_export(module_parameters, inputs, kwargs)
|
||||
inputs_as_tuple = _flatten_module_input(input_names, inputs, kwargs)
|
||||
|
||||
return input_names, inputs_as_tuple
|
||||
@@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import onnx
|
||||
|
||||
|
||||
def optimize_qdq_model():
|
||||
parser = argparse.ArgumentParser(
|
||||
os.path.basename(__file__),
|
||||
description="Update a QDQ format ONNX model to ensure optimal performance when executed using ONNX Runtime.",
|
||||
)
|
||||
|
||||
parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.")
|
||||
parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write updated ONNX model to.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model = onnx.load(str(args.input_model.resolve(strict=True)))
|
||||
|
||||
# run QDQ model optimizations here
|
||||
|
||||
# Originally, the fixing up of DQ nodes with multiple consumers was implemented as one such optimization.
|
||||
# That was moved to an ORT graph transformer.
|
||||
print("As of ORT 1.15, the fixing up of DQ nodes with multiple consumers is done by an ORT graph transformer.")
|
||||
|
||||
# There are no optimizations being run currently but we expect that there may be in the future.
|
||||
|
||||
onnx.save(model, str(args.output_model.resolve()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
optimize_qdq_model()
|
||||
@@ -0,0 +1,292 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import onnx
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
|
||||
def graph_topological_sort(graph):
|
||||
deps_count = [0] * len(graph.node) # dependency count of each node
|
||||
deps_to_nodes = {} # input to node indice
|
||||
sorted_nodes = [] # initialize sorted_nodes
|
||||
for node_idx, node in enumerate(graph.node):
|
||||
# CANNOT use len(node.input) directly because input can be optional
|
||||
deps_count[node_idx] = sum(1 for _ in node.input if _)
|
||||
if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs
|
||||
sorted_nodes.append(graph.node[node_idx])
|
||||
continue
|
||||
|
||||
for input_name in node.input:
|
||||
if input_name not in deps_to_nodes:
|
||||
deps_to_nodes[input_name] = [node_idx]
|
||||
else:
|
||||
deps_to_nodes[input_name].append(node_idx)
|
||||
|
||||
# Note: this logic only applies to top level graph since a sub graph could use intializer from parent graph
|
||||
initializer_names = [init.name for init in graph.initializer]
|
||||
graph_input_names = [input.name for input in graph.input]
|
||||
input_names = initializer_names + graph_input_names
|
||||
input_names.sort()
|
||||
prev_input_name = None
|
||||
for input_name in input_names:
|
||||
if prev_input_name == input_name:
|
||||
continue
|
||||
|
||||
prev_input_name = input_name
|
||||
if input_name in deps_to_nodes:
|
||||
for node_idx in deps_to_nodes[input_name]:
|
||||
deps_count[node_idx] = deps_count[node_idx] - 1
|
||||
if deps_count[node_idx] == 0:
|
||||
sorted_nodes.append(graph.node[node_idx])
|
||||
|
||||
start = 0
|
||||
end = len(sorted_nodes)
|
||||
|
||||
while start < end:
|
||||
for output in sorted_nodes[start].output:
|
||||
if output in deps_to_nodes:
|
||||
for node_idx in deps_to_nodes[output]:
|
||||
deps_count[node_idx] = deps_count[node_idx] - 1
|
||||
if deps_count[node_idx] == 0:
|
||||
sorted_nodes.append(graph.node[node_idx])
|
||||
end = end + 1
|
||||
start = start + 1
|
||||
|
||||
assert end == len(graph.node), "Graph is not a DAG"
|
||||
graph.ClearField("node")
|
||||
graph.node.extend(sorted_nodes)
|
||||
|
||||
|
||||
class QnnTensorStruct:
|
||||
def __init__(self):
|
||||
self.name = ""
|
||||
self.onnx_data_type = TensorProto.FLOAT
|
||||
self.dim = []
|
||||
|
||||
|
||||
def qnn_data_type_to_onnx_data_type(qnn_data_type):
|
||||
# QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UINT_8
|
||||
if qnn_data_type == 0x0408 or qnn_data_type == 0x0108:
|
||||
return TensorProto.UINT8
|
||||
# QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_UINT_16
|
||||
elif qnn_data_type == 0x0416 or qnn_data_type == 0x0116:
|
||||
return TensorProto.UINT16
|
||||
# QNN_DATATYPE_UFIXED_POINT_32 QNN_DATATYPE_UINT_32
|
||||
elif qnn_data_type == 0x0432 or qnn_data_type == 0x0132:
|
||||
return TensorProto.UINT32
|
||||
# QNN_DATATYPE_UINT_64
|
||||
elif qnn_data_type == 0x0164:
|
||||
return TensorProto.UINT64
|
||||
# QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_INT_8
|
||||
elif qnn_data_type == 0x0308 or qnn_data_type == 0x0008:
|
||||
return TensorProto.INT8
|
||||
# QNN_DATATYPE_FIXED_POINT_16 QNN_DATATYPE_INT_16
|
||||
elif qnn_data_type == 0x0316 or qnn_data_type == 0x0016:
|
||||
return TensorProto.INT16
|
||||
# QNN_DATATYPE_FIXED_POINT_32 QNN_DATATYPE_INT_32
|
||||
elif qnn_data_type == 0x0332 or qnn_data_type == 0x0032:
|
||||
return TensorProto.INT32
|
||||
# QNN_DATATYPE_INT_64
|
||||
elif qnn_data_type == 0x0064:
|
||||
return TensorProto.INT64
|
||||
# QNN_DATATYPE_FLOAT_16
|
||||
elif qnn_data_type == 0x0216:
|
||||
return TensorProto.FLOAT16
|
||||
# QNN_DATATYPE_FLOAT_32
|
||||
elif qnn_data_type == 0x0232:
|
||||
return TensorProto.FLOAT
|
||||
# QNN_DATATYPE_BOOL_8
|
||||
elif qnn_data_type == 0x0508:
|
||||
return TensorProto.BOOL
|
||||
else:
|
||||
return TensorProto.UNDEFINED
|
||||
|
||||
|
||||
def parse_qnn_json_file(qnn_json_file_path, qnn_input_output_tensor_dic):
|
||||
with open(qnn_json_file_path) as qnn_json_file:
|
||||
qnn_json = json.load(qnn_json_file)
|
||||
assert "graph" in qnn_json, "QNN converted json file not valid. Can't find graph."
|
||||
assert "tensors" in qnn_json["graph"], "QNN converted json file not valid. Can't find tensors."
|
||||
for qnn_tensor_name, qnn_tensor_attribute in qnn_json["graph"]["tensors"].items():
|
||||
# type:0 - QNN input tensor, type:1 - QNN output tensor
|
||||
assert (
|
||||
"type" in qnn_tensor_attribute
|
||||
and "data_type" in qnn_tensor_attribute
|
||||
and "dims" in qnn_tensor_attribute
|
||||
), "QNN converted json file not valid. Can't find some keys from tensors"
|
||||
if qnn_tensor_attribute["type"] == 0 or qnn_tensor_attribute["type"] == 1:
|
||||
qnn_tensor = QnnTensorStruct()
|
||||
qnn_tensor.name = qnn_tensor_name
|
||||
qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type(qnn_tensor_attribute["data_type"])
|
||||
qnn_tensor.dim = qnn_tensor_attribute["dims"]
|
||||
qnn_input_output_tensor_dic[qnn_tensor_name] = qnn_tensor
|
||||
|
||||
assert len(qnn_input_output_tensor_dic) > 1, (
|
||||
"Converted QNN model not valid. It should have at least 1 input & 1 output."
|
||||
)
|
||||
|
||||
|
||||
def compare_onnx_shape_with_qnn_shape(onnx_dims, qnn_dims):
|
||||
assert len(onnx_dims) == len(qnn_dims), "Onnx shape and Qnn shape has different rank."
|
||||
return all(onnx_dims[i].dim_value == qnn_dims[i] for i in range(len(onnx_dims)))
|
||||
|
||||
|
||||
def gen_to_channel_first_perm(rank):
|
||||
assert rank > 2, "Shape rank should >2 for the Transpose node."
|
||||
perm = []
|
||||
perm.append(0)
|
||||
perm.append(rank - 1)
|
||||
for i in range(1, rank - 1):
|
||||
perm.append(i) # noqa: PERF402
|
||||
|
||||
return perm
|
||||
|
||||
|
||||
def gen_to_channel_last_perm(rank):
|
||||
assert rank > 2, "Shape rank should >2 for the Transpose node."
|
||||
perm = []
|
||||
perm.append(0)
|
||||
for i in range(2, rank):
|
||||
perm.append(i) # noqa: PERF402
|
||||
perm.append(1)
|
||||
|
||||
return perm
|
||||
|
||||
|
||||
# Onnxruntime QNN EP can support context binary file generated by QNN tool chain. However QNN generated context binary file
|
||||
# uses channel last data layout and 8 bits or 16 bits for input and output.
|
||||
# This script gets the QNN model input & output information from QNN converted model_net.json file, compare them with Onnx model
|
||||
# and inserts Cast, Transpose nodes to Onnx model if required
|
||||
def main():
|
||||
parser = ArgumentParser(
|
||||
"Insert Cast, Transpose nodes into Onnx model to make it aligned with QNN generated context binary."
|
||||
)
|
||||
parser.add_argument("-m", "--onnx_model", help="Required. Path to Onnx model file.", required=True, type=str)
|
||||
parser.add_argument(
|
||||
"-q", "--qnn_json", help="Required. Path to Qnn converted model_net.json file.", required=True, type=str
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse Qnn model_net.json file to get the graph input output information
|
||||
qnn_input_output_tensor_dic = {}
|
||||
parse_qnn_json_file(args.qnn_json, qnn_input_output_tensor_dic)
|
||||
|
||||
model = onnx.load(args.onnx_model)
|
||||
|
||||
nodes_to_add = []
|
||||
# Tranch the tensor name change to update the consumer nodes
|
||||
graph_input_output_name_dic = {}
|
||||
for graph_input in model.graph.input:
|
||||
if graph_input.name in qnn_input_output_tensor_dic:
|
||||
input_name_fater_node_insert = graph_input.name
|
||||
qnn_input_tensor = qnn_input_output_tensor_dic[graph_input.name]
|
||||
# Insert Cast node if Onnx input and Qnn input has different data type
|
||||
if graph_input.type.tensor_type.elem_type != qnn_input_tensor.onnx_data_type:
|
||||
# Insert Cast node
|
||||
cast_input_name = input_name_fater_node_insert
|
||||
cast_output_name = cast_input_name + "_qnn_cast"
|
||||
input_cast_node = helper.make_node(
|
||||
"Cast",
|
||||
name=cast_output_name,
|
||||
inputs=[cast_input_name],
|
||||
outputs=[cast_output_name],
|
||||
to=graph_input.type.tensor_type.elem_type,
|
||||
)
|
||||
# Change input data type to Qnn input data type
|
||||
graph_input.type.tensor_type.elem_type = qnn_input_tensor.onnx_data_type
|
||||
nodes_to_add.extend([input_cast_node])
|
||||
input_name_fater_node_insert = cast_output_name
|
||||
graph_input_output_name_dic[graph_input.name] = cast_output_name
|
||||
|
||||
if not compare_onnx_shape_with_qnn_shape(graph_input.type.tensor_type.shape.dim, qnn_input_tensor.dim):
|
||||
# Add Transpose node (channel last to channel first)
|
||||
transpose_perm = gen_to_channel_first_perm(len(graph_input.type.tensor_type.shape.dim))
|
||||
transpose_input_name = input_name_fater_node_insert
|
||||
transpose_output_name = transpose_input_name + "_qnn_trans"
|
||||
input_transpose_node = helper.make_node(
|
||||
"Transpose",
|
||||
name=transpose_output_name,
|
||||
inputs=[transpose_input_name],
|
||||
outputs=[transpose_output_name],
|
||||
perm=transpose_perm,
|
||||
)
|
||||
nodes_to_add.extend([input_transpose_node])
|
||||
graph_input_output_name_dic[graph_input.name] = transpose_output_name
|
||||
|
||||
# Change input shape to Qnn input shape
|
||||
for i in range(len(graph_input.type.tensor_type.shape.dim)):
|
||||
graph_input.type.tensor_type.shape.dim[i].dim_value = qnn_input_tensor.dim[i]
|
||||
else:
|
||||
raise AssertionError("Error: Onnx model input: " + graph_input.name + " not exist from QNN model input.")
|
||||
|
||||
for graph_output in model.graph.output:
|
||||
if graph_output.name in qnn_input_output_tensor_dic:
|
||||
output_name_after_node_insert = graph_output.name
|
||||
# Insert Cast node if Onnx input and Qnn input has idfferent data type
|
||||
qnn_output_tensor = qnn_input_output_tensor_dic[graph_output.name]
|
||||
if graph_output.type.tensor_type.elem_type != qnn_output_tensor.onnx_data_type:
|
||||
# Insert Cast node
|
||||
cast_output_name = output_name_after_node_insert
|
||||
cast_input_name = cast_output_name + "_qnn_cast"
|
||||
output_cast_node = helper.make_node(
|
||||
"Cast",
|
||||
name=cast_input_name,
|
||||
inputs=[cast_input_name],
|
||||
outputs=[cast_output_name],
|
||||
to=qnn_output_tensor.onnx_data_type,
|
||||
)
|
||||
# Change output data type to Onn output data type
|
||||
graph_output.type.tensor_type.elem_type = qnn_output_tensor.onnx_data_type
|
||||
nodes_to_add.extend([output_cast_node])
|
||||
output_name_after_node_insert = cast_input_name
|
||||
graph_input_output_name_dic[graph_output.name] = cast_input_name
|
||||
|
||||
if not compare_onnx_shape_with_qnn_shape(graph_output.type.tensor_type.shape.dim, qnn_output_tensor.dim):
|
||||
# Add Transpose node (channel first to channel last)
|
||||
transpose_perm = gen_to_channel_last_perm(len(graph_output.type.tensor_type.shape.dim))
|
||||
transpose_output_name = output_name_after_node_insert
|
||||
transpose_input_name = transpose_output_name + "_qnn_trans"
|
||||
output_transpose_node = helper.make_node(
|
||||
"Transpose",
|
||||
name=transpose_input_name,
|
||||
inputs=[transpose_input_name],
|
||||
outputs=[transpose_output_name],
|
||||
perm=transpose_perm,
|
||||
)
|
||||
nodes_to_add.extend([output_transpose_node])
|
||||
graph_input_output_name_dic[graph_output.name] = transpose_input_name
|
||||
|
||||
# Change output shape to Qnn output shape
|
||||
for i in range(len(graph_output.type.tensor_type.shape.dim)):
|
||||
graph_output.type.tensor_type.shape.dim[i].dim_value = qnn_input_output_tensor_dic[
|
||||
graph_output.name
|
||||
].dim[i]
|
||||
else:
|
||||
raise AssertionError("Error: Onnx model output: " + graph_output.name + " not exist from QNN model output.")
|
||||
|
||||
for node in model.graph.node:
|
||||
for node_input_index, node_input in enumerate(node.input):
|
||||
# update consumer node for graph inputs to connect to inserted node
|
||||
if node_input in graph_input_output_name_dic:
|
||||
node.input[node_input_index] = graph_input_output_name_dic[node_input]
|
||||
|
||||
for node_output_index, node_output in enumerate(node.output):
|
||||
# update producer node for graph outputs to connect to inserted node
|
||||
if node_output in graph_input_output_name_dic:
|
||||
node.output[node_output_index] = graph_input_output_name_dic[node_output]
|
||||
|
||||
model.graph.node.extend(nodes_to_add)
|
||||
graph_topological_sort(model.graph)
|
||||
|
||||
# Add extra parameter all_tensors_to_one_file=False, size_threshold=5000 if the model exceeds protobuf 2GB limit e.g below
|
||||
# onnx.save(model, args.onnx_model.replace(".onnx", "_add_trans.onnx"), all_tensors_to_one_file=False, size_threshold=5000)
|
||||
onnx.save(model, args.onnx_model.replace(".onnx", "_add_trans.onnx"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,364 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import onnx
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
|
||||
class QnnTensorStruct:
|
||||
def __init__(
|
||||
self, name="", onnx_data_type=TensorProto.FLOAT, is_quantized=False, scale=0.0, offset=0, dim=None, id=None
|
||||
):
|
||||
self.name = name
|
||||
self.onnx_data_type = onnx_data_type
|
||||
self.is_quantized = is_quantized
|
||||
self.scale = scale
|
||||
self.offset = offset
|
||||
self.dim = [] if dim is None else dim
|
||||
self.id = id
|
||||
|
||||
|
||||
def is_quantized_data_type(qnn_data_type, is_converter_json):
|
||||
if is_converter_json:
|
||||
# QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_FIXED_POINT_16
|
||||
return qnn_data_type == 0x0408 or qnn_data_type == 0x0416 or qnn_data_type == 0x0308 or qnn_data_type == 0x0316
|
||||
else:
|
||||
return (
|
||||
qnn_data_type == "QNN_DATATYPE_UFIXED_POINT_8"
|
||||
or qnn_data_type == "QNN_DATATYPE_UFIXED_POINT_16"
|
||||
or qnn_data_type == "QNN_DATATYPE_FIXED_POINT_8"
|
||||
or qnn_data_type == "QNN_DATATYPE_FIXED_POINT_16"
|
||||
)
|
||||
|
||||
|
||||
def qnn_data_type_to_onnx_data_type(qnn_data_type, is_converter_json):
|
||||
if is_converter_json:
|
||||
# QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UINT_8
|
||||
if qnn_data_type == 0x0408 or qnn_data_type == 0x0108:
|
||||
return TensorProto.UINT8
|
||||
# QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_UINT_16
|
||||
elif qnn_data_type == 0x0416 or qnn_data_type == 0x0116:
|
||||
return TensorProto.UINT16
|
||||
# QNN_DATATYPE_UFIXED_POINT_32 QNN_DATATYPE_UINT_32
|
||||
elif qnn_data_type == 0x0432 or qnn_data_type == 0x0132:
|
||||
return TensorProto.UINT32
|
||||
# QNN_DATATYPE_UINT_64
|
||||
elif qnn_data_type == 0x0164:
|
||||
return TensorProto.UINT64
|
||||
# QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_INT_8
|
||||
elif qnn_data_type == 0x0308 or qnn_data_type == 0x0008:
|
||||
return TensorProto.INT8
|
||||
# QNN_DATATYPE_FIXED_POINT_16 QNN_DATATYPE_INT_16
|
||||
elif qnn_data_type == 0x0316 or qnn_data_type == 0x0016:
|
||||
return TensorProto.INT16
|
||||
# QNN_DATATYPE_FIXED_POINT_32 QNN_DATATYPE_INT_32
|
||||
elif qnn_data_type == 0x0332 or qnn_data_type == 0x0032:
|
||||
return TensorProto.INT32
|
||||
# QNN_DATATYPE_INT_64
|
||||
elif qnn_data_type == 0x0064:
|
||||
return TensorProto.INT64
|
||||
# QNN_DATATYPE_FLOAT_16
|
||||
elif qnn_data_type == 0x0216:
|
||||
return TensorProto.FLOAT16
|
||||
# QNN_DATATYPE_FLOAT_32
|
||||
elif qnn_data_type == 0x0232:
|
||||
return TensorProto.FLOAT
|
||||
# QNN_DATATYPE_BOOL_8
|
||||
elif qnn_data_type == 0x0508:
|
||||
return TensorProto.BOOL
|
||||
else:
|
||||
return TensorProto.UNDEFINED
|
||||
else:
|
||||
# QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UINT_8
|
||||
if qnn_data_type == "QNN_DATATYPE_UFIXED_POINT_8" or qnn_data_type == "QNN_DATATYPE_UINT_8":
|
||||
return TensorProto.UINT8
|
||||
# QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_UINT_16
|
||||
elif qnn_data_type == "QNN_DATATYPE_UFIXED_POINT_16" or qnn_data_type == "QNN_DATATYPE_UINT_16":
|
||||
return TensorProto.UINT16
|
||||
# QNN_DATATYPE_UFIXED_POINT_32 QNN_DATATYPE_UINT_32
|
||||
elif qnn_data_type == "QNN_DATATYPE_UFIXED_POINT_32" or qnn_data_type == "QNN_DATATYPE_UINT_32":
|
||||
return TensorProto.UINT32
|
||||
# QNN_DATATYPE_UINT_64
|
||||
elif qnn_data_type == "QNN_DATATYPE_UINT_64":
|
||||
return TensorProto.UINT64
|
||||
# QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_INT_8
|
||||
elif qnn_data_type == "QNN_DATATYPE_FIXED_POINT_8" or qnn_data_type == "QNN_DATATYPE_INT_8":
|
||||
return TensorProto.INT8
|
||||
# QNN_DATATYPE_FIXED_POINT_16 QNN_DATATYPE_INT_16
|
||||
elif qnn_data_type == "QNN_DATATYPE_FIXED_POINT_16" or qnn_data_type == "QNN_DATATYPE_INT_16":
|
||||
return TensorProto.INT16
|
||||
# QNN_DATATYPE_FIXED_POINT_32 QNN_DATATYPE_INT_32
|
||||
elif qnn_data_type == "QNN_DATATYPE_FIXED_POINT_32" or qnn_data_type == "QNN_DATATYPE_INT_32":
|
||||
return TensorProto.INT32
|
||||
# QNN_DATATYPE_INT_64
|
||||
elif qnn_data_type == "QNN_DATATYPE_INT_64":
|
||||
return TensorProto.INT64
|
||||
# QNN_DATATYPE_FLOAT_16
|
||||
elif qnn_data_type == "QNN_DATATYPE_FLOAT_16":
|
||||
return TensorProto.FLOAT16
|
||||
# QNN_DATATYPE_FLOAT_32
|
||||
elif qnn_data_type == "QNN_DATATYPE_FLOAT_32":
|
||||
return TensorProto.FLOAT
|
||||
# QNN_DATATYPE_BOOL_8
|
||||
elif qnn_data_type == "QNN_DATATYPE_BOOL_8":
|
||||
return TensorProto.BOOL
|
||||
else:
|
||||
return TensorProto.UNDEFINED
|
||||
|
||||
|
||||
def parse_qnn_converter_json_file(qnn_convert_json, qnn_input_tensor_dic, qnn_output_tensor_dic):
|
||||
is_qnn_converter_json = True
|
||||
for qnn_tensor_name, qnn_tensor_attribute in qnn_convert_json["graph"]["tensors"].items():
|
||||
# type:0 - QNN input tensor, type:1 - QNN output tensor
|
||||
assert (
|
||||
"type" in qnn_tensor_attribute
|
||||
and "data_type" in qnn_tensor_attribute
|
||||
and "dims" in qnn_tensor_attribute
|
||||
and "id" in qnn_tensor_attribute
|
||||
and "quant_params" in qnn_tensor_attribute
|
||||
), "QNN converted json file not valid. Can't find some keys from tensors"
|
||||
|
||||
# If tensor is not IO, ignore it
|
||||
if qnn_tensor_attribute["type"] not in [0, 1]:
|
||||
continue
|
||||
|
||||
# Get all graph inputs & output
|
||||
qnn_tensor = QnnTensorStruct(
|
||||
name=qnn_tensor_name,
|
||||
onnx_data_type=qnn_data_type_to_onnx_data_type(qnn_tensor_attribute["data_type"], is_qnn_converter_json),
|
||||
is_quantized=is_quantized_data_type(qnn_tensor_attribute["data_type"], is_qnn_converter_json),
|
||||
dim=qnn_tensor_attribute["dims"],
|
||||
id=qnn_tensor_attribute["id"],
|
||||
)
|
||||
|
||||
if (
|
||||
qnn_tensor_attribute["quant_params"]["definition"] == 1
|
||||
and qnn_tensor_attribute["quant_params"]["encoding"] == 0
|
||||
):
|
||||
qnn_tensor.scale = qnn_tensor_attribute["quant_params"]["scale_offset"]["scale"]
|
||||
qnn_tensor.offset = -qnn_tensor_attribute["quant_params"]["scale_offset"]["offset"]
|
||||
|
||||
if qnn_tensor_attribute["type"] == 0:
|
||||
qnn_input_tensor_dic[qnn_tensor_name] = qnn_tensor
|
||||
else:
|
||||
qnn_output_tensor_dic[qnn_tensor_name] = qnn_tensor
|
||||
|
||||
assert len(qnn_input_tensor_dic) >= 1 and len(qnn_output_tensor_dic) >= 1, (
|
||||
"Converted QNN model not valid. It should have at least 1 input & 1 output."
|
||||
)
|
||||
|
||||
|
||||
def generate_wrapper_onnx_file(
|
||||
grap_name,
|
||||
model_file_name,
|
||||
qnn_input_tensor_dic,
|
||||
qnn_output_tensor_dic,
|
||||
disable_embed_mode,
|
||||
qnn_ctx_file,
|
||||
quantized_IO,
|
||||
qnn_sdk_version="unknown",
|
||||
):
|
||||
graph_nodes = []
|
||||
ini_list = []
|
||||
value_infos = []
|
||||
|
||||
model_inputs = []
|
||||
for qnn_input in sorted(qnn_input_tensor_dic.values(), key=lambda inp: inp.id):
|
||||
if qnn_input.is_quantized and not quantized_IO:
|
||||
q_scale_input_name = qnn_input.name + "_scale"
|
||||
q_offset_input_name = qnn_input.name + "_zp"
|
||||
q_scale = helper.make_tensor(q_scale_input_name, TensorProto.FLOAT, [], [qnn_input.scale])
|
||||
ini_list.append(q_scale)
|
||||
q_offset = helper.make_tensor(q_offset_input_name, qnn_input.onnx_data_type, [], [qnn_input.offset])
|
||||
ini_list.append(q_offset)
|
||||
input_name = qnn_input.name + "_dq"
|
||||
|
||||
q_node = helper.make_node(
|
||||
"QuantizeLinear",
|
||||
name=qnn_input.name,
|
||||
inputs=[input_name, q_scale_input_name, q_offset_input_name],
|
||||
outputs=[qnn_input.name],
|
||||
)
|
||||
|
||||
graph_nodes.append(q_node)
|
||||
model_inputs.append(helper.make_tensor_value_info(input_name, TensorProto.FLOAT, qnn_input.dim))
|
||||
value_infos.append(helper.make_tensor_value_info(qnn_input.name, qnn_input.onnx_data_type, qnn_input.dim))
|
||||
else:
|
||||
model_inputs.append(helper.make_tensor_value_info(qnn_input.name, qnn_input.onnx_data_type, qnn_input.dim))
|
||||
|
||||
if disable_embed_mode:
|
||||
ep_cache_context_content = qnn_ctx_file
|
||||
ctx_embed_mode = 0
|
||||
else:
|
||||
with open(qnn_ctx_file, "rb") as file:
|
||||
ep_cache_context_content = file.read()
|
||||
ctx_embed_mode = 1
|
||||
|
||||
qnn_ep_context_node = helper.make_node(
|
||||
"EPContext",
|
||||
name=grap_name,
|
||||
inputs=qnn_input_tensor_dic.keys(),
|
||||
outputs=qnn_output_tensor_dic.keys(),
|
||||
ep_cache_context=ep_cache_context_content,
|
||||
embed_mode=ctx_embed_mode,
|
||||
ep_sdk_version=qnn_sdk_version,
|
||||
source="Qnn",
|
||||
domain="com.microsoft",
|
||||
)
|
||||
graph_nodes.append(qnn_ep_context_node)
|
||||
|
||||
model_outputs = []
|
||||
for qnn_output in sorted(qnn_output_tensor_dic.values(), key=lambda out: out.id):
|
||||
if qnn_output.is_quantized and not quantized_IO:
|
||||
dq_scale_input_name = qnn_output.name + "_scale"
|
||||
dq_offset_input_name = qnn_output.name + "_zp"
|
||||
dq_scale = helper.make_tensor(dq_scale_input_name, TensorProto.FLOAT, [], [qnn_output.scale])
|
||||
ini_list.append(dq_scale)
|
||||
dq_offset = helper.make_tensor(dq_offset_input_name, qnn_output.onnx_data_type, [], [qnn_output.offset])
|
||||
ini_list.append(dq_offset)
|
||||
output_name = qnn_output.name + "_dq"
|
||||
|
||||
dq_node = helper.make_node(
|
||||
"DequantizeLinear",
|
||||
name=output_name,
|
||||
inputs=[qnn_output.name, dq_scale_input_name, dq_offset_input_name],
|
||||
outputs=[output_name],
|
||||
)
|
||||
|
||||
graph_nodes.append(dq_node)
|
||||
model_outputs.append(helper.make_tensor_value_info(output_name, TensorProto.FLOAT, qnn_output.dim))
|
||||
value_infos.append(
|
||||
helper.make_tensor_value_info(qnn_output.name, qnn_output.onnx_data_type, qnn_output.dim)
|
||||
)
|
||||
else:
|
||||
model_outputs.append(
|
||||
helper.make_tensor_value_info(qnn_output.name, qnn_output.onnx_data_type, qnn_output.dim)
|
||||
)
|
||||
|
||||
graph_def = helper.make_graph(graph_nodes, "qnn-onnx-model", model_inputs, model_outputs, ini_list, "", value_infos)
|
||||
|
||||
model_def = helper.make_model(graph_def, producer_name="MS")
|
||||
|
||||
onnx.save(model_def, model_file_name)
|
||||
|
||||
|
||||
# parse Qnn graph from the json file that extracted from context binary file
|
||||
def parse_qnn_graph(qnn_graph, qnn_input_tensor_dic, qnn_output_tensor_dic):
|
||||
is_qnn_converter_json = False
|
||||
graph_name = qnn_graph["info"]["graphName"]
|
||||
raw_inputs = qnn_graph["info"]["graphInputs"]
|
||||
raw_outputs = qnn_graph["info"]["graphOutputs"]
|
||||
|
||||
for raw_input in raw_inputs:
|
||||
tensor_info = raw_input["info"]
|
||||
qnn_tensor = QnnTensorStruct()
|
||||
qnn_tensor.name = tensor_info["name"]
|
||||
qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type(tensor_info["dataType"], is_qnn_converter_json)
|
||||
qnn_tensor.is_quantized = is_quantized_data_type(tensor_info["dataType"], is_qnn_converter_json)
|
||||
qnn_tensor.dim = tensor_info["dimensions"]
|
||||
if (
|
||||
tensor_info["quantizeParams"]["definition"] == "QNN_DEFINITION_DEFINED"
|
||||
and tensor_info["quantizeParams"]["quantizationEncoding"] == "QNN_QUANTIZATION_ENCODING_SCALE_OFFSET"
|
||||
):
|
||||
qnn_tensor.scale = tensor_info["quantizeParams"]["scaleOffset"]["scale"]
|
||||
qnn_tensor.offset = 0 - tensor_info["quantizeParams"]["scaleOffset"]["offset"]
|
||||
qnn_input_tensor_dic[qnn_tensor.name] = qnn_tensor
|
||||
|
||||
for raw_output in raw_outputs:
|
||||
tensor_info = raw_output["info"]
|
||||
qnn_tensor = QnnTensorStruct()
|
||||
qnn_tensor.name = tensor_info["name"]
|
||||
qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type(tensor_info["dataType"], is_qnn_converter_json)
|
||||
qnn_tensor.is_quantized = is_quantized_data_type(tensor_info["dataType"], is_qnn_converter_json)
|
||||
qnn_tensor.dim = tensor_info["dimensions"]
|
||||
if (
|
||||
tensor_info["quantizeParams"]["definition"] == "QNN_DEFINITION_DEFINED"
|
||||
and tensor_info["quantizeParams"]["quantizationEncoding"] == "QNN_QUANTIZATION_ENCODING_SCALE_OFFSET"
|
||||
):
|
||||
qnn_tensor.scale = tensor_info["quantizeParams"]["scaleOffset"]["scale"]
|
||||
qnn_tensor.offset = 0 - tensor_info["quantizeParams"]["scaleOffset"]["offset"]
|
||||
qnn_output_tensor_dic[qnn_tensor.name] = qnn_tensor
|
||||
|
||||
assert len(qnn_input_tensor_dic) >= 1 and len(qnn_output_tensor_dic) >= 1, (
|
||||
"Converted QNN model not valid. It should have at least 1 input & 1 output."
|
||||
)
|
||||
|
||||
return graph_name
|
||||
|
||||
|
||||
# Onnxruntime QNN EP can support context binary file generated by QNN tool chain. However QNN generated context binary file
|
||||
# uses channel last data layout and 8 bits or 16 bits for input and output.
|
||||
# This script gets the QNN model input & output information from QNN converted model_net.json file, compare them with Onnx model
|
||||
# and inserts Cast, Transpose nodes to Onnx model if required
|
||||
def main():
|
||||
parser = ArgumentParser("Generate Onnx model which includes the QNN context binary.")
|
||||
parser.add_argument("-b", "--qnn_bin", help="Required. Path to Qnn context binary file.", required=True, type=str)
|
||||
parser.add_argument(
|
||||
"-q", "--qnn_json", help="Required. Path to Qnn converted model_net.json file.", required=True, type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable_embed_mode",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Set embed_mode=1 which mean embed Qnn context binary into the onnx model. Otherwise, set context binary file path in the onnx model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantized_IO",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="QNN converted context binary use quantized data as graph inputs and outputs. Will keep it if quantized_IO=True, otherwise, will insert Q and DQ nodes accordingly to make the graph inputs & outputs as float32 data type.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse Qnn model_net.json file to get the graph input output information
|
||||
|
||||
with open(args.qnn_json) as qnn_json_file:
|
||||
qnn_json_obj = json.load(qnn_json_file)
|
||||
if "graph" in qnn_json_obj and "tensors" in qnn_json_obj["graph"]:
|
||||
print("This json file is from Qnn converter")
|
||||
qnn_input_tensor_dic = {}
|
||||
qnn_output_tensor_dic = {}
|
||||
parse_qnn_converter_json_file(qnn_json_obj, qnn_input_tensor_dic, qnn_output_tensor_dic)
|
||||
|
||||
generate_wrapper_onnx_file(
|
||||
"QnnContext",
|
||||
args.qnn_json.replace(".json", "_qnn_ctx.onnx"),
|
||||
qnn_input_tensor_dic,
|
||||
qnn_output_tensor_dic,
|
||||
args.disable_embed_mode,
|
||||
args.qnn_bin,
|
||||
args.quantized_IO,
|
||||
)
|
||||
elif "info" in qnn_json_obj and "graphs" in qnn_json_obj["info"]:
|
||||
print("This json file is extracted from QNN context binary file")
|
||||
qnn_version = qnn_json_obj["info"]["buildId"]
|
||||
for qnn_graph in qnn_json_obj["info"]["graphs"]:
|
||||
qnn_input_tensor_dic = {}
|
||||
qnn_output_tensor_dic = {}
|
||||
graph_name = parse_qnn_graph(qnn_graph, qnn_input_tensor_dic, qnn_output_tensor_dic)
|
||||
|
||||
ctx_file_name = graph_name + "_qnn_ctx.onnx"
|
||||
if not args.quantized_IO:
|
||||
ctx_file_name = ctx_file_name.replace(".onnx", "_fp32_io.onnx")
|
||||
|
||||
generate_wrapper_onnx_file(
|
||||
graph_name,
|
||||
ctx_file_name,
|
||||
qnn_input_tensor_dic,
|
||||
qnn_output_tensor_dic,
|
||||
args.disable_embed_mode,
|
||||
args.qnn_bin,
|
||||
args.quantized_IO,
|
||||
qnn_version,
|
||||
)
|
||||
else:
|
||||
print("json file unrecoginized.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,165 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
"""Provide entry point to preprocess ONNX model especially for QNN."""
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
import onnx
|
||||
|
||||
from onnxruntime.quantization.execution_providers import qnn
|
||||
|
||||
|
||||
def _parse_arguments():
|
||||
"""Parse cmdline arguments."""
|
||||
parser = argparse.ArgumentParser(description="Arguments for QNN model preprocess.")
|
||||
|
||||
parser.add_argument("--input_model_path", "-i", required=True, help="Path to the input ONNX model.")
|
||||
parser.add_argument("--output_model_path", "-o", required=True, help="Path to the output ONNX model.")
|
||||
|
||||
# Save preprocessed model with external data.
|
||||
parser.add_argument(
|
||||
"--save_as_external_data",
|
||||
action="store_true",
|
||||
help="Whether the output model would be saved with external data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--all_tensors_to_one_file",
|
||||
action="store_true",
|
||||
help="Whether to save all external data in one file or save each tensor to a file named with the tensor name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--external_data_location",
|
||||
help="Filename of the external file where all tensors are saved. The path is relative to the model path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--external_data_size_threshold",
|
||||
default=1024,
|
||||
type=int,
|
||||
help="Tensors with data size larger than this threshold are converted to external data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--external_data_convert_attribute",
|
||||
action="store_true",
|
||||
help="Whether to save all tensors, including attribute tensors, to external data.",
|
||||
)
|
||||
|
||||
# Preprocess options.
|
||||
parser.add_argument(
|
||||
"--fuse_layernorm",
|
||||
action="store_true",
|
||||
help="Whether to fuse matched sequences into LayerNormalization nodes if possible.",
|
||||
)
|
||||
|
||||
# I/O layouts.
|
||||
parser.add_argument(
|
||||
"--inputs_to_make_channel_last",
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="List of graph input names to be transposed into channel-last.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--outputs_to_make_channel_last",
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="List of graph output names to be transposed into channel-last.",
|
||||
)
|
||||
|
||||
# Fix dynamic input shapes.
|
||||
parser.add_argument(
|
||||
"--dynamic_input_shapes",
|
||||
nargs=2,
|
||||
action="append",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model input name and desired static shape in comma seprated format, for example: 'input' 1,3,256,256",
|
||||
)
|
||||
|
||||
# Exclude initializer from input
|
||||
parser.add_argument(
|
||||
"--exclude_initializer_from_input",
|
||||
action="store_true",
|
||||
help="Whether to exclude initializer from input if model.ir_version >= 4",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def qnn_preprocess_model(
|
||||
model_input: str | pathlib.Path | onnx.ModelProto,
|
||||
model_output: str | pathlib.Path,
|
||||
fuse_layernorm: bool = False,
|
||||
save_as_external_data: bool = False,
|
||||
all_tensors_to_one_file: bool = False,
|
||||
external_data_location: str | None = None,
|
||||
external_data_size_threshold: int = 1024,
|
||||
external_data_convert_attribute: bool = False,
|
||||
inputs_to_make_channel_last: list[str] | None = None,
|
||||
outputs_to_make_channel_last: list[str] | None = None,
|
||||
dynamic_input_shapes: list[tuple[str, str]] | None = None,
|
||||
exclude_initializer_from_input: bool = False,
|
||||
) -> bool:
|
||||
"""Preprocess ONNX model for QNN.
|
||||
|
||||
Args:
|
||||
model_input: A path or ONNX ModelProto specifiying the model to be preprocessed.
|
||||
model_output: A path specifying where the preprocessed model to be saved.
|
||||
fuse_layernorm: A bool specifying whether to fuse the matched sequence into a single LayerNormalization node.
|
||||
Defaults to False.
|
||||
save_as_external_data: A bool specifying whether to save model with external data. Defaults to False.
|
||||
all_tensors_to_one_file: A bool specifying whether to save all external data in one file or save each tensor to
|
||||
a file named with the tensor name. This argument is effective only when `save_as_external_data` is True.
|
||||
Defaults to False.
|
||||
external_data_location: A str specifying where to save the external data. The path is relative to the model
|
||||
path. This argument is effective only when `save_as_external_data` is True. Defaults to the model name.
|
||||
external_data_size_threshold: An int specifying the threshold of data size for tensors be saved as external
|
||||
data. This argument is effective only when `save_as_external_data` is True. Defaults to 1024.
|
||||
external_data_convert_attribute: A bool specifying whether to save all tensors including attributes as external
|
||||
data. This argument is effective only when `save_as_external_data` is True. Defaults to False.
|
||||
inputs_to_make_channel_last: A list of strs specifying graph input names to be transposed into channel-last.
|
||||
Defaults to None.
|
||||
outputs_to_make_channel_last: A list of strs specifying graph output names to be transposed into channel-last.
|
||||
Defaults to None.
|
||||
dynamic_input_shapes: A list of tuples specifying model input name to and its static shape in comma seprated
|
||||
format, for example: [('input', '1,3,256,256')]. Defaults to None.
|
||||
exclude_initializer_from_input: A bool specifying whether to exclude initializer from input. Defaults to False.
|
||||
|
||||
Returns:
|
||||
A bool indicating whether the model is modified.
|
||||
"""
|
||||
return qnn.qnn_preprocess_model(
|
||||
model_input,
|
||||
model_output,
|
||||
fuse_layernorm=fuse_layernorm,
|
||||
save_as_external_data=save_as_external_data,
|
||||
all_tensors_to_one_file=all_tensors_to_one_file,
|
||||
external_data_location=external_data_location,
|
||||
external_data_size_threshold=external_data_size_threshold,
|
||||
external_data_convert_attribute=external_data_convert_attribute,
|
||||
inputs_to_make_channel_last=inputs_to_make_channel_last,
|
||||
outputs_to_make_channel_last=outputs_to_make_channel_last,
|
||||
dynamic_input_shapes=dynamic_input_shapes,
|
||||
exclude_initializer_from_input=exclude_initializer_from_input,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = _parse_arguments()
|
||||
qnn_preprocess_model(
|
||||
args.input_model_path,
|
||||
args.output_model_path,
|
||||
fuse_layernorm=args.fuse_layernorm,
|
||||
save_as_external_data=args.save_as_external_data,
|
||||
all_tensors_to_one_file=args.all_tensors_to_one_file,
|
||||
external_data_location=args.external_data_location,
|
||||
external_data_size_threshold=args.external_data_size_threshold,
|
||||
external_data_convert_attribute=args.external_data_convert_attribute,
|
||||
inputs_to_make_channel_last=args.inputs_to_make_channel_last,
|
||||
outputs_to_make_channel_last=args.outputs_to_make_channel_last,
|
||||
dynamic_input_shapes=args.dynamic_input_shapes,
|
||||
exclude_initializer_from_input=args.exclude_initializer_from_input,
|
||||
)
|
||||
@@ -0,0 +1,203 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
# Check if the flatbuffers module is available. If not we cannot handle type reduction information in the config.
|
||||
try:
|
||||
import flatbuffers # noqa: F401
|
||||
|
||||
have_flatbuffers = True
|
||||
from .ort_format_model import GloballyAllowedTypesOpTypeImplFilter, OperatorTypeUsageManager
|
||||
except ImportError:
|
||||
have_flatbuffers = False
|
||||
|
||||
|
||||
def parse_config(config_file: str, enable_type_reduction: bool = False):
|
||||
"""
|
||||
Parse the configuration file and return the required operators dictionary and an
|
||||
OpTypeImplFilterInterface instance.
|
||||
|
||||
Configuration file lines can do the following:
|
||||
1. specify required operators
|
||||
2. specify globally allowed types for all operators
|
||||
3. specify what it means for no required operators to be specified
|
||||
|
||||
1. Specifying required operators
|
||||
|
||||
The basic format for specifying required operators is `domain;opset1,opset2;op1,op2...`
|
||||
e.g. `ai.onnx;11;Add,Cast,Clip,... for a single opset
|
||||
`ai.onnx;11,12;Add,Cast,Clip,... for multiple opsets
|
||||
|
||||
note: Configuration information is accrued as the file is parsed. If an operator requires support from multiple
|
||||
opsets that can be done with one entry for each opset, or one entry with multiple opsets in it.
|
||||
|
||||
If the configuration file is generated from ORT format models it may optionally contain JSON for per-operator
|
||||
type reduction. The required types are generally listed per input and/or output of the operator.
|
||||
The type information is in a map, with 'inputs' and 'outputs' keys. The value for 'inputs' or 'outputs' is a map
|
||||
between the index number of the input/output and the required list of types.
|
||||
|
||||
For example, both the input and output types are relevant to ai.onnx:Cast.
|
||||
Type information for input 0 and output 0 could look like this:
|
||||
`{"inputs": {"0": ["float", "int32_t"]}, "outputs": {"0": ["float", "int64_t"]}}`
|
||||
|
||||
which is added directly after the operator name in the configuration file.
|
||||
e.g.
|
||||
`ai.onnx;12;Add,Cast{"inputs": {"0": ["float", "int32_t"]}, "outputs": {"0": ["float", "int64_t"]}},Concat`
|
||||
|
||||
If for example the types of inputs 0 and 1 were important, the entry may look like this (e.g. ai.onnx:Gather):
|
||||
`{"inputs": {"0": ["float", "int32_t"], "1": ["int32_t"]}}`
|
||||
|
||||
Finally some operators do non-standard things and store their type information under a 'custom' key.
|
||||
ai.onnx.OneHot is an example of this, where the three input types are combined into a triple.
|
||||
`{"custom": [["float", "int64_t", "int64_t"], ["int64_t", "std::string", "int64_t"]]}`
|
||||
|
||||
2. Specifying globally allowed types for all operators
|
||||
|
||||
The format for specifying globally allowed types for all operators is:
|
||||
`!globally_allowed_types;T0,T1,...`
|
||||
|
||||
Ti should be a C++ scalar type supported by ONNX and ORT.
|
||||
At most one globally allowed types specification is allowed.
|
||||
|
||||
Specifying per-operator type information and specifying globally allowed types are mutually exclusive - it is an
|
||||
error to specify both.
|
||||
|
||||
3. Specify what it means for no required operators to be specified
|
||||
|
||||
By default, if no required operators are specified, NO operators are required.
|
||||
|
||||
With the following line, if no required operators are specified, ALL operators are required:
|
||||
`!no_ops_specified_means_all_ops_are_required`
|
||||
|
||||
:param config_file: Configuration file to parse
|
||||
:param enable_type_reduction: Set to True to use the type information in the config.
|
||||
If False the type information will be ignored.
|
||||
If the flatbuffers module is unavailable type information will be ignored as the
|
||||
type-based filtering has a dependency on the ORT flatbuffers schema.
|
||||
:return: required_ops: Dictionary of domain:opset:[ops] for required operators. If None, all operators are
|
||||
required.
|
||||
op_type_impl_filter: OpTypeImplFilterInterface instance if type reduction is enabled, the flatbuffers
|
||||
module is available, and type reduction information is present. None otherwise.
|
||||
"""
|
||||
|
||||
if not os.path.isfile(config_file):
|
||||
raise ValueError(f"Configuration file {config_file} does not exist")
|
||||
|
||||
# only enable type reduction when flatbuffers is available
|
||||
enable_type_reduction = enable_type_reduction and have_flatbuffers
|
||||
|
||||
required_ops = {}
|
||||
no_ops_specified_means_all_ops_are_required = False
|
||||
op_type_usage_manager = OperatorTypeUsageManager() if enable_type_reduction else None
|
||||
has_op_type_reduction_info = False
|
||||
globally_allowed_types = None
|
||||
|
||||
def process_non_op_line(line):
|
||||
if not line or line.startswith("#"): # skip empty lines and comments
|
||||
return True
|
||||
|
||||
if line.startswith("!globally_allowed_types;"): # handle globally allowed types
|
||||
if enable_type_reduction:
|
||||
nonlocal globally_allowed_types
|
||||
if globally_allowed_types is not None:
|
||||
raise RuntimeError("Globally allowed types were already specified.")
|
||||
globally_allowed_types = {segment.strip() for segment in line.split(";")[1].split(",")}
|
||||
return True
|
||||
|
||||
if line == "!no_ops_specified_means_all_ops_are_required": # handle all ops required line
|
||||
nonlocal no_ops_specified_means_all_ops_are_required
|
||||
no_ops_specified_means_all_ops_are_required = True
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
with open(config_file) as config:
|
||||
for line in [orig_line.strip() for orig_line in config]:
|
||||
if process_non_op_line(line):
|
||||
continue
|
||||
|
||||
domain, opset_str, operators_str = (segment.strip() for segment in line.split(";"))
|
||||
opsets = [int(s) for s in opset_str.split(",")]
|
||||
|
||||
# any type reduction information is serialized json that starts/ends with { and }.
|
||||
# type info is optional for each operator.
|
||||
if "{" in operators_str:
|
||||
has_op_type_reduction_info = True
|
||||
|
||||
# parse the entries in the json dictionary with type info
|
||||
operators = set()
|
||||
cur = 0
|
||||
end = len(operators_str)
|
||||
while cur < end:
|
||||
next_comma = operators_str.find(",", cur)
|
||||
next_open_brace = operators_str.find("{", cur)
|
||||
|
||||
if next_comma == -1:
|
||||
next_comma = end
|
||||
|
||||
# the json string starts with '{', so if that is found (next_open_brace != -1)
|
||||
# before the next comma (which would be the start of the next operator if there is no type info
|
||||
# for the current operator), we have type info to parse.
|
||||
# e.g. need to handle extracting the operator name and type info for OpB and OpD,
|
||||
# and just the operator names for OpA and OpC from this example string
|
||||
# OpA,OpB{"inputs": {"0": ["float", "int32_t"]}},OpC,OpD{"outputs": {"0": ["int32_t"]}}
|
||||
if 0 < next_open_brace < next_comma:
|
||||
operator = operators_str[cur:next_open_brace].strip()
|
||||
operators.add(operator)
|
||||
|
||||
# parse out the json dictionary with the type info by finding the closing brace that matches
|
||||
# the opening brace
|
||||
i = next_open_brace + 1
|
||||
num_open_braces = 1
|
||||
while num_open_braces > 0 and i < end:
|
||||
if operators_str[i] == "{":
|
||||
num_open_braces += 1
|
||||
elif operators_str[i] == "}":
|
||||
num_open_braces -= 1
|
||||
i += 1
|
||||
|
||||
if num_open_braces != 0:
|
||||
raise RuntimeError("Mismatched { and } in type string: " + operators_str[next_open_brace:])
|
||||
|
||||
if op_type_usage_manager:
|
||||
type_str = operators_str[next_open_brace:i]
|
||||
op_type_usage_manager.restore_from_config_entry(domain, operator, type_str)
|
||||
|
||||
cur = i + 1
|
||||
else:
|
||||
# comma or end of line is next
|
||||
end_str = next_comma if next_comma != -1 else end
|
||||
operators.add(operators_str[cur:end_str].strip())
|
||||
cur = end_str + 1
|
||||
|
||||
else:
|
||||
operators = {op.strip() for op in operators_str.split(",")}
|
||||
|
||||
for opset in opsets:
|
||||
if domain not in required_ops:
|
||||
required_ops[domain] = {opset: operators}
|
||||
elif opset not in required_ops[domain]:
|
||||
required_ops[domain][opset] = operators
|
||||
else:
|
||||
required_ops[domain][opset].update(operators)
|
||||
|
||||
if len(required_ops) == 0 and no_ops_specified_means_all_ops_are_required:
|
||||
required_ops = None
|
||||
|
||||
op_type_impl_filter = None
|
||||
if enable_type_reduction:
|
||||
if not has_op_type_reduction_info:
|
||||
op_type_usage_manager = None
|
||||
if globally_allowed_types is not None and op_type_usage_manager is not None:
|
||||
raise RuntimeError(
|
||||
"Specifying globally allowed types and per-op type reduction info together is unsupported."
|
||||
)
|
||||
|
||||
if globally_allowed_types is not None:
|
||||
op_type_impl_filter = GloballyAllowedTypesOpTypeImplFilter(globally_allowed_types)
|
||||
elif op_type_usage_manager is not None:
|
||||
op_type_impl_filter = op_type_usage_manager.make_op_type_impl_filter()
|
||||
|
||||
return required_ops, op_type_impl_filter
|
||||
@@ -0,0 +1,37 @@
|
||||
import argparse
|
||||
|
||||
import onnx
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", required=True, help="input model")
|
||||
parser.add_argument("--output", required=True, help="output model")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def remove_initializer_from_input(model: onnx.ModelProto) -> bool:
|
||||
if model.ir_version < 4:
|
||||
print("Model with ir_version below 4 requires to include initializer in graph input")
|
||||
return False
|
||||
|
||||
inputs = model.graph.input
|
||||
name_to_input = {}
|
||||
for input in inputs:
|
||||
name_to_input[input.name] = input
|
||||
|
||||
modified = False
|
||||
for initializer in model.graph.initializer:
|
||||
if initializer.name in name_to_input:
|
||||
modified = True
|
||||
inputs.remove(name_to_input[initializer.name])
|
||||
|
||||
return modified
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
model = onnx.load(args.input)
|
||||
remove_initializer_from_input(model)
|
||||
onnx.save(model, args.output)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
from .onnx_model_utils import update_onnx_opset
|
||||
|
||||
|
||||
def update_onnx_opset_helper():
|
||||
parser = argparse.ArgumentParser(
|
||||
f"{os.path.basename(__file__)}:{update_onnx_opset_helper.__name__}",
|
||||
description="""
|
||||
Update the ONNX opset of the model.
|
||||
New opset must be later than the existing one.
|
||||
If not specified will update to opset 15.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument("--opset", type=int, required=False, default=15, help="ONNX opset to update to.")
|
||||
parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.")
|
||||
parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write updated ONNX model to.")
|
||||
|
||||
args = parser.parse_args()
|
||||
update_onnx_opset(args.input_model, args.opset, args.output_model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
update_onnx_opset_helper()
|
||||
Reference in New Issue
Block a user