Skip to content

Commit 435172c

Browse files
committed
update ReduceSum for the attention mask subgraph as well
1 parent 2d38382 commit 435172c

File tree

1 file changed

+3
-2
lines changed
  • src/python/py/models/builders

1 file changed

+3
-2
lines changed

src/python/py/models/builders/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4258,6 +4258,7 @@ def make_attention_mask_reformatting_for_sparse_attn(self):
42584258
# attention_mask
42594259
# / \
42604260
# ReduceSum Shape
4261+
# (keepdims=0) |
42614262
# | |
42624263
# Cast to int32 Gather
42634264
# | |
@@ -4272,9 +4273,9 @@ def make_attention_mask_reformatting_for_sparse_attn(self):
42724273
# Left path
42734274
reduce_sum_name = f"{attn_mask_basename}/ReduceSum"
42744275
reduce_sum_inputs = ["attention_mask", "/model/constants/INT64/[1]"]
4275-
self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size", 1])
4276+
self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size"], keepdims=False)
42764277
cast_1_name = f"{attn_mask_basename}/ReduceSum/Cast"
4277-
self.make_cast(cast_1_name, f"{reduce_sum_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size", 1])
4278+
self.make_cast(cast_1_name, f"{reduce_sum_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size"])
42784279

42794280
# Right path
42804281
shape_name = f"{attn_mask_basename}/Shape"

0 commit comments

Comments
 (0)