Skip to content

Commit 1251fa7

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Modernize embedding
Modernize embedding to support ANY_STORAGE. Add buffer and texture shader variants using BufferMetadata/TextureMetadata with indexing.glslh. Unify new dispatch path with add_storage_type_suffix and graph.meta_ubo(). Legacy channels-packed texture path retained for backward compatibility. Pull Request resolved: pytorch#18057 ghstack-source-id: 353546689 @exported-using-ghexport Differential Revision: [D95970161](https://our.internmc.facebook.com/intern/diff/D95970161/)
1 parent 413c62e commit 1251fa7

4 files changed

Lines changed: 16 additions & 13 deletions

File tree

backends/vulkan/op_registry.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1390,11 +1390,20 @@ def register_repeat():
13901390

13911391
@update_features(exir_ops.edge.aten.embedding.default)
13921392
def register_embedding():
1393+
def check_embedding_weight_size(node: torch.fx.Node) -> bool:
1394+
weight = node.args[0]
1395+
if isinstance(weight, torch.fx.Node) and utils.is_tensor_node(weight):
1396+
numel = weight.meta["val"].numel()
1397+
if numel > utils.DEFAULT_BUFFER_LIMIT:
1398+
return False
1399+
return True
1400+
13931401
return OpFeatures(
1394-
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
1402+
inputs_storage=utils.ANY_STORAGE,
13951403
inputs_dtypes=[utils.FP_T, utils.INT_T],
13961404
supports_prepacking=True,
13971405
supports_resize=True,
1406+
are_node_inputs_supported_fn=check_embedding_weight_size,
13981407
)
13991408

14001409

backends/vulkan/runtime/graph/ops/glsl/embedding_legacy.glsl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ layout(std430) buffer;
1616

1717
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
1818
${layout_declare_tensor(B, "r", "t_in", "int", STORAGE)}
19-
${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE)}
19+
${layout_declare_tensor(B, "r", "t_weight", DTYPE, "texture2d")}
2020
${layout_declare_ubo(B, "ivec4", "sizes")}
2121

2222
#include "indexing_utils.h"
@@ -30,9 +30,6 @@ const lowp int packed_dim = unhash_packed_dim(out_layout);
3030
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
3131
const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);
3232

33-
${layout_declare_spec_const(C, "int", "weight_layout", "DEFAULT_LAYOUT")}
34-
const lowp ivec4 weight_axis_map = unhash_axis_map(weight_layout);
35-
3633
void main() {
3734
const ivec3 out_lpos = ivec3(gl_GlobalInvocationID);
3835
const ivec4 out_tidx = lpos_to_tidx(out_lpos, sizes, out_axis_map.w, packed_dim);
@@ -48,8 +45,8 @@ void main() {
4845
const int in_texel_elem = load_texel_lpos(t_in, in_lpos, in_axis_map)[out_tidx.w % 4];
4946

5047
// Read weight tensor for embedding, it is height-packed.
51-
const ivec3 weight_lpos = ivec3(out_tidx.x, in_texel_elem / 4, 0);
52-
out_texel[i] = load_texel_lpos(t_weight, weight_lpos, weight_axis_map)[in_texel_elem % 4];
48+
const ivec2 weight_pos = ivec2(out_tidx.x, in_texel_elem / 4);
49+
out_texel[i] = texelFetch(t_weight, weight_pos, 0)[in_texel_elem % 4];
5350
}
5451

5552
write_texel_lpos(t_out, out_lpos, out_texel, out_axis_map);

backends/vulkan/runtime/graph/ops/impl/Embedding.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,7 @@ void add_embedding_legacy_node(
111111
// Push Constants
112112
{},
113113
// Specialization Constants
114-
{graph.hashed_layout_of(out),
115-
graph.hashed_layout_of(in),
116-
graph.hashed_layout_of(weight)},
114+
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)},
117115
// Resize Args
118116
{},
119117
// Resizing Logic

backends/vulkan/test/op_tests/cases.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,14 +1167,13 @@ def get_embedding_inputs():
11671167
Test(weight=[10, 9], indices=[[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]),
11681168
]
11691169

1170-
# Channels packed test cases currently fail on Mac, so they are not included.
1171-
# However the test case definition is kept for later debugging.
11721170
test_suite_cpack = VkTestSuite(
11731171
[tuple(tc) + (-1, "false", "false") for tc in test_cases]
11741172
)
11751173

11761174
test_suite_cpack.dtypes = ["at::kFloat"]
11771175
test_suite_cpack.layouts = ["utils::kChannelsPacked"]
1176+
test_suite_cpack.storage_types = ["utils::kBuffer", "utils::kTexture3D"]
11781177
test_suite_cpack.test_name_suffix = "cpacked"
11791178

11801179
test_suite_wpack = VkTestSuite(
@@ -1186,7 +1185,7 @@ def get_embedding_inputs():
11861185
test_suite_wpack.storage_types = ["utils::kBuffer", "utils::kTexture3D"]
11871186
test_suite_wpack.test_name_suffix = "wpacked"
11881187

1189-
return test_suite_wpack
1188+
return [test_suite_cpack, test_suite_wpack]
11901189

11911190

11921191
@register_test_suite("aten.gather.default")

0 commit comments

Comments
 (0)