Skip to content

Commit 6ba868e

Browse files
authored
Thread kernel_registry through Module::load_method (pytorch#19641)
Differential Revision: D104433196 Pull Request resolved: pytorch#19641
1 parent a76d9cd commit 6ba868e

2 files changed

Lines changed: 19 additions & 5 deletions

File tree

extension/module/module.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace extension {
2020
namespace ET_MODULE_NAMESPACE {
2121

2222
using ET_MERGED_DATA_MAP_NAMESPACE::MergedDataMap;
23+
using ET_RUNTIME_NAMESPACE::Kernel;
2324
using ET_RUNTIME_NAMESPACE::MethodMeta;
2425
using ET_RUNTIME_NAMESPACE::Program;
2526

@@ -406,7 +407,8 @@ runtime::Error Module::load_method(
406407
const std::string& method_name,
407408
runtime::HierarchicalAllocator* planned_memory,
408409
torch::executor::EventTracer* event_tracer,
409-
const LoadBackendOptionsMap* backend_options) {
410+
const LoadBackendOptionsMap* backend_options,
411+
std::vector<Kernel> kernel_registry) {
410412
if (!is_method_loaded(method_name)) {
411413
ET_CHECK_OK_OR_RETURN_ERROR(load());
412414

@@ -446,12 +448,16 @@ runtime::Error Module::load_method(
446448

447449
method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
448450
memory_allocator_.get(), planned_memory, temp_allocator_.get());
451+
method_holder.kernel_registry = std::move(kernel_registry);
449452
auto res_method = program_->load_method(
450453
method_name.c_str(),
451454
method_holder.memory_manager.get(),
452455
event_tracer ? event_tracer : this->event_tracer(),
453456
merged_data_map_.get(),
454-
effective_backend_options);
457+
effective_backend_options,
458+
runtime::Span<const Kernel>(
459+
method_holder.kernel_registry.data(),
460+
method_holder.kernel_registry.size()));
455461
if (!res_method.ok()) {
456462
return res_method.error();
457463
}

extension/module/module.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
namespace executorch {
2828
namespace extension {
2929

30+
using ET_RUNTIME_NAMESPACE::Kernel;
3031
using ET_RUNTIME_NAMESPACE::Method;
3132
using ET_RUNTIME_NAMESPACE::MethodMeta;
3233
using ET_RUNTIME_NAMESPACE::NamedDataMap;
@@ -281,7 +282,8 @@ class Module {
281282
const std::string& method_name,
282283
runtime::HierarchicalAllocator* planned_memory = nullptr,
283284
torch::executor::EventTracer* event_tracer = nullptr,
284-
const LoadBackendOptionsMap* backend_options = nullptr);
285+
const LoadBackendOptionsMap* backend_options = nullptr,
286+
std::vector<Kernel> kernel_registry = {});
285287

286288
ET_DEPRECATED ET_NODISCARD runtime::Error inline load_method(
287289
const std::string& method_name,
@@ -329,9 +331,14 @@ class Module {
329331
ET_NODISCARD inline runtime::Error load_forward(
330332
runtime::HierarchicalAllocator* planned_memory = nullptr,
331333
torch::executor::EventTracer* event_tracer = nullptr,
332-
const LoadBackendOptionsMap* backend_options = nullptr) {
334+
const LoadBackendOptionsMap* backend_options = nullptr,
335+
std::vector<Kernel> kernel_registry = {}) {
333336
return load_method(
334-
"forward", planned_memory, event_tracer, backend_options);
337+
"forward",
338+
planned_memory,
339+
event_tracer,
340+
backend_options,
341+
std::move(kernel_registry));
335342
}
336343

337344
ET_DEPRECATED ET_NODISCARD inline runtime::Error load_forward(
@@ -724,6 +731,7 @@ class Module {
724731
std::unique_ptr<PlannedMemory> planned_memory;
725732
std::unique_ptr<runtime::MemoryManager> memory_manager;
726733
std::unique_ptr<Method> method;
734+
std::vector<Kernel> kernel_registry;
727735
};
728736

729737
std::string file_path_;

0 commit comments

Comments
 (0)