Skip to content

Commit 5dd2a26

Browse files
committed
Issue/846 - Ensure embedding tensors are on the same device.
1 parent bd25dc2 commit 5dd2a26

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

src/infinicore/nn/embedding.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,16 @@ Embedding::Embedding(size_t num_embeddings,
4343
}
4444

4545
Tensor Embedding::forward(const Tensor &indices) const {
46+
// Ensure indices are on the same device as weight
47+
// This avoids synchronous memcpy in ops layer which would hurt performance
48+
Tensor indices_on_device = indices;
49+
if (indices->device() != device_) {
50+
indices_on_device = indices->to(device_);
51+
}
52+
4653
// Ensure indices are contiguous for efficient access
4754
// op::embedding now supports device-side input for graph recording
48-
Tensor indices_contiguous = indices->is_contiguous() ? indices : indices->contiguous();
55+
Tensor indices_contiguous = indices_on_device->is_contiguous() ? indices_on_device : indices_on_device->contiguous();
4956

5057
// Use op::embedding which now supports device-side input and batch dimension
5158
// This enables full graph recording support without synchronization

src/infinicore/ops/embedding/embedding.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ common::OpDispatcher<Embedding::schema> &Embedding::dispatcher() {
1212
}
1313

1414
void Embedding::execute(Tensor out, Tensor input, Tensor weight) {
15-
// Check that output and weight are on the same device
16-
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, weight);
15+
// Check that all tensors are on the same device
16+
// This is critical: if input is on CPU while out/weight are on GPU,
17+
// passing CPU pointer to CUDA kernel will cause memory access errors
18+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, input, weight);
1719

1820
// Set device context
1921
infinicore::context::setDevice(out->device());

src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ infiniStatus_t Descriptor::create(
2323
infiniopTensorDescriptor_t input_desc,
2424
infiniopTensorDescriptor_t weight_desc) {
2525

26-
auto handle_nvidia = reinterpret_cast<device::nvidia::Handle *>(handle);
2726
auto input_shape = input_desc->shape();
2827
auto weight_shape = weight_desc->shape();
2928

@@ -63,7 +62,7 @@ infiniStatus_t Descriptor::create(
6362
vocab_size,
6463
input_dtype,
6564
weight_dtype,
66-
new Opaque{handle_nvidia->internal()},
65+
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
6766
handle->device,
6867
handle->device_id);
6968

0 commit comments

Comments
 (0)