@@ -440,6 +440,69 @@ TEST_F(OperatorRegistryTest, GetOpFunctionUsesProvidedKernelList) {
440440 EXPECT_EQ (run_kernel (*fallback_func), 50 );
441441}
442442
443+ TEST_F (OperatorRegistryTest, ProvidedKernelListMissCanFallBackToGlobal) {
444+ std::array<char , kKernelKeyBufSize > buf{};
445+ Error err = make_kernel_key (
446+ {{ScalarType::Long, {0 , 1 , 2 , 3 }}}, buf.data (), buf.size ());
447+ ASSERT_EQ (err, Error::Ok);
448+ KernelKey long_key = KernelKey (buf.data ());
449+
450+ Kernel global_kernel = Kernel (
451+ " test::provided_kernel_list_global_fallback" ,
452+ KernelKey{},
453+ [](KernelRuntimeContext& context, Span<EValue*> stack) {
454+ (void )context;
455+ *(stack[0 ]) = Scalar (50 );
456+ });
457+ err = register_kernels ({&global_kernel, 1 });
458+ ASSERT_EQ (err, Error::Ok);
459+
460+ Kernel scoped_kernel = Kernel (
461+ " test::provided_kernel_list_global_fallback" ,
462+ long_key,
463+ [](KernelRuntimeContext& context, Span<EValue*> stack) {
464+ (void )context;
465+ *(stack[0 ]) = Scalar (100 );
466+ });
467+ Span<const Kernel> scoped_registry (&scoped_kernel, 1 );
468+
469+ std::array<Tensor::DimOrderType, 4 > dims = {0 , 1 , 2 , 3 };
470+ auto dim_order_type = Span<Tensor::DimOrderType>(dims.data (), dims.size ());
471+ std::array<TensorMeta, 1 > long_meta = {
472+ TensorMeta (ScalarType::Long, dim_order_type)};
473+ Span<const TensorMeta> long_kernel_key (long_meta.data (), long_meta.size ());
474+
475+ std::array<TensorMeta, 1 > float_meta = {
476+ TensorMeta (ScalarType::Float, dim_order_type)};
477+ Span<const TensorMeta> float_kernel_key (float_meta.data (), float_meta.size ());
478+
479+ auto run_kernel = [](OpFunction func) {
480+ EValue value = Scalar (0 );
481+ std::array<EValue*, 1 > stack = {&value};
482+ KernelRuntimeContext context{};
483+ func (context, Span<EValue*>(stack.data (), stack.size ()));
484+ return value.toScalar ().to <int64_t >();
485+ };
486+
487+ Result<OpFunction> scoped_func = get_op_function_from_registry (
488+ " test::provided_kernel_list_global_fallback" ,
489+ long_kernel_key,
490+ scoped_registry);
491+ ASSERT_EQ (scoped_func.error (), Error::Ok);
492+ EXPECT_EQ (run_kernel (*scoped_func), 100 );
493+
494+ Result<OpFunction> scoped_miss = get_op_function_from_registry (
495+ " test::provided_kernel_list_global_fallback" ,
496+ float_kernel_key,
497+ scoped_registry);
498+ ASSERT_EQ (scoped_miss.error (), Error::OperatorMissing);
499+
500+ Result<OpFunction> global_func = get_op_function_from_registry (
501+ " test::provided_kernel_list_global_fallback" , float_kernel_key);
502+ ASSERT_EQ (global_func.error (), Error::Ok);
503+ EXPECT_EQ (run_kernel (*global_func), 50 );
504+ }
505+
443506TEST_F (OperatorRegistryTest, DoubleRegisterKernelsDies) {
444507 std::array<char , kKernelKeyBufSize > buf_long_contiguous;
445508 Error err = make_kernel_key (
0 commit comments