Skip to content

Commit 479696b

Browse files
junjiang-labcopybara-github
authored andcommitted
Add impl for aten.bitwise_and.Tensor and lowering.
PiperOrigin-RevId: 762031114
1 parent d76075b commit 479696b

File tree

4 files changed

+33
-0
lines changed

4 files changed

+33
-0
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,20 @@ def _aten_pow_tensor_tensor_decomp(x, y):
9696
return torch.ops.tfl.pow(x, y)
9797

9898

99+
@register_decomp(torch.ops.aten.bitwise_and.Tensor)
100+
def _aten_bitwise_and_tensor_decomp(x, y):
101+
if not (
102+
isinstance(x, torch.Tensor)
103+
and x.dtype == torch.bool
104+
and isinstance(y, torch.Tensor)
105+
and y.dtype == torch.bool
106+
):
107+
raise TypeError(
108+
"Input tensors for aten.bitwise_and only supports bool for now."
109+
)
110+
return torch.ops.tfl.logical_and(x, y)
111+
112+
99113
@register_decomp(torch.ops.aten.gt.Tensor)
100114
def _aten_gt_tensor_decomp(x, y):
101115
return torch.ops.tfl.greater(x, y)

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,19 @@ def _tfl_pow_lowering(
164164
)
165165

166166

167+
@lower(torch.ops.tfl.logical_and.default)
168+
def _tfl_logical_and_lowering(
169+
lctx: LoweringContext,
170+
lhs: ir.Value,
171+
rhs: ir.Value,
172+
) -> ir.Value:
173+
return _ir_operation(
174+
"tfl.logical_and",
175+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
176+
operands=[lhs, rhs],
177+
)
178+
179+
167180
@lower(torch.ops.tfl.greater.default)
168181
def _tfl_greater_lowering(
169182
lctx: LoweringContext,

ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ def tfl_pow(x: Any, y: Any) -> torch.Tensor:
6363
return torch.pow(x, y)
6464

6565

66+
@custom_op_with_fake("tfl::logical_and")
67+
def tfl_logical_and(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
68+
return torch.logical_and(x, y)
69+
70+
6671
@custom_op_with_fake("tfl::greater")
6772
def tfl_greater(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
6873
return torch.gt(x, y)

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def _assert_export_and_close(
133133
("aten_pow_Scalar_0", torch.ops.aten.pow.Scalar, (np.random.rand(), rnd(torch.float32, (10, 10)),), dict()),
134134
("aten_pow_Tensor_Scalar_0", torch.ops.aten.pow.Tensor_Scalar, (rnd(torch.float32, (10, 10)), np.random.rand(),), dict()),
135135
("aten_pow_Tensor_Tensor_0", torch.ops.aten.pow.Tensor_Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
136+
("aten_bitwise_and_Tensor_0", torch.ops.aten.bitwise_and.Tensor, (rnd(torch.bool, (10, 10)), rnd(torch.bool, (10, 10)),), dict()),
136137
("aten_gt_Tensor_0", torch.ops.aten.gt.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
137138
("aten_gt_Tensor_1", torch.ops.aten.gt.Tensor, (rnd(torch.float32, (1, 10)), rnd(torch.float32, (10, 1)),), dict()),
138139
("aten_lt_Tensor_0", torch.ops.aten.lt.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),

0 commit comments

Comments
 (0)