Skip to content

Commit a0618c8

Browse files
Binary Comparison Ops
Differential Revision: D76049244 Pull Request resolved: pytorch#12198
1 parent fc435fa commit a0618c8

File tree

9 files changed

+198
-21
lines changed

9 files changed

+198
-21
lines changed

.lintrunner.toml

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ exclude_patterns = [
1010
'exir/serde/**',
1111
]
1212
command = [
13-
'python3',
13+
'python',
1414
'-m',
1515
'lintrunner_adapters',
1616
'run',
@@ -19,7 +19,7 @@ command = [
1919
'@{{PATHSFILE}}'
2020
]
2121
init_command = [
22-
'python3',
22+
'python',
2323
'-m',
2424
'lintrunner_adapters',
2525
'run',
@@ -41,7 +41,7 @@ exclude_patterns = [
4141
'exir/serde/**',
4242
]
4343
command = [
44-
'python3',
44+
'python',
4545
'-m',
4646
'lintrunner_adapters',
4747
'run',
@@ -50,7 +50,7 @@ command = [
5050
'@{{PATHSFILE}}'
5151
]
5252
init_command = [
53-
'python3',
53+
'python',
5454
'-m',
5555
'lintrunner_adapters',
5656
'run',
@@ -84,7 +84,7 @@ exclude_patterns = [
8484
'runtime/core/portable_type/c10/**',
8585
]
8686
command = [
87-
'python3',
87+
'python',
8888
'-m',
8989
'lintrunner_adapters',
9090
'run',
@@ -95,7 +95,7 @@ command = [
9595
'@{{PATHSFILE}}'
9696
]
9797
init_command = [
98-
'python3',
98+
'python',
9999
'-m',
100100
'lintrunner_adapters',
101101
'run',
@@ -117,7 +117,7 @@ exclude_patterns = [
117117
'**/third-party/**',
118118
]
119119
command = [
120-
'python3',
120+
'python',
121121
'-m',
122122
'lintrunner_adapters',
123123
'run',
@@ -127,7 +127,7 @@ command = [
127127
'@{{PATHSFILE}}',
128128
]
129129
init_command = [
130-
'python3',
130+
'python',
131131
'-m',
132132
'lintrunner_adapters',
133133
'run',
@@ -151,7 +151,7 @@ exclude_patterns = [
151151
'**/third-party/**',
152152
]
153153
command = [
154-
'python3',
154+
'python',
155155
'-m',
156156
'lintrunner_adapters',
157157
'run',
@@ -192,7 +192,7 @@ exclude_patterns = [
192192
'extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_special_unstrided_cpu.h',
193193
]
194194
command = [
195-
'python3',
195+
'python',
196196
'-m',
197197
'lintrunner_adapters',
198198
'run',
@@ -234,7 +234,7 @@ exclude_patterns = [
234234
'util/**',
235235
]
236236
command = [
237-
'python3',
237+
'python',
238238
'-m',
239239
'lintrunner_adapters',
240240
'run',
@@ -287,7 +287,7 @@ exclude_patterns = [
287287
'util/**',
288288
]
289289
command = [
290-
'python3',
290+
'python',
291291
'-m',
292292
'lintrunner_adapters',
293293
'run',
@@ -337,7 +337,7 @@ exclude_patterns = [
337337
'backends/arm/test/**',
338338
]
339339
command = [
340-
'python3',
340+
'python',
341341
'-m',
342342
'lintrunner_adapters',
343343
'run',
@@ -349,7 +349,7 @@ command = [
349349
'@{{PATHSFILE}}'
350350
]
351351
init_command = [
352-
'python3',
352+
'python',
353353
'-m',
354354
'lintrunner_adapters',
355355
'run',
@@ -368,7 +368,7 @@ exclude_patterns = [
368368
'.lintrunner.toml',
369369
]
370370
command = [
371-
'python3',
371+
'python',
372372
'-m',
373373
'lintrunner_adapters',
374374
'run',
@@ -397,7 +397,7 @@ exclude_patterns = [
397397
]
398398

399399
command = [
400-
"python3",
400+
"python",
401401
"-m",
402402
"lintrunner_adapters",
403403
"run",

backends/vulkan/op_registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,11 @@ def register_ephemeral_op(features: OpFeatures):
259259
exir_ops.edge.aten.div.Tensor,
260260
exir_ops.edge.aten.div.Tensor_mode,
261261
exir_ops.edge.aten.pow.Tensor_Tensor,
262+
exir_ops.edge.aten.eq.Tensor,
263+
exir_ops.edge.aten.lt.Tensor,
264+
exir_ops.edge.aten.le.Tensor,
265+
exir_ops.edge.aten.gt.Tensor,
266+
exir_ops.edge.aten.ge.Tensor,
262267
]
263268
)
264269
def register_binary_op(features: OpFeatures):

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,9 +728,16 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
728728
)
729729

730730
for variant in params_dict["shader_variants"]:
731+
default_iterated_params_names = set(
732+
default_iterated_params.keys()
733+
if default_iterated_params is not None
734+
else {}
735+
)
731736
variant_params_names = set(variant.keys())
737+
732738
invalid_keys = (
733739
variant_params_names
740+
- default_iterated_params_names
734741
- params_names
735742
- {"generate_variant_forall"}
736743
)
@@ -758,6 +765,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
758765
variant_name = f"{variant_name}_{param_value[1]}"
759766

760767
default_params_copy["NAME"] = variant_name
768+
default_params_copy["VARIANT_NAME"] = variant["NAME"]
761769

762770
self.shader_template_params[template_name].append(
763771
default_params_copy

backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,35 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
// Binary comparison ops require that the output is boolean and not the same as input.
14+
$IS_COMPARISON_OP = (any([name in VARIANT_NAME for name in ["binary_eq", "binary_lt", "binary_le", "binary_gt", "binary_ge"]]))
15+
16+
#define NAME ${VARIANT_NAME}
17+
1318
#define VEC4_T ${texel_type(DTYPE)}
14-
#define T ${buffer_scalar_type(DTYPE)}
19+
$if IS_COMPARISON_OP:
20+
#define T ${buffer_scalar_type("uint8")}
21+
#define VEC4_OUT_T ${texel_type("uint8")}
22+
$else:
23+
#define T ${buffer_scalar_type(DTYPE)}
24+
#define VEC4_OUT_T VEC4_T
1525

1626
#define op(X, Y, A) ${OPERATOR}
1727

1828
${define_active_storage_type(STORAGE)}
1929
${define_required_extensions(DTYPE)}
2030

31+
32+
$if IS_COMPARISON_OP:
33+
${define_required_extensions("uint8")}
34+
2135
layout(std430) buffer;
2236

23-
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
37+
$if IS_COMPARISON_OP:
38+
${layout_declare_tensor(B, "w", "t_out", "uint8", STORAGE)}
39+
$else:
40+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
41+
2442
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
2543
${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)}
2644

@@ -121,7 +139,7 @@ void main() {
121139
write_texel_lpos(
122140
t_out,
123141
lpos,
124-
VEC4_T(op(in_texel, other_texel, alpha)),
142+
VEC4_OUT_T(op(in_texel, other_texel, alpha)),
125143
out_axis_map);
126144
}
127145

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,84 @@ binary_op:
3232
OPERATOR: floor(X / Y)
3333
- NAME: binary_minimum
3434
OPERATOR: min(X, Y)
35+
- NAME: binary_eq_int32
36+
OPERATOR: X == Y
37+
DTYPE: int32
38+
- NAME: binary_eq_buffer
39+
OPERATOR: abs(X - Y) < 1e-5
40+
STORAGE: buffer
41+
generate_variant_forall:
42+
DTYPE:
43+
- VALUE: half
44+
- VALUE: float
45+
- NAME: binary_eq_texture3d
46+
OPERATOR: all(lessThanEqual(abs(X - Y), VEC4_T(1e-5)))
47+
STORAGE: texture3d
48+
generate_variant_forall:
49+
DTYPE:
50+
- VALUE: half
51+
- VALUE: float
52+
- NAME: binary_lt_buffer
53+
OPERATOR: X < Y
54+
STORAGE: buffer
55+
generate_variant_forall:
56+
DTYPE:
57+
- VALUE: half
58+
- VALUE: float
59+
- VALUE: int32
60+
- NAME: binary_lt_texture3d
61+
OPERATOR: all(lessThan(X, Y))
62+
STORAGE: texture3d
63+
generate_variant_forall:
64+
DTYPE:
65+
- VALUE: half
66+
- VALUE: float
67+
- VALUE: int32
68+
- NAME: binary_le_buffer
69+
OPERATOR: X <= Y
70+
STORAGE: buffer
71+
generate_variant_forall:
72+
DTYPE:
73+
- VALUE: half
74+
- VALUE: float
75+
- VALUE: int32
76+
- NAME: binary_le_texture3d
77+
OPERATOR: all(lessThanEqual(X, Y))
78+
STORAGE: texture3d
79+
generate_variant_forall:
80+
DTYPE:
81+
- VALUE: half
82+
- VALUE: float
83+
- VALUE: int32
84+
- NAME: binary_gt_buffer
85+
OPERATOR: X > Y
86+
STORAGE: buffer
87+
generate_variant_forall:
88+
DTYPE:
89+
- VALUE: half
90+
- VALUE: float
91+
- VALUE: int32
92+
- NAME: binary_gt_texture3d
93+
OPERATOR: all(greaterThan(X, Y))
94+
STORAGE: texture3d
95+
generate_variant_forall:
96+
DTYPE:
97+
- VALUE: half
98+
- VALUE: float
99+
- VALUE: int32
100+
- NAME: binary_ge_buffer
101+
OPERATOR: X >= Y
102+
STORAGE: buffer
103+
generate_variant_forall:
104+
DTYPE:
105+
- VALUE: half
106+
- VALUE: float
107+
- VALUE: int32
108+
- NAME: binary_ge_texture3d
109+
OPERATOR: all(greaterThanEqual(X, Y))
110+
STORAGE: texture3d
111+
generate_variant_forall:
112+
DTYPE:
113+
- VALUE: half
114+
- VALUE: float
115+
- VALUE: int32

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ void add_binary_op_texture_node(
7777
kernel_name.reserve(kShaderNameReserve);
7878
kernel_name += op_name;
7979
add_storage_type_suffix(kernel_name, *t_out);
80-
add_dtype_suffix(kernel_name, *t_out);
80+
add_dtype_suffix(kernel_name, graph.dtype_of(in1));
8181

8282
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
8383
graph,
@@ -121,7 +121,8 @@ void add_binary_op_buffer_node(
121121
kernel_name.reserve(kShaderNameReserve);
122122
kernel_name += op_name;
123123
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
124-
add_dtype_suffix(kernel_name, graph.dtype_of(out));
124+
125+
add_dtype_suffix(kernel_name, graph.dtype_of(in1));
125126

126127
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
127128
graph,
@@ -189,6 +190,11 @@ DEFINE_BINARY_OP_FN(mul);
189190
DEFINE_BINARY_OP_FN(div);
190191
DEFINE_BINARY_OP_FN(pow);
191192
DEFINE_BINARY_OP_FN(minimum);
193+
DEFINE_BINARY_OP_FN(eq);
194+
DEFINE_BINARY_OP_FN(lt);
195+
DEFINE_BINARY_OP_FN(le);
196+
DEFINE_BINARY_OP_FN(gt);
197+
DEFINE_BINARY_OP_FN(ge);
192198

193199
REGISTER_OPERATORS {
194200
VK_REGISTER_OP(aten.add.Tensor, add);
@@ -198,6 +204,11 @@ REGISTER_OPERATORS {
198204
VK_REGISTER_OP(aten.div.Tensor_mode, floor_divide);
199205
VK_REGISTER_OP(aten.pow.Tensor_Tensor, pow);
200206
VK_REGISTER_OP(aten.minimum.default, minimum);
207+
VK_REGISTER_OP(aten.eq.Tensor, eq);
208+
VK_REGISTER_OP(aten.lt.Tensor, lt);
209+
VK_REGISTER_OP(aten.le.Tensor, le);
210+
VK_REGISTER_OP(aten.gt.Tensor, gt);
211+
VK_REGISTER_OP(aten.ge.Tensor, ge);
201212
}
202213

203214
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,42 @@ def get_binary_elementwise_inputs():
6363
"utils::kBuffer",
6464
"utils::kTexture3D",
6565
]
66+
67+
return test_suite
68+
69+
70+
# Eq requires a different test generator so it was split from the other test case.
71+
@register_test_suite(
72+
[
73+
"aten.eq.Tensor",
74+
"aten.gt.Tensor",
75+
"aten.lt.Tensor",
76+
"aten.ge.Tensor",
77+
"aten.le.Tensor",
78+
]
79+
)
80+
def get_binary_elementwise_compare_inputs():
81+
test_suite = VkTestSuite(
82+
[
83+
((M1, M2), (M1, M2)),
84+
((M1, M2), (M1, 1), 2.0),
85+
((M1, M2), (1, M2)),
86+
((S, S1, S2), (S, S1, S2)),
87+
((S, S1, S2), (S, S1, 1), 2.0),
88+
((S, S1, S2), (S, 1, S2), 2.0),
89+
((XS, S, S1, S2), (XS, S, 1, 1), 2.0),
90+
((3, 64, 1), (1, 64, 1)),
91+
]
92+
)
93+
test_suite.layouts = [
94+
"utils::kWidthPacked",
95+
"utils::kChannelsPacked",
96+
]
97+
test_suite.storage_types = [
98+
"utils::kBuffer",
99+
"utils::kTexture3D",
100+
]
101+
test_suite.data_gen = "make_casted_randint_tensor"
66102
return test_suite
67103

68104

0 commit comments

Comments
 (0)