Skip to content

Commit 2349049

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Modernize expand_copy
Modernize expand_copy to support ANY_STORAGE. Add buffer shader variant using BufferMetadata with indexing.glslh. Unify dispatch with add_storage_type_suffix and DynamicDispatchNode. Add resize function and symint support for dynamic target sizes. Pull Request resolved: pytorch#18053 ghstack-source-id: 353546690 @exported-using-ghexport Differential Revision: [D95970162](https://our.internmc.facebook.com/intern/diff/D95970162/)
1 parent 2bb94b1 commit 2349049

5 files changed

Lines changed: 121 additions & 11 deletions

File tree

backends/vulkan/op_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1120,7 +1120,7 @@ def register_gather():
11201120
@update_features(exir_ops.edge.aten.expand_copy.default)
11211121
def register_expand_copy():
11221122
return OpFeatures(
1123-
inputs_storage=utils.ANY_BUFFER,
1123+
inputs_storage=utils.ANY_STORAGE,
11241124
inputs_dtypes=utils.FP_INT_BOOL_T,
11251125
supports_resize=False,
11261126
supports_highdim=True,
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
${define_required_extensions("texture3d", DTYPE)}
12+
13+
#define PRECISION ${PRECISION}
14+
15+
#define VEC4_T ${texel_load_type(DTYPE, "texture3d")}
16+
#define T ${texel_load_component_type(DTYPE, "texture3d")}
17+
18+
${define_active_storage_type("texture3d")}
19+
20+
layout(std430) buffer;
21+
22+
#include "indexing.glslh"
23+
24+
${layout_declare_tensor(B, "w", "t_outp", DTYPE, "texture3d")}
25+
${layout_declare_tensor(B, "r", "t_inp", DTYPE, "texture3d")}
26+
27+
${layout_declare_ubo(B, "TextureMetadata", "outp")}
28+
${layout_declare_ubo(B, "TextureMetadata", "inp")}
29+
30+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
31+
32+
${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")}
33+
const int packed_dim = get_packed_dim(out_layout);
34+
35+
void main() {
36+
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
37+
38+
if (out_of_bounds(out_pos, outp)) {
39+
return;
40+
}
41+
42+
TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos);
43+
44+
VEC4_T out_texel = VEC4_T(0);
45+
46+
int limit = min(
47+
4, outp.sizes[packed_dim] - out_tidx.data[packed_dim]);
48+
for (int comp = 0; comp < 4; comp++) {
49+
if (comp >= limit) {
50+
break;
51+
}
52+
53+
// Map output tensor index to input tensor index using modulo
54+
TensorIndex4D inp_tidx;
55+
inp_tidx.data.x = out_tidx.data.x % inp.sizes.x;
56+
inp_tidx.data.y = out_tidx.data.y % inp.sizes.y;
57+
inp_tidx.data.z = out_tidx.data.z % inp.sizes.z;
58+
inp_tidx.data.w = out_tidx.data.w % inp.sizes.w;
59+
60+
TextureElementIndex inp_elem =
61+
tensor4d_idx_to_texture_element_idx_simple(inp, inp_tidx);
62+
63+
VEC4_T inp_texel = texelFetch(t_inp, inp_elem.pos, 0);
64+
out_texel[comp] = inp_texel[inp_elem.comp];
65+
66+
out_tidx.data[packed_dim]++;
67+
}
68+
69+
imageStore(t_outp, out_pos, out_texel);
70+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
expand_texture:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
generate_variant_forall:
5+
DTYPE:
6+
- VALUE: half
7+
- VALUE: float
8+
- VALUE: int32
9+
- VALUE: uint8
10+
shader_variants:
11+
- NAME: expand_texture3d

backends/vulkan/runtime/graph/ops/impl/Expand.cpp

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,39 @@
1616

1717
namespace vkcompute {
1818

19-
void add_expand_buffer_node(
19+
void resize_expand_node(
20+
ComputeGraph* graph,
21+
const std::vector<ArgGroup>& args,
22+
const std::vector<ValueRef>& extra_args) {
23+
const ValueRef in = args.at(1).refs.at(0);
24+
const ValueRef out = args.at(0).refs.at(0);
25+
const ValueRef size_ref = extra_args.at(0);
26+
27+
const std::vector<int64_t> in_sizes = graph->sizes_of(in);
28+
const std::vector<int64_t> target_sizes =
29+
graph->extract_int_or_symint_list(size_ref);
30+
31+
VK_CHECK_COND(
32+
target_sizes.size() >= in_sizes.size(),
33+
"expand: target sizes must have at least as many dims as input");
34+
VK_CHECK_COND(
35+
!target_sizes.empty(), "expand: target sizes must not be empty");
36+
37+
const size_t dim_offset = target_sizes.size() - in_sizes.size();
38+
std::vector<int64_t> out_sizes(target_sizes.size());
39+
for (size_t i = 0; i < target_sizes.size(); i++) {
40+
if (target_sizes[i] == -1 && i >= dim_offset) {
41+
out_sizes[i] = in_sizes[i - dim_offset];
42+
} else if (target_sizes[i] == -1) {
43+
out_sizes[i] = 1;
44+
} else {
45+
out_sizes[i] = target_sizes[i];
46+
}
47+
}
48+
graph->virtual_resize(out, out_sizes);
49+
}
50+
51+
void add_expand_node(
2052
ComputeGraph& graph,
2153
const ValueRef in,
2254
const ValueRef size,
@@ -27,8 +59,8 @@ void add_expand_buffer_node(
2759
add_dtype_suffix(kernel_name, graph.dtype_of(out));
2860

2961
vkapi::ParamsBindList param_buffers = {
30-
graph.buffer_meta_ubo(out),
31-
graph.buffer_meta_ubo(in),
62+
graph.meta_ubo(out),
63+
graph.meta_ubo(in),
3264
};
3365

3466
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
@@ -42,11 +74,11 @@ void add_expand_buffer_node(
4274
// Push Constants
4375
{},
4476
// Specialization Constants
45-
{},
77+
{graph.hashed_layout_of(out)},
4678
// Resize Args
4779
{size},
4880
// Resizing Logic
49-
nullptr));
81+
resize_expand_node));
5082
}
5183

5284
void expand(ComputeGraph& graph, const std::vector<ValueRef>& args) {
@@ -57,11 +89,7 @@ void expand(ComputeGraph& graph, const std::vector<ValueRef>& args) {
5789
(void)implicit;
5890
const ValueRef out = args.at(idx++);
5991

60-
if (graph.is_buffer_storage(out)) {
61-
return add_expand_buffer_node(graph, in, size, out);
62-
}
63-
64-
VK_THROW("Expand operator only supports buffer storage");
92+
add_expand_node(graph, in, size, out);
6593
}
6694

6795
REGISTER_OPERATORS {

backends/vulkan/test/op_tests/cases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2031,6 +2031,7 @@ def get_expand_inputs():
20312031
)
20322032
test_suite.storage_types = [
20332033
"utils::kBuffer",
2034+
"utils::kTexture3D",
20342035
]
20352036
test_suite.layouts = [
20362037
"utils::kWidthPacked",

0 commit comments

Comments
 (0)