Fix tensor lifetime issue#4228
Conversation
|
|
||
| auto dims = core::util::toVec(out_shape); | ||
| auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); | ||
| outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous()); |
There was a problem hiding this comment.
By the way, a separate cleanup should be done where this line is instead
outputs[pyt_idx] = at::empty(dims, at::TensorOptions().device(at::kCUDA).dtype(type));
This would improve from two allocations & a dtype-conversion kernel to just a single allocation.
There was a problem hiding this comment.
I think this is the same line Shane identified as well.
narendasan
left a comment
There was a problem hiding this comment.
This looks good to me
|
|
||
| auto dims = core::util::toVec(out_shape); | ||
| auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); | ||
| outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous()); |
There was a problem hiding this comment.
I think this is the same line Shane identified as well.
3b6cdb3 to
66c5a42
Compare
66c5a42 to
61b3003
Compare
| // recycled by the caching allocator for output tensors, aliasing inputs | ||
| // onto outputs and corrupting reads after the first output write. | ||
| std::list<at::Tensor> setup_input_tensors( | ||
| std::vector<at::Tensor> inputs, |
There was a problem hiding this comment.
What if we mutated inputs inplace so something like
void setup_input_tensors(
std::vector<at::Tensor>& inputs,
...
Similar to what is happening for the inputShapeTensors
There was a problem hiding this comment.
Yeah, it kinda seems like the same thing but wrapped in a different package
narendasan
left a comment
There was a problem hiding this comment.
I think the only alternative design I can think of is that we mutate the vector of input tensors inplace with the new contiguous versions.
Description
This change fixes a correctness issue that I and others were seeing when running the FLUX2 diffusion model. The model, when compiled with either TensorRT or TensorRT-RTX was producing garbage images.
The issue was that the input tensor's lifetime was incorrect. The input tensor's ref count dropped to 0 before the engine ran with
enqueueV3(). In this specific case, it was a bit of a perfect storm with an output having the same size and shape and also there being a fp32->bf16 cast. Another tensor was being allocated (the output tensor) and that was given the address of the input tensor.Type of change
Checklist: