Skip to content

Commit 9aeceee

Browse files
Yujie Huifacebook-github-bot
authored andcommitted
Implement grid_priors op (#4440)
Summary: Pull Request resolved: #4440 Modify the spec of customized op `grid_prirors` to take a tensor as input. Compared to previous definition, the `height` and `width` arguments will be determined by the input tensor as `height, width = self.shape[-2:]`. The reason we change the spec is: if we want to support dynamic shape, the input should be a tensor. Implement customized op `grid_priors`. This op is used to generate mapped x,y points from different level feature map to original images. Op spec: ``` (Tensor self, int stride, float offset) -> Tensor ``` Example: ``` input_tensor = torch.rand(size = [1, 5, 2, 3]) stride = 8 offset = 0.5 output.shape = [3x2, 2] output = tensor([[ 4., 4.], [12., 4.], [20., 4.], [ 4., 12.], [12., 12.], [20., 12.]]) ``` Add smoke test for now due to some issue to lower customized op to Vulkan backend. Will add unit test and nn.Module test when be able to lower customized op from PyTorch to Vulkan backend. bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: copyrightly Differential Revision: D60203196 fbshipit-source-id: 93e5180e80e07cc0b9acb50890a1187ce0f82951
1 parent 69f3f1c commit 9aeceee

File tree

6 files changed

+210
-8
lines changed

6 files changed

+210
-8
lines changed

backends/vulkan/passes/custom_ops_defs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ def conv_with_clamp_impl(
4949

5050

5151
def grid_priors_impl(
52-
height,
53-
width,
52+
x,
5453
stride,
5554
offset,
5655
):
56+
height, width = x.shape[-2:]
5757
shift_x = (torch.arange(0, width) + offset) * stride
5858
shift_y = (torch.arange(0, height) + offset) * stride
5959
shift_xx, shift_yy = torch.meshgrid(shift_y, shift_x)
@@ -64,6 +64,6 @@ def grid_priors_impl(
6464

6565

6666
name = "grid_priors"
67-
lib.define(f"{name}(int height, int width, int stride, float offset) -> Tensor")
68-
lib.impl(name, grid_priors_impl)
67+
lib.define(f"{name}(Tensor self, int stride, float offset) -> Tensor")
68+
lib.impl(name, grid_priors_impl, "CompositeExplicitAutograd")
6969
grid_priors_op = getattr(getattr(torch.ops, namespace), name)

backends/vulkan/passes/test_custom_ops.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,15 @@ class GridPriors(torch.nn.Module):
9797
def __init__(self):
9898
super().__init__()
9999

100-
def forward(self, height, width, stride, offset):
101-
return torch.ops.et_vk.grid_priors(height, width, stride, offset)
100+
def forward(self, x, stride, offset):
101+
return torch.ops.et_vk.grid_priors(x, stride, offset)
102102

103103
model = GridPriors()
104-
sample_input = (2, 3, 4, 0.5)
104+
sample_input = (torch.rand(2, 5, 2, 3), 4, 0.5)
105105
custom_out = model(*sample_input)
106106

107-
def calculate_expected_output(height, width, stride, offset):
107+
def calculate_expected_output(x, stride, offset):
108+
height, width = x.shape[-2:]
108109
shift_x = (torch.arange(0, width) + offset) * stride
109110
shift_y = (torch.arange(0, height) + offset) * stride
110111
shift_xx, shift_yy = torch.meshgrid(shift_y, shift_x)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#version 450 core
2+
3+
#define PRECISION ${PRECISION}
4+
5+
#define VEC4_T ${texel_type(DTYPE)}
6+
7+
layout(std430) buffer;
8+
9+
#include "indexing_utils.h"
10+
11+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
12+
${layout_declare_ubo(1, "ivec4", "in_sizes")}
13+
${layout_declare_ubo(2, "ivec4", "out_sizes")}
14+
${layout_declare_ubo(3, "int", "stride", "float", "offset")}
15+
16+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
17+
18+
layout(constant_id = 3) const int packed_dim = C_DIM;
19+
20+
void main() {
21+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
22+
const ivec4 idx = to_tensor_idx(pos, out_sizes, packed_dim);
23+
24+
if (pos_out_of_bounds(pos, out_sizes, packed_dim)) {
25+
return;
26+
}
27+
int width = in_sizes.x;
28+
VEC4_T outtex;
29+
if (pos.x == 0) {
30+
float value = (pos.y % width + offset) * stride;
31+
outtex = VEC4_T(value, 0, 0, 0);
32+
} else if (pos.x == 1) {
33+
float value = (pos.y / width + offset) * stride;
34+
outtex = VEC4_T(value, 0, 0, 0);
35+
}
36+
37+
imageStore(t_out, pos, outtex);
38+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
grid_priors:
2+
parameter_names_with_default_values:
3+
NDIM: 3
4+
DTYPE: float
5+
PACKING: C_packed
6+
STORAGE: texture3d
7+
generate_variant_forall:
8+
DTYPE:
9+
- VALUE: half
10+
- VALUE: float
11+
shader_variants:
12+
- NAME: grid_priors
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
14+
15+
namespace vkcompute {
16+
17+
struct GridPriorsParam final {
18+
int32_t stride;
19+
float offset;
20+
};
21+
22+
void resize_grid_priors_node(
23+
ComputeGraph* graph,
24+
const std::vector<ArgGroup>& args,
25+
const std::vector<ValueRef>& extra_args) {
26+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
27+
vTensorPtr in = graph->get_tensor(extra_args[0]);
28+
std::vector<int64_t> in_sizes = in->sizes();
29+
int64_t height = in_sizes.at(in_sizes.size() - 2);
30+
int64_t width = in_sizes.at(in_sizes.size() - 1);
31+
std::vector<int64_t> sizes = {height * width, 2};
32+
out->virtual_resize(sizes);
33+
}
34+
35+
void add_grid_priors_node(
36+
ComputeGraph& graph,
37+
const ValueRef& in,
38+
const ValueRef& stride_ref,
39+
const ValueRef& offset_ref,
40+
const ValueRef& out) {
41+
vTensorPtr t_out = graph.get_tensor(out);
42+
vTensorPtr t_in = graph.get_tensor(in);
43+
int32_t stride = graph.extract_scalar<int32_t>(stride_ref);
44+
float offset = graph.extract_scalar<float>(offset_ref);
45+
46+
std::string kernel_name = "grid_priors";
47+
kernel_name.reserve(kShaderNameReserve);
48+
add_dtype_suffix(kernel_name, *t_out);
49+
50+
GridPriorsParam param = {stride, offset};
51+
graph.execute_nodes().emplace_back(new ExecuteNode(
52+
graph,
53+
VK_KERNEL_FROM_STR(kernel_name),
54+
graph.create_global_wg_size(out),
55+
graph.create_local_wg_size(out),
56+
// Inputs and Outputs
57+
{
58+
{out, vkapi::MemoryAccessType::WRITE},
59+
},
60+
// Shader params buffers
61+
{
62+
t_in->sizes_ubo(),
63+
t_out->sizes_ubo(),
64+
graph.create_params_buffer(param),
65+
},
66+
// Specialization Constants
67+
{},
68+
resize_grid_priors_node,
69+
{in}));
70+
}
71+
72+
void grid_priors(ComputeGraph& graph, const std::vector<ValueRef>& args) {
73+
return add_grid_priors_node(graph, args[0], args[1], args[2], args[3]);
74+
}
75+
76+
REGISTER_OPERATORS {
77+
VK_REGISTER_OP(grid_priors.default, grid_priors);
78+
}
79+
} // namespace vkcompute

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2203,3 +2203,75 @@ TEST(VulkanComputeGraphOpsTest, conv2d_prepack_test) {
22032203
0, 3, 9, 0, 0, 6, 12, 0, 0, 5, 11,
22042204
0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
22052205
}
2206+
2207+
void test_grid_priors(
2208+
std::vector<int64_t> input_sizes,
2209+
std::vector<int64_t> output_sizes,
2210+
int stride,
2211+
double offset,
2212+
const std::vector<float>& data_out_expected) {
2213+
GraphConfig config;
2214+
ComputeGraph graph(config);
2215+
2216+
// Build graph
2217+
IOValueRef in = graph.add_input_tensor(
2218+
input_sizes,
2219+
vkapi::kFloat,
2220+
utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED);
2221+
IOValueRef out;
2222+
out.value = graph.add_tensor(
2223+
output_sizes,
2224+
vkapi::kFloat,
2225+
utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED);
2226+
2227+
VK_GET_OP_FN("grid_priors.default")
2228+
(graph,
2229+
{in.value,
2230+
graph.add_scalar<int64_t>(stride),
2231+
graph.add_scalar<double>(offset),
2232+
out.value});
2233+
2234+
out.staging = graph.set_output_tensor(out.value);
2235+
2236+
graph.prepare();
2237+
graph.encode_prepack();
2238+
graph.prepack();
2239+
graph.encode_execute();
2240+
2241+
vTensorPtr t_in = graph.get_tensor(in.value);
2242+
vTensorPtr t_out = graph.get_tensor(out.value);
2243+
// Resize input
2244+
graph.propagate_resize();
2245+
2246+
// run graph
2247+
graph.execute();
2248+
2249+
std::vector<float> output_data(t_out->gpu_numel());
2250+
graph.copy_from_staging(out.staging, output_data.data(), output_data.size());
2251+
2252+
// check results
2253+
int h_out = utils::val_at(-2, t_out->sizes());
2254+
int w_out = utils::val_at(-1, t_out->sizes());
2255+
for (size_t i = 0; i < h_out; ++i) {
2256+
for (size_t j = 0; j < w_out; ++j) {
2257+
size_t idx_out = i * w_out + j;
2258+
CHECK_VALUE(output_data, idx_out, data_out_expected[idx_out]);
2259+
}
2260+
}
2261+
}
2262+
2263+
TEST(VulkanComputeGraphOpsTest, grid_priors_test) {
2264+
test_grid_priors(
2265+
/*input size = */ {1, 5, 2, 3},
2266+
/*output size = */ {6, 2},
2267+
/*stride = */ 1,
2268+
/*offset = */ 0.0,
2269+
/*data_out_expected = */ {0, 0, 1, 0, 2, 0, 0, 1, 1, 1, 2, 1});
2270+
2271+
test_grid_priors(
2272+
/*input size = */ {1, 5, 2, 3},
2273+
/*output size = */ {6, 2},
2274+
/*stride = */ 8,
2275+
/*offset = */ 0.5,
2276+
/*data_out_expected = */ {4, 4, 12, 4, 20, 4, 4, 12, 12, 12, 20, 12});
2277+
}

0 commit comments

Comments
 (0)