Skip to content

Commit a8cfe2b

Browse files
Thread method-scoped kernel registry through Program and Method (pytorch#19561)
Differential Revision: D98080033 Pull Request resolved: pytorch#19561
1 parent 125d651 commit a8cfe2b

5 files changed

Lines changed: 96 additions & 9 deletions

File tree

runtime/executor/method.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -802,8 +802,19 @@ Error Method::resolve_operator(
802802
}
803803

804804
// Find a kernel with the matching name and tensor meta.
805-
Result<OpFunction> op_function =
806-
get_op_function_from_registry(operator_name, {meta, count});
805+
// Try method-scoped registry first (if provided), then fall back to global.
806+
auto resolve_op_function = [&]() -> Result<OpFunction> {
807+
if (!kernel_registry_.empty()) {
808+
Result<OpFunction> method_scoped_op_function =
809+
get_op_function_from_registry(
810+
operator_name, {meta, count}, kernel_registry_);
811+
if (method_scoped_op_function.ok()) {
812+
return method_scoped_op_function;
813+
}
814+
}
815+
return get_op_function_from_registry(operator_name, {meta, count});
816+
};
817+
Result<OpFunction> op_function = resolve_op_function();
807818
if (!op_function.ok()) {
808819
ET_LOG(
809820
Error,
@@ -831,7 +842,8 @@ Result<Method> Method::load(
831842
MemoryManager* memory_manager,
832843
EventTracer* event_tracer,
833844
const NamedDataMap* external_data_map,
834-
const LoadBackendOptionsMap* backend_options) {
845+
const LoadBackendOptionsMap* backend_options,
846+
Span<const Kernel> kernel_registry) {
835847
MemoryAllocator* temp_allocator = memory_manager->temp_allocator();
836848
if (temp_allocator == nullptr) {
837849
PlatformMemoryAllocator* platform_allocator =
@@ -844,7 +856,8 @@ Result<Method> Method::load(
844856
new (platform_allocator) PlatformMemoryAllocator();
845857
temp_allocator = platform_allocator;
846858
}
847-
Method method(program, memory_manager, event_tracer, temp_allocator);
859+
Method method(
860+
program, memory_manager, event_tracer, temp_allocator, kernel_registry);
848861
ET_LOG(Debug, "Loading method: %s.", s_plan->name()->c_str());
849862
Error err = method.init(s_plan, external_data_map, backend_options);
850863
if (err != Error::Ok) {

runtime/executor/method.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <executorch/runtime/executor/memory_manager.h>
2424
#include <executorch/runtime/executor/merged_data_map.h>
2525
#include <executorch/runtime/executor/method_meta.h>
26+
#include <executorch/runtime/kernel/operator_registry.h>
2627
#include <executorch/runtime/platform/compiler.h>
2728

2829
// Forward declare flatbuffer types. This is a public header and must not
@@ -82,6 +83,7 @@ class Method final {
8283
merged_data_map_(std::move(rhs.merged_data_map_)),
8384
external_constants_(rhs.external_constants_),
8485
n_external_constants_(rhs.n_external_constants_),
86+
kernel_registry_(rhs.kernel_registry_),
8587
init_state_(rhs.init_state_) {
8688
// Required: clear out fields that the dtor looks at, so that we don't free
8789
// anything twice.
@@ -331,7 +333,8 @@ class Method final {
331333
const Program* program,
332334
MemoryManager* memory_manager,
333335
EventTracer* event_tracer,
334-
MemoryAllocator* temp_allocator)
336+
MemoryAllocator* temp_allocator,
337+
Span<const Kernel> kernel_registry = {})
335338
: step_state_(),
336339
program_(program),
337340
memory_manager_(memory_manager),
@@ -348,6 +351,7 @@ class Method final {
348351
merged_data_map_(nullptr),
349352
external_constants_(nullptr),
350353
n_external_constants_(0),
354+
kernel_registry_(kernel_registry),
351355
init_state_(InitializationState::Uninitialized) {}
352356

353357
/// Static factory used by Program.
@@ -357,7 +361,8 @@ class Method final {
357361
MemoryManager* memory_manager,
358362
EventTracer* event_tracer,
359363
const NamedDataMap* named_data_map,
360-
const LoadBackendOptionsMap* backend_options = nullptr);
364+
const LoadBackendOptionsMap* backend_options = nullptr,
365+
Span<const Kernel> kernel_registry = {});
361366

362367
/**
363368
* Initialize the method from its serialized representation.
@@ -403,6 +408,8 @@ class Method final {
403408
NamedData* external_constants_;
404409
size_t n_external_constants_ = 0;
405410

411+
Span<const Kernel> kernel_registry_;
412+
406413
InitializationState init_state_;
407414

408415
/**

runtime/executor/program.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,8 @@ Result<Method> Program::load_method(
371371
MemoryManager* memory_manager,
372372
EventTracer* event_tracer,
373373
const NamedDataMap* named_data_map,
374-
const LoadBackendOptionsMap* backend_options) const {
374+
const LoadBackendOptionsMap* backend_options,
375+
Span<const Kernel> kernel_registry) const {
375376
EXECUTORCH_SCOPE_PROF("Program::load_method");
376377
internal::event_tracer_create_event_block(event_tracer, "Default");
377378
internal::EventTracerProfileMethodScope event_tracer_scope =
@@ -394,7 +395,8 @@ Result<Method> Program::load_method(
394395
memory_manager,
395396
event_tracer,
396397
named_data_map,
397-
backend_options);
398+
backend_options,
399+
kernel_registry);
398400
}
399401

400402
Result<MethodMeta> Program::method_meta(const char* method_name) const {

runtime/executor/program.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <executorch/runtime/executor/method.h>
2222
#include <executorch/runtime/executor/method_meta.h>
2323
#include <executorch/runtime/executor/pte_data_map.h>
24+
#include <executorch/runtime/kernel/operator_registry.h>
2425
#include <executorch/runtime/platform/compiler.h>
2526

2627
// Forward declare flatbuffer types. This is a public header and must not
@@ -151,7 +152,8 @@ class Program final {
151152
MemoryManager* memory_manager,
152153
EventTracer* event_tracer = nullptr,
153154
const NamedDataMap* named_data_map = nullptr,
154-
const LoadBackendOptionsMap* backend_options = nullptr) const;
155+
const LoadBackendOptionsMap* backend_options = nullptr,
156+
Span<const Kernel> kernel_registry = {}) const;
155157

156158
/**
157159
* Gathers metadata for the named method.

runtime/kernel/test/operator_registry_test.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,69 @@ TEST_F(OperatorRegistryTest, GetOpFunctionUsesProvidedKernelList) {
440440
EXPECT_EQ(run_kernel(*fallback_func), 50);
441441
}
442442

443+
TEST_F(OperatorRegistryTest, ProvidedKernelListMissCanFallBackToGlobal) {
444+
std::array<char, kKernelKeyBufSize> buf{};
445+
Error err = make_kernel_key(
446+
{{ScalarType::Long, {0, 1, 2, 3}}}, buf.data(), buf.size());
447+
ASSERT_EQ(err, Error::Ok);
448+
KernelKey long_key = KernelKey(buf.data());
449+
450+
Kernel global_kernel = Kernel(
451+
"test::provided_kernel_list_global_fallback",
452+
KernelKey{},
453+
[](KernelRuntimeContext& context, Span<EValue*> stack) {
454+
(void)context;
455+
*(stack[0]) = Scalar(50);
456+
});
457+
err = register_kernels({&global_kernel, 1});
458+
ASSERT_EQ(err, Error::Ok);
459+
460+
Kernel scoped_kernel = Kernel(
461+
"test::provided_kernel_list_global_fallback",
462+
long_key,
463+
[](KernelRuntimeContext& context, Span<EValue*> stack) {
464+
(void)context;
465+
*(stack[0]) = Scalar(100);
466+
});
467+
Span<const Kernel> scoped_registry(&scoped_kernel, 1);
468+
469+
std::array<Tensor::DimOrderType, 4> dims = {0, 1, 2, 3};
470+
auto dim_order_type = Span<Tensor::DimOrderType>(dims.data(), dims.size());
471+
std::array<TensorMeta, 1> long_meta = {
472+
TensorMeta(ScalarType::Long, dim_order_type)};
473+
Span<const TensorMeta> long_kernel_key(long_meta.data(), long_meta.size());
474+
475+
std::array<TensorMeta, 1> float_meta = {
476+
TensorMeta(ScalarType::Float, dim_order_type)};
477+
Span<const TensorMeta> float_kernel_key(float_meta.data(), float_meta.size());
478+
479+
auto run_kernel = [](OpFunction func) {
480+
EValue value = Scalar(0);
481+
std::array<EValue*, 1> stack = {&value};
482+
KernelRuntimeContext context{};
483+
func(context, Span<EValue*>(stack.data(), stack.size()));
484+
return value.toScalar().to<int64_t>();
485+
};
486+
487+
Result<OpFunction> scoped_func = get_op_function_from_registry(
488+
"test::provided_kernel_list_global_fallback",
489+
long_kernel_key,
490+
scoped_registry);
491+
ASSERT_EQ(scoped_func.error(), Error::Ok);
492+
EXPECT_EQ(run_kernel(*scoped_func), 100);
493+
494+
Result<OpFunction> scoped_miss = get_op_function_from_registry(
495+
"test::provided_kernel_list_global_fallback",
496+
float_kernel_key,
497+
scoped_registry);
498+
ASSERT_EQ(scoped_miss.error(), Error::OperatorMissing);
499+
500+
Result<OpFunction> global_func = get_op_function_from_registry(
501+
"test::provided_kernel_list_global_fallback", float_kernel_key);
502+
ASSERT_EQ(global_func.error(), Error::Ok);
503+
EXPECT_EQ(run_kernel(*global_func), 50);
504+
}
505+
443506
TEST_F(OperatorRegistryTest, DoubleRegisterKernelsDies) {
444507
std::array<char, kKernelKeyBufSize> buf_long_contiguous;
445508
Error err = make_kernel_key(

0 commit comments

Comments
 (0)