-
Notifications
You must be signed in to change notification settings - Fork 563
Open
Description
❓ Lowering custom op with a list of tensors input
Hi I want to export torch custom op with a list of tensor input but getting this error. The error message seems there is some support for a list of tensor input. Do I need to toggle on an option to lower this custom op? Or is it currently not supported?
# Repro script.
import torch
import torch.nn as nn
import torch_xla
import torch_xla.core.xla_model as xm
from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo
import torch
from torch.library import Library, impl, register_fake
# Create a new custom namespace
my_lib = Library("my_ops", "FRAGMENT")
# Define a custom operator with a list of tensors as input
my_lib.define("a_func(Tensor[] inputs) -> Tensor")
@impl(f"{my_lib.ns}::a_func", "default")
def layer_id_op(t):
return t[0]
@register_fake(f"{my_lib.ns}::a_func")
def layer_id_op_meta(t):
return t[0]
from torch.library import impl
# Use the custom op in a callable
class AModule(nn.Module):
def forward(self, inputs: list[torch.Tensor]):
return torch.ops.my_ops.a_func(inputs)
x1 = torch.randn(2, 3)
x2 = torch.randn(2, 3)
x3 = torch.randn(2, 3)
inputs = [x1, x2, x3]
# Export to ExportedProgram
ep = export(AModule(), (inputs,))
# Export to StableHLO
try:
a = exported_program_to_stablehlo(ep, torch_xla.stablehlo.StableHLOExportOptions(
custom_ops_allowed_in_graph={"my_ops"}
),)
except:
import ipdb
ipdb.post_mortem()
breakpoint()
Error:
Traceback (most recent call last):
File "/test/xla_test.py", line 42, in <module>
a = exported_program_to_stablehlo(ep, torch_xla.stablehlo.StableHLOExportOptions(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 629, in exported_program_to_stablehlo
bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 373, in _exported_program_to_stablehlo_bundle
res = xla_interpreter.run(*_flat_input_args, enable_io_processing=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 171, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 284, in run_node
res = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 240, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 265, in call_function
return super().call_function(target, args, new_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 320, in call_function
return target(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch_xla/experimental/stablehlo_custom_call.py", line 16, in stablehlo_custom_call
res = torch_xla._XLAC._xla_custom_call(args, call_target, output_shapes,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: _xla_custom_call(): incompatible function arguments. The following argument types are supported:
1. (arg0: list[torch.Tensor], arg1: str, arg2: list[list[int]], arg3: list[object], arg4: bool, arg5: str, arg6: int, arg7: dict[str, str]) -> list[torch.Tensor]
Invoked with: ([tensor([[-1.6685, -1.7139, -2.0131],
[-0.1968, 0.8239, 1.9200]], device='xla:0'), tensor([[-0.3895, 0.7437, -0.5434],
[ 0.0462, 0.1798, -3.2928]], device='xla:0'), tensor([[-0.1212, 0.4554, -0.2454],
[ 0.1532, -1.0724, 1.1181]], device='xla:0')],), 'my_ops.a_func.default', [(2, 3)], [torch.float32], False, '', 0, {}
While executing %a_func : [num_users=1] = call_function[target=torch_xla.experimental.stablehlo_custom_call.stablehlo_custom_call](args = (([%inputs_0, %inputs_1, %inputs_2],), my_ops.a_func.default, [(2, 3)], [tor
ch.float32]), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
def forward(self, inputs_0: "f32[2, 3][3, 1]", inputs_1: "f32[2, 3][3, 1]", inputs_2: "f32[2, 3][3, 1]"):
# File: /test/xla_test.py:30 in forward, code: return torch.ops.my_ops.a_func(inputs)
a_func: "f32[2, 3][3, 1]" = torch_xla_experimental_stablehlo_custom_call_stablehlo_custom_call(([inputs_0, inputs_1, inputs_2],), 'my_ops.a_func.default', [(2, 3)], [torch.float32]); inputs_0 = inputs_1 =
inputs_2 = None
return (a_func,)
Metadata
Metadata
Assignees
Labels
No labels