Skip to content

Commit 28f38d6

Browse files
jgibson2claude
andauthored
Add Tensor.copyDataInto to Java API (pytorch#19171)
## Summary Adds `Tensor.copyDataInto(... dst)` to the Android Java API for the float32 and float16 dtypes. It copies the tensor's data into a caller-provided destination buffer instead of allocating a fresh `float[]` per call (as `getDataAsFloatArray()` does today). The same pattern is repeated for other types. ## Motivation While profiling depth inference on Android with Perfetto, output extraction was a meaningful contributor to ART GC pressure. Each call to `output.toTensor().dataAsFloatArray` allocates a new Java `float[]` sized to the tensor's element count and bulk-copies from the underlying off-heap buffer into it. The native side already exposes the underlying `FloatBuffer` directly (zero-copy view of the C++ tensor's `data_ptr()`), so the only thing missing was a public way for callers to drain it into a destination buffer they already own and reuse across calls. ## API ```java public void copyDataInto(FloatBuffer dst) ``` - Implemented on all datatypes ## Caller-side usage example ```java // One-time setup FloatBuffer depthBuf = Tensor.allocateFloatBuffer(numelDepth); // Per inference EValue[] outputs = module.forward(...); depthBuf.rewind(); outputs[0].toTensor().copyDataInto(depthBuf); // no allocation // ... read from depthBuf ... ``` ## Test plan - [x] Added unit tests in `TensorTest.kt`: - `testCopyDataIntoFloat32` — round-trip with reuse across two calls - `testCopyDataIntoFloat32_writesAtDstPosition` — verifies the call writes at `dst.position()` and advances it (does not overwrite from index 0) - `testCopyDataIntoFloat32_overflow` — `BufferOverflowException` on undersized destination - `testCopyDataIntoFloat16` — verifies fp16→fp32 widening matches `getDataAsFloatArray` - `testCopyDataIntoFloat_unsupportedDtype` — `IllegalStateException` from base default for non-float dtypes This PR was authored with Claude. cc @kirklandsign @cbilgin --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 20415bf commit 28f38d6

2 files changed

Lines changed: 520 additions & 0 deletions

File tree

  • extension/android/executorch_android/src

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,122 @@ public float[] getDataAsFloatArray() {
554554
"Tensor of type " + getClass().getSimpleName() + " cannot return data as float array.");
555555
}
556556

557+
/**
558+
* Copies the tensor's data into a caller-provided {@link FloatBuffer}, avoiding the per-call
559+
* {@code float[]} allocation that {@link #getDataAsFloatArray()} performs. The destination
560+
* buffer's position is advanced by the number of elements written; its content from the starting
561+
* position must have at least {@link #numel()} elements of remaining capacity.
562+
*
563+
* <p>Useful in steady-state inference loops where the same output tensor shape is read every
564+
* frame: pre-allocate a {@code FloatBuffer} once (e.g. via {@link #allocateFloatBuffer(int)}) and
565+
* reuse it across calls.
566+
*
567+
* <p>Supported by float32 (zero-copy bulk put) and float16 (per-element half→float widening,
568+
* matching {@link #getDataAsFloatArray()} on that subclass). For raw fp16 bits without widening,
569+
* use {@link #copyDataInto(ShortBuffer)}.
570+
*
571+
* @param dst the destination buffer; must have remaining capacity {@code >=} {@link #numel()}.
572+
* @throws IllegalStateException if it is called for a tensor type that does not support a float
573+
* view.
574+
* @throws java.nio.BufferOverflowException if {@code dst} does not have enough remaining
575+
* capacity.
576+
*/
577+
public void copyDataInto(FloatBuffer dst) {
578+
throw new IllegalStateException(
579+
"Tensor of type " + getClass().getSimpleName() + " cannot copy data into FloatBuffer.");
580+
}
581+
582+
/**
583+
* Copies the tensor's data into a caller-provided {@link ByteBuffer}, avoiding the per-call
584+
* {@code byte[]} allocation that {@link #getDataAsByteArray()} performs.
585+
*
586+
* @param dst the destination buffer; must have remaining capacity {@code >=} {@link #numel()}.
587+
* @throws IllegalStateException if it is called for a non-int8 tensor.
588+
* @throws java.nio.BufferOverflowException if {@code dst} does not have enough remaining
589+
* capacity.
590+
*/
591+
public void copyDataInto(ByteBuffer dst) {
592+
throw new IllegalStateException(
593+
"Tensor of type " + getClass().getSimpleName() + " cannot copy data into ByteBuffer.");
594+
}
595+
596+
/**
597+
* Copies the tensor's data into a caller-provided {@link ByteBuffer}, avoiding the per-call
598+
* {@code byte[]} allocation that {@link #getDataAsUnsignedByteArray()} performs. The bytes carry
599+
* the raw uint8 bits — Java's signed {@code byte} representation, with values {@code >127}
600+
* appearing negative; reinterpret with {@code & 0xFF} when reading.
601+
*
602+
* @param dst the destination buffer; must have remaining capacity {@code >=} {@link #numel()}.
603+
* @throws IllegalStateException if it is called for a non-uint8 tensor.
604+
* @throws java.nio.BufferOverflowException if {@code dst} does not have enough remaining
605+
* capacity.
606+
*/
607+
public void copyDataIntoUnsigned(ByteBuffer dst) {
608+
throw new IllegalStateException(
609+
"Tensor of type "
610+
+ getClass().getSimpleName()
611+
+ " cannot copy data into ByteBuffer (unsigned).");
612+
}
613+
614+
/**
615+
* Copies the tensor's data into a caller-provided {@link IntBuffer}, avoiding the per-call {@code
616+
* int[]} allocation that {@link #getDataAsIntArray()} performs.
617+
*
618+
* @param dst the destination buffer; must have remaining capacity {@code >=} {@link #numel()}.
619+
* @throws IllegalStateException if it is called for a non-int32 tensor.
620+
* @throws java.nio.BufferOverflowException if {@code dst} does not have enough remaining
621+
* capacity.
622+
*/
623+
public void copyDataInto(IntBuffer dst) {
624+
throw new IllegalStateException(
625+
"Tensor of type " + getClass().getSimpleName() + " cannot copy data into IntBuffer.");
626+
}
627+
628+
/**
629+
* Copies the tensor's data into a caller-provided {@link LongBuffer}, avoiding the per-call
630+
* {@code long[]} allocation that {@link #getDataAsLongArray()} performs.
631+
*
632+
* @param dst the destination buffer; must have remaining capacity {@code >=} {@link #numel()}.
633+
* @throws IllegalStateException if it is called for a non-int64 tensor.
634+
* @throws java.nio.BufferOverflowException if {@code dst} does not have enough remaining
635+
* capacity.
636+
*/
637+
public void copyDataInto(LongBuffer dst) {
638+
throw new IllegalStateException(
639+
"Tensor of type " + getClass().getSimpleName() + " cannot copy data into LongBuffer.");
640+
}
641+
642+
/**
643+
* Copies the tensor's data into a caller-provided {@link DoubleBuffer}, avoiding the per-call
644+
* {@code double[]} allocation that {@link #getDataAsDoubleArray()} performs.
645+
*
646+
* @param dst the destination buffer; must have remaining capacity {@code >=} {@link #numel()}.
647+
* @throws IllegalStateException if it is called for a non-float64 tensor.
648+
* @throws java.nio.BufferOverflowException if {@code dst} does not have enough remaining
649+
* capacity.
650+
*/
651+
public void copyDataInto(DoubleBuffer dst) {
652+
throw new IllegalStateException(
653+
"Tensor of type " + getClass().getSimpleName() + " cannot copy data into DoubleBuffer.");
654+
}
655+
656+
/**
657+
* Copies the tensor's data into a caller-provided {@link ShortBuffer}, avoiding the per-call
658+
* {@code short[]} allocation that {@link #getDataAsShortArray()} performs. For float16 tensors
659+
* this writes the raw 16-bit half-precision bits with no widening; use {@link
660+
* #copyDataInto(FloatBuffer)} if you want the values widened to fp32.
661+
*
662+
* @param dst the destination buffer; must have remaining capacity {@code >=} {@link #numel()}.
663+
* @throws IllegalStateException if it is called for a tensor type whose backing storage is not a
664+
* {@code ShortBuffer}.
665+
* @throws java.nio.BufferOverflowException if {@code dst} does not have enough remaining
666+
* capacity.
667+
*/
668+
public void copyDataInto(ShortBuffer dst) {
669+
throw new IllegalStateException(
670+
"Tensor of type " + getClass().getSimpleName() + " cannot copy data into ShortBuffer.");
671+
}
672+
557673
/**
558674
* @return a Java long array that contains the tensor data. This may be a copy or reference.
559675
* @throws IllegalStateException if it is called for a non-int64 tensor.
@@ -604,6 +720,12 @@ public byte[] getDataAsUnsignedByteArray() {
604720
return arr;
605721
}
606722

723+
@Override
724+
public void copyDataIntoUnsigned(ByteBuffer dst) {
725+
data.rewind();
726+
dst.put(data);
727+
}
728+
607729
@Override
608730
public String toString() {
609731
return String.format("Tensor(%s, dtype=torch.uint8)", Arrays.toString(shape));
@@ -636,6 +758,12 @@ public byte[] getDataAsByteArray() {
636758
return arr;
637759
}
638760

761+
@Override
762+
public void copyDataInto(ByteBuffer dst) {
763+
data.rewind();
764+
dst.put(data);
765+
}
766+
639767
@Override
640768
public String toString() {
641769
return String.format("Tensor(%s, dtype=torch.int8)", Arrays.toString(shape));
@@ -668,6 +796,12 @@ public int[] getDataAsIntArray() {
668796
return arr;
669797
}
670798

799+
@Override
800+
public void copyDataInto(IntBuffer dst) {
801+
data.rewind();
802+
dst.put(data);
803+
}
804+
671805
@Override
672806
public String toString() {
673807
return String.format("Tensor(%s, dtype=torch.int32)", Arrays.toString(shape));
@@ -690,6 +824,12 @@ public float[] getDataAsFloatArray() {
690824
return arr;
691825
}
692826

827+
@Override
828+
public void copyDataInto(FloatBuffer dst) {
829+
data.rewind();
830+
dst.put(data);
831+
}
832+
693833
@Override
694834
public DType dtype() {
695835
return DType.FLOAT;
@@ -732,6 +872,12 @@ public short[] getDataAsShortArray() {
732872
return arr;
733873
}
734874

875+
@Override
876+
public void copyDataInto(ShortBuffer dst) {
877+
data.rewind();
878+
dst.put(data);
879+
}
880+
735881
@Override
736882
public float[] getDataAsFloatArray() {
737883
data.rewind();
@@ -743,6 +889,21 @@ public float[] getDataAsFloatArray() {
743889
return arr;
744890
}
745891

892+
@Override
893+
public void copyDataInto(FloatBuffer dst) {
894+
data.rewind();
895+
int remaining = data.remaining();
896+
// Match the all-or-nothing semantics of bulk FloatBuffer.put(FloatBuffer):
897+
// verify capacity up front so an undersized destination throws before any
898+
// partial widening is observed in dst.
899+
if (dst.remaining() < remaining) {
900+
throw new java.nio.BufferOverflowException();
901+
}
902+
for (int i = 0; i < remaining; i++) {
903+
dst.put(halfBitsToFloat(data.get()));
904+
}
905+
}
906+
746907
@Override
747908
public String toString() {
748909
return String.format("Tensor(%s, dtype=torch.float16)", Arrays.toString(shape));
@@ -800,6 +961,12 @@ public long[] getDataAsLongArray() {
800961
return arr;
801962
}
802963

964+
@Override
965+
public void copyDataInto(LongBuffer dst) {
966+
data.rewind();
967+
dst.put(data);
968+
}
969+
803970
@Override
804971
public String toString() {
805972
return String.format("Tensor(%s, dtype=torch.int64)", Arrays.toString(shape));
@@ -832,6 +999,12 @@ public double[] getDataAsDoubleArray() {
832999
return arr;
8331000
}
8341001

1002+
@Override
1003+
public void copyDataInto(DoubleBuffer dst) {
1004+
data.rewind();
1005+
dst.put(data);
1006+
}
1007+
8351008
@Override
8361009
public String toString() {
8371010
return String.format("Tensor(%s, dtype=torch.float64)", Arrays.toString(shape));

0 commit comments

Comments
 (0)