Skip to content

Commit 37aa87e

Browse files
[slimtensor] Enable CUDA tensor copy (pytorch#16800)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: pytorch#16771 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/111/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/111/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/110/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/111/orig Differential Revision: [D91202900](https://our.internmc.facebook.com/intern/diff/D91202900/) @diff-train-skip-merge --------- Co-authored-by: gasoonjia <gasoonjia@icloud.com> Co-authored-by: Gasoonjia <gasoonjia@meta.com>
1 parent 8ab593b commit 37aa87e

5 files changed

Lines changed: 427 additions & 66 deletions

File tree

backends/aoti/slim/core/SlimTensor.h

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <cstdint>
1212
#include <cstring>
1313
#include <utility>
14+
#include <vector>
1415

1516
#include <executorch/backends/aoti/slim/c10/core/Contiguity.h>
1617
#include <executorch/backends/aoti/slim/c10/core/Device.h>
@@ -277,69 +278,67 @@ class SlimTensor {
277278
* Copy data from another tensor to this tensor.
278279
*
279280
* Both tensors must have the same numel and dtype.
280-
* Currently only supports CPU-to-CPU copy (contiguous tensors only).
281+
* Supports CPU-to-CPU and cross-device copies (CPU↔CUDA, CUDA↔CUDA).
281282
*
282283
* @param other The source tensor to copy from
283284
* @return Reference to this tensor
284285
*/
285286
SlimTensor& copy_(const SlimTensor& other) {
286287
ET_CHECK_MSG(
287-
this->numel() == other.numel(),
288-
"copy_: numel mismatch (dst=%zu, src=%zu)",
289-
this->numel(),
290-
other.numel());
291-
ET_CHECK_MSG(this->dtype() == other.dtype(), "copy_: dtype mismatch");
288+
this->numel() == other.numel(), "copy_: numel of tensors must match");
289+
ET_CHECK_MSG(this->dtype() == other.dtype(), "copy_: dtype must match");
292290

293291
if (this->numel() == 0) {
294292
return *this;
295293
}
296294

297-
// Current we only support CPU-only tensors
298-
// TODO(gasoonjia): support other device types.
299-
ET_CHECK_MSG(
300-
this->is_cpu() && other.is_cpu(), "copy_: only CPU tensors supported");
301-
295+
// Case 1: Both tensors are contiguous. We can do a fast bulk copy.
302296
if (this->is_contiguous() && other.is_contiguous()) {
303-
// Fast path: both tensors are contiguous, use memcpy
304-
std::memcpy(this->data_ptr(), other.data_ptr(), other.nbytes());
305-
} else {
306-
// Slow path: element-wise copy for non-contiguous tensors
307-
copy_strided_(other);
297+
storage_->copy_(
298+
this->data_ptr(), other.data_ptr(), other.nbytes(), other.device());
299+
return *this;
308300
}
309301

310-
return *this;
311-
}
312-
313-
private:
314-
/**
315-
* Element-wise copy for non-contiguous tensors.
316-
*/
317-
void copy_strided_(const SlimTensor& other) {
302+
// Case 2: At least one tensor is non-contiguous, perform element-wise copy
303+
// that respects both source and destination strides.
318304
const size_t elem_size = c10::elementSize(dtype_);
319305
char* dst_data = static_cast<char*>(this->data_ptr());
320306
const char* src_data = static_cast<const char*>(other.data_ptr());
321307

322308
std::vector<int64_t> counter(this->dim(), 0);
323309
for (size_t i = 0; i < this->numel(); i++) {
324-
// Compute source offset
310+
// Compute src offset in elements
325311
int64_t src_offset = 0;
326312
for (size_t d = 0; d < other.dim(); d++) {
327-
src_offset += counter[d] * other.stride(static_cast<int64_t>(d));
313+
src_offset += counter[d] * other.stride(d);
328314
}
329315

330-
// Compute destination offset
316+
// Compute dst offset in elements
331317
int64_t dst_offset = 0;
332318
for (size_t d = 0; d < this->dim(); d++) {
333-
dst_offset += counter[d] * this->stride(static_cast<int64_t>(d));
319+
dst_offset += counter[d] * this->stride(d);
334320
}
335321

336-
// Copy single element
337-
std::memcpy(
338-
dst_data + dst_offset * static_cast<int64_t>(elem_size),
339-
src_data + src_offset * static_cast<int64_t>(elem_size),
340-
elem_size);
341-
342-
// Increment multi-dimensional counter
322+
// Copy elem_size bytes from src to dst
323+
if (this->device().is_cpu() && other.device().is_cpu()) {
324+
std::memcpy(
325+
dst_data + dst_offset * elem_size,
326+
src_data + src_offset * elem_size,
327+
elem_size);
328+
} else if (this->device().is_cuda() || other.device().is_cuda()) {
329+
#if defined(CUDA_AVAILABLE)
330+
DeviceTraits<c10::DeviceType::CUDA>::memcpy(
331+
dst_data + dst_offset * elem_size,
332+
src_data + src_offset * elem_size,
333+
elem_size,
334+
device(), // dst device
335+
other.device() // src device
336+
);
337+
#else
338+
ET_CHECK_MSG(false, "Failed on copy_ cuda tensors: no CUDA support");
339+
#endif
340+
}
341+
// Increment the multi-dimensional counter
343342
for (int64_t d = static_cast<int64_t>(this->dim()) - 1; d >= 0; --d) {
344343
counter[d]++;
345344
if (counter[d] < this->size(d)) {
@@ -348,8 +347,10 @@ class SlimTensor {
348347
counter[d] = 0;
349348
}
350349
}
350+
return *this;
351351
}
352352

353+
private:
353354
void refresh_numel() {
354355
numel_ = compute_numel(sizes_and_strides_.sizes_arrayref());
355356
}

backends/aoti/slim/core/Storage.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,15 @@ class MaybeOwningStorage {
296296
return;
297297
}
298298

299-
ET_CHECK_MSG(
300-
device_.is_cpu() && src_device.is_cpu(),
301-
"Only CPU-to-CPU copy is currently supported");
302-
303-
DeviceTraits<c10::DeviceType::CPU>::memcpy(
304-
dst_data_ptr, src_data_ptr, nbytes, device_, src_device);
299+
if (device_.is_cpu() && src_device.is_cpu()) {
300+
// CPU to CPU copy
301+
DeviceTraits<c10::DeviceType::CPU>::memcpy(
302+
dst_data_ptr, src_data_ptr, nbytes, device_, src_device);
303+
} else {
304+
// At least one of the devices is CUDA
305+
DeviceTraits<c10::DeviceType::CUDA>::memcpy(
306+
dst_data_ptr, src_data_ptr, nbytes, device_, src_device);
307+
}
305308
}
306309

307310
/// Creates a clone of this storage on the specified device.

backends/aoti/slim/core/targets.bzl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def define_common_targets():
2222
],
2323
)
2424

25-
# Header-only library for SlimTensor (CPU-only for now)
2625
runtime.cxx_library(
2726
name = "slimtensor",
2827
headers = [

backends/aoti/slim/core/test/targets.bzl

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,49 +32,31 @@ def define_common_targets():
3232
**backend_kwargs
3333
)
3434

35-
backend_kwargs = {
36-
"external_deps": [("cuda", None, "cuda-lazy")],
37-
"preprocessor_flags": ["-DCUDA_AVAILABLE=1"],
38-
"keep_gpu_sections": True,
39-
"remote_execution": re_test_utils.remote_execution(
40-
platform = "gpu-remote-execution",
41-
),
42-
} if backend_mode == "cuda" else {}
43-
4435
runtime.cxx_test(
45-
name = "test_storage" + backend_suffix,
36+
name = "test_slimtensor_basic" + backend_suffix,
4637
srcs = [
47-
"test_storage.cpp",
38+
"test_slimtensor_basic.cpp",
4839
],
4940
deps = [
41+
"//executorch/backends/aoti/slim/core:slimtensor",
5042
"//executorch/backends/aoti/slim/core:storage",
5143
],
5244
**backend_kwargs
5345
)
5446

5547
runtime.cxx_test(
56-
name = "test_slimtensor_basic" + backend_suffix,
48+
name = "test_slimtensor_copy" + backend_suffix,
5749
srcs = [
58-
"test_slimtensor_basic.cpp",
50+
"test_slimtensor_copy.cpp",
5951
],
6052
deps = [
6153
"//executorch/backends/aoti/slim/core:slimtensor",
6254
"//executorch/backends/aoti/slim/core:storage",
55+
"//executorch/backends/aoti/slim/factory:empty",
6356
],
6457
**backend_kwargs
6558
)
6659

67-
runtime.cxx_test(
68-
name = "test_slimtensor_copy",
69-
srcs = [
70-
"test_slimtensor_copy.cpp",
71-
],
72-
deps = [
73-
"//executorch/backends/aoti/slim/core:slimtensor",
74-
"//executorch/backends/aoti/slim/core:storage",
75-
],
76-
)
77-
7860
runtime.cxx_test(
7961
name = "test_slimtensor_dtypes",
8062
srcs = [

0 commit comments

Comments
 (0)