File tree Expand file tree Collapse file tree 3 files changed +13
-5
lines changed
infiniop/ops/embedding/nvidia Expand file tree Collapse file tree 3 files changed +13
-5
lines changed Original file line number Diff line number Diff line change @@ -43,9 +43,16 @@ Embedding::Embedding(size_t num_embeddings,
4343}
4444
4545Tensor Embedding::forward (const Tensor &indices) const {
46+ // Ensure indices are on the same device as weight
47+ // This avoids synchronous memcpy in ops layer which would hurt performance
48+ Tensor indices_on_device = indices;
49+ if (indices->device () != device_) {
50+ indices_on_device = indices->to (device_);
51+ }
52+
4653 // Ensure indices are contiguous for efficient access
4754 // op::embedding now supports device-side input for graph recording
48- Tensor indices_contiguous = indices ->is_contiguous () ? indices : indices ->contiguous ();
55+ Tensor indices_contiguous = indices_on_device ->is_contiguous () ? indices_on_device : indices_on_device ->contiguous ();
4956
5057 // Use op::embedding which now supports device-side input and batch dimension
5158 // This enables full graph recording support without synchronization
Original file line number Diff line number Diff line change @@ -12,8 +12,10 @@ common::OpDispatcher<Embedding::schema> &Embedding::dispatcher() {
1212}
1313
1414void Embedding::execute (Tensor out, Tensor input, Tensor weight) {
15- // Check that output and weight are on the same device
16- INFINICORE_ASSERT_TENSORS_SAME_DEVICE (out, weight);
15+ // Check that all tensors are on the same device
16+ // This is critical: if input is on CPU while out/weight are on GPU,
17+ // passing CPU pointer to CUDA kernel will cause memory access errors
18+ INFINICORE_ASSERT_TENSORS_SAME_DEVICE (out, input, weight);
1719
1820 // Set device context
1921 infinicore::context::setDevice (out->device ());
Original file line number Diff line number Diff line change @@ -23,7 +23,6 @@ infiniStatus_t Descriptor::create(
2323 infiniopTensorDescriptor_t input_desc,
2424 infiniopTensorDescriptor_t weight_desc) {
2525
26- auto handle_nvidia = reinterpret_cast <device::nvidia::Handle *>(handle);
2726 auto input_shape = input_desc->shape ();
2827 auto weight_shape = weight_desc->shape ();
2928
@@ -63,7 +62,7 @@ infiniStatus_t Descriptor::create(
6362 vocab_size,
6463 input_dtype,
6564 weight_dtype,
66- new Opaque{handle_nvidia ->internal ()},
65+ new Opaque{reinterpret_cast <device::nvidia::Handle *>(handle) ->internal ()},
6766 handle->device ,
6867 handle->device_id );
6968
You can’t perform that action at this time.
0 commit comments