Skip to content

Commit b1c3dac

Browse files
committed
Add support for torch.export ExportedProgram models (#1498)
Add functionality to load and execute PyTorch models exported via torch.export (.pt2 files) using AOTInductor compilation, enabling .NET applications to run ExportedProgram models. Native layer: - THSExport.h/.cpp C++ wrappers using AOTIModelPackageLoader API - ExportedProgramModule typedef in Utils.h - CMakeLists.txt updated to include THSExport sources Managed layer: - LibTorchSharp.THSExport.cs PInvoke declarations - ExportedProgram and ExportedProgram<TResult> classes in Export namespace - torch.export.load() API following PyTorch conventions Capabilities: - Load .pt2 files compiled with torch._inductor.aoti_compile_and_package() - Inference-only forward pass with type-safe generics - Single tensor, array, and tuple output support - IDisposable resource cleanup Tests: - 7 unit tests covering load, execute, multi-input, tuple/list outputs - 6 test .pt2 models regenerated with PyTorch 2.10 - generate_export_models.py for model regeneration Fixes #1498
1 parent 38988a2 commit b1c3dac

17 files changed

Lines changed: 660 additions & 0 deletions

RELEASENOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ __Bug Fixes__:
2424

2525
__API Changes__:
2626

27+
#1498 Add support for torch.export ExportedProgram models (.pt2 files)<br/>
2728
#1503 Add ReadOnlySpan overloads to many methods.<br/>
2829
#1478 Fix `torch.jit.ScriptModule.zero_grad`.<br/>
2930
#1495 Make `torchvision.io.read_image` and `torchvision.io.read_image_async` allow subsequent opening of the file for reading.<br/>

src/Native/LibTorchSharp/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ set(SOURCES
1111
crc32c.h
1212
THSAutograd.h
1313
THSData.h
14+
THSExport.h
1415
THSJIT.h
1516
THSNN.h
1617
THSStorage.h
@@ -23,6 +24,7 @@ set(SOURCES
2324
THSActivation.cpp
2425
THSAutograd.cpp
2526
THSData.cpp
27+
THSExport.cpp
2628
THSFFT.cpp
2729
THSJIT.cpp
2830
THSLinearAlgebra.cpp
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
2+
#include "THSExport.h"
3+
4+
// torch.export support via AOTInductor
5+
// This uses torch::inductor::AOTIModelPackageLoader which is INFERENCE-ONLY
6+
// Models must be compiled with torch._inductor.aoti_compile_and_package() in Python
7+
8+
ExportedProgramModule THSExport_load(const char* filename)
9+
{
10+
CATCH(
11+
// Load .pt2 file using AOTIModelPackageLoader
12+
// This requires models to be compiled with aoti_compile_and_package()
13+
auto* loader = new torch::inductor::AOTIModelPackageLoader(filename);
14+
return loader;
15+
);
16+
17+
return nullptr;
18+
}
19+
20+
void THSExport_Module_dispose(const ExportedProgramModule module)
21+
{
22+
delete module;
23+
}
24+
25+
void THSExport_Module_run(
26+
const ExportedProgramModule module,
27+
const Tensor* input_tensors,
28+
const int input_length,
29+
Tensor** result_tensors,
30+
int* result_length)
31+
{
32+
CATCH(
33+
// Convert input tensor pointers to std::vector<torch::Tensor>
34+
std::vector<torch::Tensor> inputs;
35+
inputs.reserve(input_length);
36+
for (int i = 0; i < input_length; i++) {
37+
inputs.push_back(*input_tensors[i]);
38+
}
39+
40+
// Run inference
41+
std::vector<torch::Tensor> outputs = module->run(inputs);
42+
43+
// Allocate output array and copy results
44+
*result_length = outputs.size();
45+
*result_tensors = new Tensor[outputs.size()];
46+
47+
for (size_t i = 0; i < outputs.size(); i++) {
48+
(*result_tensors)[i] = new torch::Tensor(outputs[i]);
49+
}
50+
);
51+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
2+
#pragma once
3+
4+
#include "../Stdafx.h"
5+
6+
#include "torch/torch.h"
7+
#include "torch/csrc/inductor/aoti_package/model_package_loader.h"
8+
9+
#include "Utils.h"
10+
11+
// torch.export support via AOTInductor - Load and execute PyTorch ExportedProgram models (.pt2 files)
12+
// ExportedProgram is PyTorch 2.x's recommended way to export models for production deployment
13+
//
14+
// IMPORTANT: This implementation uses torch::inductor::AOTIModelPackageLoader which is
15+
// INFERENCE-ONLY. Training, parameter updates, and device movement are not supported.
16+
// Models must be compiled with torch._inductor.aoti_compile_and_package() in Python.
17+
18+
// Load an AOTInductor-compiled model package from a .pt2 file
19+
EXPORT_API(ExportedProgramModule) THSExport_load(const char* filename);
20+
21+
// Dispose of an ExportedProgram module
22+
EXPORT_API(void) THSExport_Module_dispose(const ExportedProgramModule module);
23+
24+
// Execute the ExportedProgram's forward method (inference only)
25+
// Input: Array of tensors
26+
// Output: Array of result tensors (caller must free)
27+
EXPORT_API(void) THSExport_Module_run(
28+
const ExportedProgramModule module,
29+
const Tensor* input_tensors,
30+
const int input_length,
31+
Tensor** result_tensors,
32+
int* result_length);

src/Native/LibTorchSharp/THSJIT.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,7 @@ EXPORT_API(TensorOrScalar*) THSJIT_AllocateTensorOrScalarArray(int32_t size);
9999
EXPORT_API(void) THSJIT_FreeTensorOrScalarArray(TensorOrScalar* ptr);
100100
EXPORT_API(void) THSJIT_SetTensorOrScalar(TensorOrScalar* array, int32_t index, int64_t type_code, int64_t array_index, ptrdiff_t handle);
101101
EXPORT_API(TensorOrScalar*) THSJIT_GetTensorOrScalar(TensorOrScalar* array, int32_t index);
102+
103+
// Helper functions (shared with THSExport)
104+
std::vector<c10::IValue> toIValue(const TensorOrScalar* tensorPtrs, const int length);
105+
TensorOrScalar* ReturnHelper(c10::IValue result, TensorOrScalar* (*allocator)(int32_t idx, size_t length), int8_t* typeCode, int32_t* idx);

src/Native/LibTorchSharp/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <string>
55

66
#include "torch/torch.h"
7+
#include "torch/csrc/inductor/aoti_package/model_package_loader.h"
78

89
extern thread_local char *torch_last_err;
910

@@ -24,6 +25,10 @@ typedef std::shared_ptr<torch::jit::Function> * JITFunction;
2425
typedef std::shared_ptr<c10::Type> * JITType;
2526
typedef std::shared_ptr<c10::TensorType>* JITTensorType;
2627

28+
// torch.export ExportedProgram module via AOTInductor
29+
// Note: Uses torch::inductor::AOTIModelPackageLoader for inference-only execution
30+
typedef torch::inductor::AOTIModelPackageLoader* ExportedProgramModule;
31+
2732
struct TensorArray {
2833
Tensor *array;
2934
int64_t size;
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
2+
3+
using System;
4+
using System.Runtime.InteropServices;
5+
using TorchSharp.PInvoke;
6+
using static TorchSharp.PInvoke.NativeMethods;
7+
8+
namespace TorchSharp
9+
{
10+
public static partial class torch
11+
{
12+
public static partial class export
13+
{
14+
/// <summary>
15+
/// Load a PyTorch ExportedProgram from a .pt2 file compiled with AOTInductor.
16+
/// </summary>
17+
/// <param name="filename">Path to the .pt2 file</param>
18+
/// <returns>ExportedProgram model for inference</returns>
19+
/// <remarks>
20+
/// IMPORTANT: The .pt2 file must be compiled with torch._inductor.aoti_compile_and_package() in Python.
21+
/// Models saved with torch.export.save() alone will NOT work - they require AOTInductor compilation.
22+
///
23+
/// This implementation is INFERENCE-ONLY. Training, parameter updates, and device movement
24+
/// are not supported. The model is compiled for a specific device (CPU/CUDA) at compile time.
25+
///
26+
/// Example Python code to create compatible .pt2 files:
27+
/// <code>
28+
/// import torch
29+
/// import torch._inductor
30+
///
31+
/// # Export the model
32+
/// exported = torch.export.export(model, example_inputs)
33+
///
34+
/// # Compile with AOTInductor (required for C++ loading)
35+
/// torch._inductor.aoti_compile_and_package(
36+
/// exported,
37+
/// package_path="model.pt2"
38+
/// )
39+
/// </code>
40+
/// </remarks>
41+
public static ExportedProgram load(string filename)
42+
{
43+
return new ExportedProgram(filename);
44+
}
45+
46+
/// <summary>
47+
/// Load a PyTorch ExportedProgram with typed output.
48+
/// </summary>
49+
public static ExportedProgram<TResult> load<TResult>(string filename)
50+
{
51+
return new ExportedProgram<TResult>(filename);
52+
}
53+
}
54+
}
55+
56+
/// <summary>
57+
/// Represents a PyTorch ExportedProgram loaded from an AOTInductor-compiled .pt2 file.
58+
/// This is an INFERENCE-ONLY implementation - training and parameter updates are not supported.
59+
/// </summary>
60+
/// <remarks>
61+
/// Unlike TorchScript models, ExportedProgram models are ahead-of-time (AOT) compiled for
62+
/// a specific device and are optimized for inference performance. They provide 30-40% better
63+
/// latency compared to TorchScript in many cases.
64+
///
65+
/// Key limitations:
66+
/// - Inference only (no training, no gradients)
67+
/// - No parameter access or updates
68+
/// - No device movement (compiled for specific device)
69+
/// - No dynamic model structure changes
70+
///
71+
/// Use torch.jit for models that require training or dynamic behavior.
72+
/// </remarks>
73+
public class ExportedProgram : IDisposable
74+
{
75+
private IntPtr handle;
76+
private bool _disposed = false;
77+
78+
internal ExportedProgram(string filename)
79+
{
80+
handle = THSExport_load(filename);
81+
if (handle == IntPtr.Zero)
82+
torch.CheckForErrors();
83+
}
84+
85+
/// <summary>
86+
/// Run inference on the model with the given input tensors.
87+
/// </summary>
88+
/// <param name="inputs">Input tensors for the model</param>
89+
/// <returns>Array of output tensors</returns>
90+
/// <remarks>
91+
/// The number and shapes of inputs must match what the model was exported with.
92+
/// All inputs must be on the same device that the model was compiled for.
93+
/// </remarks>
94+
public torch.Tensor[] run(params torch.Tensor[] inputs)
95+
{
96+
if (_disposed)
97+
throw new ObjectDisposedException(nameof(ExportedProgram));
98+
99+
// Convert managed tensors to IntPtr array
100+
IntPtr[] input_handles = new IntPtr[inputs.Length];
101+
for (int i = 0; i < inputs.Length; i++)
102+
{
103+
input_handles[i] = inputs[i].Handle;
104+
}
105+
106+
// Call native run method
107+
THSExport_Module_run(handle, input_handles, inputs.Length, out IntPtr result_ptr, out int result_length);
108+
torch.CheckForErrors();
109+
110+
// Marshal result array
111+
torch.Tensor[] results = new torch.Tensor[result_length];
112+
IntPtr[] result_handles = new IntPtr[result_length];
113+
Marshal.Copy(result_ptr, result_handles, 0, result_length);
114+
115+
for (int i = 0; i < result_length; i++)
116+
{
117+
results[i] = new torch.Tensor(result_handles[i]);
118+
}
119+
120+
// Free the native array (tensors are now owned by managed Tensor objects)
121+
Marshal.FreeHGlobal(result_ptr);
122+
123+
return results;
124+
}
125+
126+
/// <summary>
127+
/// Synonym for run() - executes forward pass.
128+
/// </summary>
129+
public torch.Tensor[] forward(params torch.Tensor[] inputs) => run(inputs);
130+
131+
/// <summary>
132+
/// Synonym for run() - executes the model.
133+
/// </summary>
134+
public torch.Tensor[] call(params torch.Tensor[] inputs) => run(inputs);
135+
136+
public void Dispose()
137+
{
138+
Dispose(true);
139+
GC.SuppressFinalize(this);
140+
}
141+
142+
protected virtual void Dispose(bool disposing)
143+
{
144+
if (!_disposed)
145+
{
146+
if (handle != IntPtr.Zero)
147+
{
148+
THSExport_Module_dispose(handle);
149+
handle = IntPtr.Zero;
150+
}
151+
_disposed = true;
152+
}
153+
}
154+
155+
~ExportedProgram()
156+
{
157+
Dispose(false);
158+
}
159+
}
160+
161+
/// <summary>
162+
/// Generic version of ExportedProgram with typed output.
163+
/// </summary>
164+
/// <typeparam name="TResult">The return type (Tensor, Tensor[], or tuple of Tensors)</typeparam>
165+
public class ExportedProgram<TResult> : ExportedProgram
166+
{
167+
internal ExportedProgram(string filename) : base(filename)
168+
{
169+
}
170+
171+
/// <summary>
172+
/// Run inference with typed return value.
173+
/// </summary>
174+
public new TResult run(params torch.Tensor[] inputs)
175+
{
176+
var results = base.run(inputs);
177+
178+
// Handle different return types
179+
if (typeof(TResult) == typeof(torch.Tensor))
180+
{
181+
if (results.Length != 1)
182+
throw new InvalidOperationException($"Expected 1 output tensor, got {results.Length}");
183+
return (TResult)(object)results[0];
184+
}
185+
186+
if (typeof(TResult) == typeof(torch.Tensor[]))
187+
{
188+
return (TResult)(object)results;
189+
}
190+
191+
// Handle tuple types
192+
if (typeof(TResult).IsGenericType)
193+
{
194+
var genericType = typeof(TResult).GetGenericTypeDefinition();
195+
if (genericType == typeof(ValueTuple<,>))
196+
{
197+
if (results.Length != 2)
198+
throw new InvalidOperationException($"Expected 2 output tensors, got {results.Length}");
199+
return (TResult)Activator.CreateInstance(typeof(TResult), results[0], results[1]);
200+
}
201+
if (genericType == typeof(ValueTuple<,,>))
202+
{
203+
if (results.Length != 3)
204+
throw new InvalidOperationException($"Expected 3 output tensors, got {results.Length}");
205+
return (TResult)Activator.CreateInstance(typeof(TResult), results[0], results[1], results[2]);
206+
}
207+
}
208+
209+
throw new NotSupportedException($"Return type {typeof(TResult)} is not supported");
210+
}
211+
212+
public new TResult forward(params torch.Tensor[] inputs) => run(inputs);
213+
public new TResult call(params torch.Tensor[] inputs) => run(inputs);
214+
}
215+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
2+
#nullable enable
3+
using System;
4+
using System.Runtime.InteropServices;
5+
6+
namespace TorchSharp.PInvoke
7+
{
8+
#pragma warning disable CA2101
9+
internal static partial class NativeMethods
10+
{
11+
// torch.export support via AOTInductor (INFERENCE-ONLY)
12+
// Models must be compiled with torch._inductor.aoti_compile_and_package() in Python
13+
14+
// Load ExportedProgram from .pt2 file
15+
[DllImport("LibTorchSharp", CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true)]
16+
internal static extern IntPtr THSExport_load(string filename);
17+
18+
// Dispose ExportedProgram module
19+
[DllImport("LibTorchSharp")]
20+
internal static extern void THSExport_Module_dispose(IntPtr handle);
21+
22+
// Execute forward pass (inference only)
23+
[DllImport("LibTorchSharp")]
24+
internal static extern void THSExport_Module_run(
25+
IntPtr module,
26+
IntPtr[] input_tensors,
27+
int input_length,
28+
out IntPtr result_tensors,
29+
out int result_length);
30+
}
31+
#pragma warning restore CA2101
32+
}

0 commit comments

Comments
 (0)