Skip to content

Commit 09a7cbe

Browse files
Add Span<const Kernel> overload to get_op_function_from_registry (pytorch#19519)
Differential Revision: D98079809 Pull Request resolved: pytorch#19519
1 parent 174d3ad commit 09a7cbe

3 files changed

Lines changed: 77 additions & 7 deletions

File tree

runtime/kernel/operator_registry.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,8 @@ bool registry_has_op_function(
249249

250250
Result<OpFunction> get_op_function_from_registry(
251251
const char* name,
252-
Span<const TensorMeta> meta_list) {
252+
Span<const TensorMeta> meta_list,
253+
Span<const Kernel> kernel_list) {
253254
std::array<char, internal::kKernelKeyBufSize> key_string;
254255
Error err = internal::make_kernel_key_string(
255256
meta_list, key_string.data(), key_string.size());
@@ -260,24 +261,31 @@ Result<OpFunction> get_op_function_from_registry(
260261
KernelKey kernel_key = KernelKey(key_string.data());
261262

262263
int32_t fallback_idx = -1;
263-
for (size_t idx = 0; idx < num_registered_kernels; idx++) {
264-
if (strcmp(registered_kernels[idx].name_, name) == 0) {
265-
if (registered_kernels[idx].kernel_key_ == kernel_key) {
266-
return registered_kernels[idx].op_;
264+
for (size_t idx = 0; idx < kernel_list.size(); idx++) {
265+
if (strcmp(kernel_list[idx].name_, name) == 0) {
266+
if (kernel_list[idx].kernel_key_ == kernel_key) {
267+
return kernel_list[idx].op_;
267268
}
268-
if (registered_kernels[idx].kernel_key_.is_fallback()) {
269+
if (kernel_list[idx].kernel_key_.is_fallback()) {
269270
fallback_idx = idx;
270271
}
271272
}
272273
}
273274
if (fallback_idx != -1) {
274-
return registered_kernels[fallback_idx].op_;
275+
return kernel_list[fallback_idx].op_;
275276
}
276277
ET_LOG(Error, "kernel '%s' not found.", name);
277278
ET_LOG_TENSOR_META(meta_list);
278279
return Error::OperatorMissing;
279280
}
280281

282+
Result<OpFunction> get_op_function_from_registry(
283+
const char* name,
284+
Span<const TensorMeta> meta_list) {
285+
return get_op_function_from_registry(
286+
name, meta_list, get_registered_kernels());
287+
}
288+
281289
Span<const Kernel> get_registered_kernels() {
282290
return {registered_kernels, num_registered_kernels};
283291
}

runtime/kernel/operator_registry.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,15 @@ ::executorch::runtime::Result<OpFunction> get_op_function_from_registry(
233233
const char* name,
234234
Span<const TensorMeta> meta_list = {});
235235

236+
/**
237+
* Returns the operator with a given name and TensorMeta list from the provided
238+
* kernel list instead of the global registry.
239+
*/
240+
::executorch::runtime::Result<OpFunction> get_op_function_from_registry(
241+
const char* name,
242+
Span<const TensorMeta> meta_list,
243+
Span<const Kernel> kernel_list);
244+
236245
/**
237246
* Returns all registered kernels.
238247
*/

runtime/kernel/test/operator_registry_test.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,59 @@ TEST_F(OperatorRegistryTest, RegisterTwoKernels) {
387387
ASSERT_EQ(val_2, 50);
388388
}
389389

390+
TEST_F(OperatorRegistryTest, GetOpFunctionUsesProvidedKernelList) {
391+
std::array<char, kKernelKeyBufSize> buf{};
392+
Error err = make_kernel_key(
393+
{{ScalarType::Long, {0, 1, 2, 3}}}, buf.data(), buf.size());
394+
ASSERT_EQ(err, Error::Ok);
395+
KernelKey long_key = KernelKey(buf.data());
396+
397+
std::array<Kernel, 2> kernels = {
398+
Kernel(
399+
"test::provided_kernel_list",
400+
KernelKey{},
401+
[](KernelRuntimeContext& context, Span<EValue*> stack) {
402+
(void)context;
403+
*(stack[0]) = Scalar(50);
404+
}),
405+
Kernel(
406+
"test::provided_kernel_list",
407+
long_key,
408+
[](KernelRuntimeContext& context, Span<EValue*> stack) {
409+
(void)context;
410+
*(stack[0]) = Scalar(100);
411+
}),
412+
};
413+
Span<const Kernel> kernels_span(kernels.data(), kernels.size());
414+
415+
std::array<Tensor::DimOrderType, 4> dims = {0, 1, 2, 3};
416+
auto dim_order_type = Span<Tensor::DimOrderType>(dims.data(), dims.size());
417+
std::array<TensorMeta, 1> long_meta = {
418+
TensorMeta(ScalarType::Long, dim_order_type)};
419+
Span<const TensorMeta> long_kernel_key(long_meta.data(), long_meta.size());
420+
421+
auto run_kernel = [](OpFunction func) {
422+
EValue value = Scalar(0);
423+
std::array<EValue*, 1> stack = {&value};
424+
KernelRuntimeContext context{};
425+
func(context, Span<EValue*>(stack.data(), stack.size()));
426+
return value.toScalar().to<int64_t>();
427+
};
428+
429+
Result<OpFunction> specialized_func = get_op_function_from_registry(
430+
"test::provided_kernel_list", long_kernel_key, kernels_span);
431+
ASSERT_EQ(specialized_func.error(), Error::Ok);
432+
EXPECT_EQ(run_kernel(*specialized_func), 100);
433+
434+
std::array<TensorMeta, 1> float_meta = {
435+
TensorMeta(ScalarType::Float, dim_order_type)};
436+
Span<const TensorMeta> float_kernel_key(float_meta.data(), float_meta.size());
437+
Result<OpFunction> fallback_func = get_op_function_from_registry(
438+
"test::provided_kernel_list", float_kernel_key, kernels_span);
439+
ASSERT_EQ(fallback_func.error(), Error::Ok);
440+
EXPECT_EQ(run_kernel(*fallback_func), 50);
441+
}
442+
390443
TEST_F(OperatorRegistryTest, DoubleRegisterKernelsDies) {
391444
std::array<char, kKernelKeyBufSize> buf_long_contiguous;
392445
Error err = make_kernel_key(

0 commit comments

Comments
 (0)