Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "ops/add.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/rearrange.hpp"
Expand Down
7 changes: 7 additions & 0 deletions include/infinicore/ops/embedding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@

namespace infinicore::op {

class Embedding {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor out, Tensor input, Tensor weight);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor embedding(Tensor input, Tensor weight);
void embedding_(Tensor out, Tensor input, Tensor weight);
} // namespace infinicore::op
1 change: 1 addition & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "infiniop/ops/clip.h"
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/embedding.h"
#include "infiniop/ops/gelu.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/layer_norm.h"
Expand Down
26 changes: 26 additions & 0 deletions include/infiniop/ops/embedding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef __INFINIOP_EMBEDDING_API_H__
#define __INFINIOP_EMBEDDING_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopEmbeddingDescriptor_t;

__C __export infiniStatus_t infiniopCreateEmbeddingDescriptor(
infiniopHandle_t handle,
infiniopEmbeddingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc);

__C __export infiniStatus_t infiniopEmbedding(
infiniopEmbeddingDescriptor_t desc,
void *output,
const void *input,
const void *weight,
void *stream);

__C __export infiniStatus_t infiniopDestroyEmbeddingDescriptor(
infiniopEmbeddingDescriptor_t desc);

#endif

5 changes: 2 additions & 3 deletions python/infinicore/nn/functional/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ def embedding(
and (sparse is False)
), "Unsupported parameters."

assert "cpu" == input.device.type, (
"The device of 'input' variable must be on the CPU."
)
# Note: embedding now supports device-side input for graph recording
# The C++ implementation handles both CPU and device-side inputs

if out is None:
return Tensor(_infinicore.embedding(input._underlying, weight._underlying))
Expand Down
82 changes: 11 additions & 71 deletions src/infinicore/nn/embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,80 +43,20 @@ Embedding::Embedding(size_t num_embeddings,
}

Tensor Embedding::forward(const Tensor &indices) const {
// Get the shape of indices
auto indices_shape = indices->shape();

// Output shape: indices_shape + [embedding_dim]
std::vector<size_t> output_shape = indices_shape;
output_shape.push_back(embedding_dim_);

// Create output tensor on the same device as weight
auto out = Tensor::empty(output_shape, weight_->dtype(), weight_->device());

// Flatten indices for sequential row copies
auto cpu_device = Device(Device::Type::CPU, 0);
auto indices_cpu = indices->to(cpu_device)->contiguous();

// Calculate total number of lookups
size_t num_lookups = 1;
for (auto dim : indices_shape) {
num_lookups *= dim;
// Ensure indices are on the same device as weight
// This avoids synchronous memcpy in ops layer which would hurt performance
Tensor indices_on_device = indices;
if (indices->device() != device_) {
indices_on_device = indices->to(device_);
}

const size_t row_bytes = embedding_dim_ * dsize(weight_->dtype());

// Source and destination base pointers
auto *weight_base = weight_->data();
auto *out_base = out->data();

// Helper lambda to read index based on dtype with bounds checking
auto read_index = [&](size_t i) -> int64_t {
auto dtype = indices_cpu->dtype();
if (dtype == DataType::I32) {
const auto *data = reinterpret_cast<const int32_t *>(indices_cpu->data());
return static_cast<int64_t>(data[i]);
} else if (dtype == DataType::I64) {
const auto *data = reinterpret_cast<const int64_t *>(indices_cpu->data());
return data[i];
} else if (dtype == DataType::U32) {
const auto *data = reinterpret_cast<const uint32_t *>(indices_cpu->data());
return static_cast<int64_t>(data[i]);
} else if (dtype == DataType::U64) {
const auto *data = reinterpret_cast<const uint64_t *>(indices_cpu->data());
uint64_t val = data[i];
// Check if value can fit in int64_t
if (val > static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
throw std::out_of_range("Index value out of range for int64_t: " + std::to_string(val));
}
return static_cast<int64_t>(val);
} else {
throw std::runtime_error("Embedding indices must be integer type, got dtype=" + std::to_string(static_cast<int>(dtype)));
}
};

if (weight_->device().getType() == Device::Type::CPU) {
// CPU path: memcpy row by row
for (size_t i = 0; i < num_lookups; ++i) {
int64_t idx = read_index(i);
if (idx < 0 || idx >= static_cast<int64_t>(num_embeddings_)) {
throw std::out_of_range(
"Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")");
}
std::memcpy(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
}
} else {
// Device path: use stream-ordered D2D copies
for (size_t i = 0; i < num_lookups; ++i) {
int64_t idx = read_index(i);
if (idx < 0 || idx >= static_cast<int64_t>(num_embeddings_)) {
throw std::out_of_range(
"Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")");
}
context::memcpyD2D(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
}
}
// Ensure indices are contiguous for efficient access
// op::embedding now supports device-side input for graph recording
Tensor indices_contiguous = indices_on_device->is_contiguous() ? indices_on_device : indices_on_device->contiguous();

return out;
// Use op::embedding which now supports device-side input and batch dimension
// This enables full graph recording support without synchronization
return op::embedding(indices_contiguous, weight_);
}

std::string Embedding::extra_repr() const {
Expand Down
84 changes: 21 additions & 63 deletions src/infinicore/ops/embedding/embedding.cc
Original file line number Diff line number Diff line change
@@ -1,15 +1,34 @@
#include "infinicore/ops/embedding.hpp"
#include "../../utils.hpp"
#include "infinicore/context/context.hpp"
#include <cstring>
#include <stdexcept>

namespace infinicore::op {

common::OpDispatcher<Embedding::schema> &Embedding::dispatcher() {
static common::OpDispatcher<Embedding::schema> dispatcher_;
return dispatcher_;
}

void Embedding::execute(Tensor out, Tensor input, Tensor weight) {
// Check that all tensors are on the same device
// This is critical: if input is on CPU while out/weight are on GPU,
// passing CPU pointer to CUDA kernel will cause memory access errors
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, input, weight);

// Set device context
infinicore::context::setDevice(out->device());

// Use dispatcher to lookup kernel (infiniop implementation)
dispatcher().lookup(out->device().getType())(out, input, weight);
}

Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract
Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1
) {
auto input_shape = input->shape();
auto weight_shape = weight->shape();
// auto vocab_size = weight_shape[0];
auto embedding_dim = weight_shape[1];

// Assign memory to out variables
Expand All @@ -22,68 +41,7 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i
}

void embedding_(Tensor out, Tensor input, Tensor weight) {
assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype()));
assert(infinicore::Device::Type::CPU == input->device().getType());

auto input_shape = input->shape();
auto weight_shape = weight->shape();
auto embedding_dim = weight_shape[1];

// Calculate the number of token
Size counts = 1;
for (auto &v : input_shape) {
counts *= v;
}

// the bytes of one token
const Size bytes = dsize(weight->dtype()) * embedding_dim;
auto *weight_ptr = weight->data();
auto *out_ptr = out->data();

// copies
if (weight->device().getType() == Device::Type::CPU) {
if (infinicore::DataType::I64 == input->dtype()) {
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
for (Size i = 0; i < counts; ++i) {
int64_t idx = input_arr[i];
assert((idx >= 0) && (idx < weight_shape[0]));
std::memcpy(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
}
} else if (infinicore::DataType::I32 == input->dtype()) {
const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data());

for (Size i = 0; i < counts; ++i) {
int32_t idx = input_arr[i];
assert((idx >= 0) && (idx < weight_shape[0]));
std::memcpy(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
}
}

} else {
if (infinicore::DataType::I64 == input->dtype()) {
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
for (Size i = 0; i < counts; ++i) {
int64_t idx = input_arr[i];
assert((idx >= 0) && (idx < weight_shape[0]));
context::memcpyD2D(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
}
} else if (infinicore::DataType::I32 == input->dtype()) {
const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data());
for (Size i = 0; i < counts; ++i) {
int32_t idx = input_arr[i];
assert((idx >= 0) && (idx < weight_shape[0]));
context::memcpyD2D(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
}
}
}
Embedding::execute(out, input, weight);
}

} // namespace infinicore::op
49 changes: 49 additions & 0 deletions src/infinicore/ops/embedding/embedding_infiniop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/embedding.hpp"
#include <infiniop.h>

namespace infinicore::op::embedding_impl::infiniop {

thread_local common::OpCache<size_t, infiniopEmbeddingDescriptor_t> caches(
100, // capacity
[](infiniopEmbeddingDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyEmbeddingDescriptor(desc));
desc = nullptr;
}
});

void calculate(Tensor out, Tensor input, Tensor weight) {
size_t seed = hash_combine(out, input, weight);

auto device = context::getDevice();
auto &cache = caches.getCache(device);

auto desc_opt = cache.get(seed);
infiniopEmbeddingDescriptor_t desc = nullptr;

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateEmbeddingDescriptor(
context::getInfiniopHandle(device), &desc,
out->desc(), input->desc(), weight->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}

INFINICORE_CHECK_ERROR(infiniopEmbedding(
desc,
out->data(),
input->data(),
weight->data(),
context::getStream()));
}

static bool registered = []() {
Embedding::dispatcher().registerAll(&calculate, false);
return true;
}();

} // namespace infinicore::op::embedding_impl::infiniop
Loading