1111#include < cstdint>
1212#include < cstring>
1313#include < utility>
14+ #include < vector>
1415
1516#include < executorch/backends/aoti/slim/c10/core/Contiguity.h>
1617#include < executorch/backends/aoti/slim/c10/core/Device.h>
@@ -277,69 +278,67 @@ class SlimTensor {
277278 * Copy data from another tensor to this tensor.
278279 *
279280 * Both tensors must have the same numel and dtype.
280- * Currently only supports CPU-to-CPU copy (contiguous tensors only ).
281+ * Supports CPU-to-CPU and cross-device copies (CPU↔CUDA, CUDA↔CUDA ).
281282 *
282283 * @param other The source tensor to copy from
283284 * @return Reference to this tensor
284285 */
285286 SlimTensor& copy_ (const SlimTensor& other) {
286287 ET_CHECK_MSG (
287- this ->numel () == other.numel (),
288- " copy_: numel mismatch (dst=%zu, src=%zu)" ,
289- this ->numel (),
290- other.numel ());
291- ET_CHECK_MSG (this ->dtype () == other.dtype (), " copy_: dtype mismatch" );
288+ this ->numel () == other.numel (), " copy_: numel of tensors must match" );
289+ ET_CHECK_MSG (this ->dtype () == other.dtype (), " copy_: dtype must match" );
292290
293291 if (this ->numel () == 0 ) {
294292 return *this ;
295293 }
296294
297- // Current we only support CPU-only tensors
298- // TODO(gasoonjia): support other device types.
299- ET_CHECK_MSG (
300- this ->is_cpu () && other.is_cpu (), " copy_: only CPU tensors supported" );
301-
295+ // Case 1: Both tensors are contiguous. We can do a fast bulk copy.
302296 if (this ->is_contiguous () && other.is_contiguous ()) {
303- // Fast path: both tensors are contiguous, use memcpy
304- std::memcpy (this ->data_ptr (), other.data_ptr (), other.nbytes ());
305- } else {
306- // Slow path: element-wise copy for non-contiguous tensors
307- copy_strided_ (other);
297+ storage_->copy_ (
298+ this ->data_ptr (), other.data_ptr (), other.nbytes (), other.device ());
299+ return *this ;
308300 }
309301
310- return *this ;
311- }
312-
313- private:
314- /* *
315- * Element-wise copy for non-contiguous tensors.
316- */
317- void copy_strided_ (const SlimTensor& other) {
302+ // Case 2: At least one tensor is non-contiguous, perform element-wise copy
303+ // that respects both source and destination strides.
318304 const size_t elem_size = c10::elementSize (dtype_);
319305 char * dst_data = static_cast <char *>(this ->data_ptr ());
320306 const char * src_data = static_cast <const char *>(other.data_ptr ());
321307
322308 std::vector<int64_t > counter (this ->dim (), 0 );
323309 for (size_t i = 0 ; i < this ->numel (); i++) {
324- // Compute source offset
310+ // Compute src offset in elements
325311 int64_t src_offset = 0 ;
326312 for (size_t d = 0 ; d < other.dim (); d++) {
327- src_offset += counter[d] * other.stride (static_cast < int64_t >(d) );
313+ src_offset += counter[d] * other.stride (d );
328314 }
329315
330- // Compute destination offset
316+ // Compute dst offset in elements
331317 int64_t dst_offset = 0 ;
332318 for (size_t d = 0 ; d < this ->dim (); d++) {
333- dst_offset += counter[d] * this ->stride (static_cast < int64_t >(d) );
319+ dst_offset += counter[d] * this ->stride (d );
334320 }
335321
336- // Copy single element
337- std::memcpy (
338- dst_data + dst_offset * static_cast <int64_t >(elem_size),
339- src_data + src_offset * static_cast <int64_t >(elem_size),
340- elem_size);
341-
342- // Increment multi-dimensional counter
322+ // Copy elem_size bytes from src to dst
323+ if (this ->device ().is_cpu () && other.device ().is_cpu ()) {
324+ std::memcpy (
325+ dst_data + dst_offset * elem_size,
326+ src_data + src_offset * elem_size,
327+ elem_size);
328+ } else if (this ->device ().is_cuda () || other.device ().is_cuda ()) {
329+ #if defined(CUDA_AVAILABLE)
330+ DeviceTraits<c10::DeviceType::CUDA>::memcpy (
331+ dst_data + dst_offset * elem_size,
332+ src_data + src_offset * elem_size,
333+ elem_size,
334+ device (), // dst device
335+ other.device () // src device
336+ );
337+ #else
338+ ET_CHECK_MSG (false , " Failed on copy_ cuda tensors: no CUDA support" );
339+ #endif
340+ }
341+ // Increment the multi-dimensional counter
343342 for (int64_t d = static_cast <int64_t >(this ->dim ()) - 1 ; d >= 0 ; --d) {
344343 counter[d]++;
345344 if (counter[d] < this ->size (d)) {
@@ -348,8 +347,10 @@ class SlimTensor {
348347 counter[d] = 0 ;
349348 }
350349 }
350+ return *this ;
351351 }
352352
353+ private:
353354 void refresh_numel () {
354355 numel_ = compute_numel (sizes_and_strides_.sizes_arrayref ());
355356 }
0 commit comments