Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion core/runtime/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_pkg//:pkg.bzl", "pkg_tar")
load("@rules_pkg//pkg:mappings.bzl", "pkg_files")

package(default_visibility = ["//visibility:public"])

config_setting(
Expand Down Expand Up @@ -75,6 +76,7 @@ cc_library(
"RTDevice.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
"TensorRTBackend.h",
"runtime.h",
],
linkopts = [
Expand Down Expand Up @@ -107,6 +109,7 @@ filegroup(
"RTDevice.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
"TensorRTBackend.h",
"runtime.h",
],
visibility = ["//visibility:public"],
Expand All @@ -121,6 +124,6 @@ pkg_tar(
pkg_files(
name = "include_pkg_files",
srcs = [":include_files"],
visibility = ["//visibility:public"],
prefix = "include/torch_tensorrt/core/runtime/",
visibility = ["//visibility:public"],
)
1 change: 1 addition & 0 deletions core/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ set(CXX_SRCS

set(HEADER_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/RTDevice.h"
"${CMAKE_CURRENT_SOURCE_DIR}/TensorRTBackend.h"
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngine.h"
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngineProfiler.h"
"${CMAKE_CURRENT_SOURCE_DIR}/runtime.h"
Expand Down
64 changes: 64 additions & 0 deletions core/runtime/TensorRTBackend.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#pragma once

#include <executorch/runtime/backend/interface.h>
#include <vector>
#include "ATen/core/TensorBody.h"
#include "core/runtime/TRTEngine.h"
#include "core/runtime/runtime.h"
#include "core/util/prelude.h"

namespace torch_tensorrt {
namespace core {
namespace runtime {

/**
* Thin backend that holds a TRT engine and runs inference by calling
* execute_engine. Implements executorch::runtime::BackendInterface
* (init sets up the engine, execute runs the engine).
*/
class TensorRTBackend final : public executorch::runtime::BackendInterface {
public:
TensorRTBackend() = default;

explicit TensorRTBackend(c10::intrusive_ptr<TRTEngine> engine) : engine_(std::move(engine)) {}

void set_engine(c10::intrusive_ptr<TRTEngine> engine) {
engine_ = std::move(engine);
}

c10::intrusive_ptr<TRTEngine> get_engine() const {
return engine_;
}

bool is_initialized() const {
return engine_ != nullptr;
}

/**
* Run inference: forwards to execute_engine(inputs, engine_).
* Returns output tensors from the TRT engine.
*/
std::vector<at::Tensor> execute(std::vector<at::Tensor> inputs) {
TORCHTRT_CHECK(engine_ != nullptr, "TensorRTBackend: engine is null");
return execute_engine(std::move(inputs), engine_);
}

// executorch::runtime::BackendInterface
bool is_available() const override;
executorch::runtime::Result<executorch::runtime::DelegateHandle*> init(
executorch::runtime::BackendInitContext& context,
executorch::runtime::FreeableBuffer* processed,
executorch::runtime::ArrayRef<executorch::runtime::CompileSpec> compile_specs) const override;
executorch::runtime::Error execute(
executorch::runtime::BackendExecutionContext& context,
executorch::runtime::DelegateHandle* handle,
executorch::runtime::Span<executorch::runtime::EValue*> args) const override;
void destroy(executorch::runtime::DelegateHandle* handle) const override;

private:
c10::intrusive_ptr<TRTEngine> engine_;
};

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
26 changes: 26 additions & 0 deletions examples/executorch_example/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import torch_tensorrt


class MyModel(torch.nn.Module):
def forward(self, x):
return x + 1


with torch.no_grad():
model = MyModel().eval().cuda()
example_input = (torch.randn((2, 3, 4, 4)).cuda(),)

exported_program = torch.export.export(model, example_input)
compile_settings = {
"arg_inputs": [
torch_tensorrt.Input(shape=(2, 3, 4, 4), dtype=torch.float32),
],
"min_block_size": 1,
}
trt_gm = torch_tensorrt.dynamo.compile(exported_program, **compile_settings)

# Save as ExecuTorch .pte format (loadable by ExecuTorch runtime)
torch_tensorrt.save(
trt_gm, "model.pte", output_format="executorch", arg_inputs=example_input
)
33 changes: 33 additions & 0 deletions examples/executorch_example/model_dynamic_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
import torch_tensorrt


class MyModel(torch.nn.Module):
def forward(self, x):
return x + 1


with torch.no_grad():
model = MyModel().eval().cuda()
example_input = (torch.randn((2, 3, 4, 4)).cuda(),)
batch_dim = torch.export.Dim("batch", min=2, max=5)

exported_program = torch.export.export(
model, example_input, dynamic_shapes={"x": {0: batch_dim}}
)
compile_settings = {
"arg_inputs": [
torch_tensorrt.Input(
min_shape=(2, 3, 4, 4),
opt_shape=(3, 3, 4, 4),
max_shape=(5, 3, 4, 4),
dtype=torch.float32,
)
],
"min_block_size": 1,
}
trt_gm = torch_tensorrt.dynamo.compile(exported_program, **compile_settings)
# Save as ExecuTorch .pte format (loadable by ExecuTorch runtime)
torch_tensorrt.save(
trt_gm, "model.pte", output_format="executorch", arg_inputs=example_input
)
75 changes: 60 additions & 15 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@
import platform
import warnings
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)

import torch
from torch_tensorrt._enums import dtype
Expand Down Expand Up @@ -545,12 +556,15 @@ def convert_method_to_trt_engine(
module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs
)

return dynamo_convert_exported_program_to_serialized_trt_engine(
exp_program,
arg_inputs=tuple(arg_inputs),
kwarg_inputs=torchtrt_kwarg_inputs,
enabled_precisions=enabled_precisions_set,
**kwargs,
return cast(
bytes,
dynamo_convert_exported_program_to_serialized_trt_engine(
exp_program,
arg_inputs=tuple(arg_inputs),
kwarg_inputs=torchtrt_kwarg_inputs,
enabled_precisions=enabled_precisions_set,
**kwargs,
),
)
elif target_ir == _IRType.torch_compile:
raise RuntimeError(
Expand Down Expand Up @@ -653,7 +667,8 @@ def save(
inputs (Union[torch.Tensor, torch_tensorrt.Input]): Torch input tensors or Input specifications
arg_inputs (Tuple[Union[torch.Tensor, torch_tensorrt.Input], ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[str, Union[torch.Tensor, torch_tensorrt.Input]]): Optional, kwarg inputs to the module forward function.
output_format (str): Format to save the model. Options include exported_program | torchscript | aot_inductor.
output_format (str): Format to save the model. Options include exported_program | torchscript | aot_inductor | executorch.
Use executorch to save a .pte file loadable by the ExecuTorch runtime (requires executorch). For GraphModule, arg_inputs is required when output_format is executorch.
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.

For TRT-compiled modules with dynamic shapes, both retrace=True and retrace=False are supported:
Expand Down Expand Up @@ -726,7 +741,7 @@ def save(
if isinstance(module, CudaGraphsTorchTensorRTModule):
module = module.compiled_module
module_type = _parse_module_type(module)
accepted_formats = {"exported_program", "torchscript", "aot_inductor"}
accepted_formats = {"exported_program", "torchscript", "aot_inductor", "executorch"}
if arg_inputs is not None and not all(
isinstance(input, (torch.Tensor, Input)) for input in arg_inputs
):
Expand Down Expand Up @@ -805,8 +820,8 @@ def _all_are_input_objects(obj: Any) -> bool:
f"Inferred dynamic_shapes from torch_tensorrt.Input objects with min/opt/max specifications: {dynamic_shapes}"
)

arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device())) # type: ignore
kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device()) # type: ignore
arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device()))
kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device())

else:
# Mixed case: some inputs are Tensors, some are Input objects
Expand Down Expand Up @@ -847,7 +862,19 @@ def _extract_tensor(obj: Any) -> Any:

if output_format not in accepted_formats:
raise ValueError(
f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript"
f"Provided output_format {output_format} is not supported. "
"Supported options are exported_program | torchscript | aot_inductor | executorch"
)
if (
output_format == "executorch"
and module_type == _ModuleType.fx
and (
arg_inputs is None
or (isinstance(arg_inputs, (list, tuple)) and len(arg_inputs) == 0)
)
):
raise ValueError(
"output_format='executorch' with a GraphModule requires arg_inputs (example inputs) to export the module."
)
if output_format == "aot_inductor" and platform.system() != "Linux":
raise ValueError(
Expand Down Expand Up @@ -906,9 +933,15 @@ def _extract_tensor(obj: Any) -> Any:
inductor_configs=inductor_configs,
package_path=file_path,
)
elif output_format == "executorch":
from torch_tensorrt.dynamo._executorch_export import (
export_to_executorch,
)

export_to_executorch(module, file_path)
else:
raise RuntimeError(
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program, aot_inductor, and executorch"
)
elif module_type == _ModuleType.fx:
# The module type is torch.fx.GraphModule
Expand Down Expand Up @@ -963,9 +996,15 @@ def _extract_tensor(obj: Any) -> Any:
inductor_configs=inductor_configs,
package_path=file_path,
)
elif output_format == "executorch":
from torch_tensorrt.dynamo._executorch_export import (
export_to_executorch,
)

export_to_executorch(exp_program, file_path)
else:
raise RuntimeError(
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program, aot_inductor, and executorch"
)
else:
# When retrace=True with a TRT-compiled GraphModule that has dynamic shapes,
Expand Down Expand Up @@ -1042,9 +1081,15 @@ def _extract_tensor(obj: Any) -> Any:
inductor_configs=inductor_configs,
package_path=file_path,
)
elif output_format == "executorch":
from torch_tensorrt.dynamo._executorch_export import (
export_to_executorch,
)

export_to_executorch(exp_program, file_path)
else:
raise RuntimeError(
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program, aot_inductor, and executorch"
)


Expand Down
Loading
Loading