Skip to content

Lowering custom op with a list of tensors input. #9641

@tlsdmstn56

Description

@tlsdmstn56

❓ Lowering custom op with a list of tensors input

Hi I want to export torch custom op with a list of tensor input but getting this error. The error message seems there is some support for a list of tensor input. Do I need to toggle on an option to lower this custom op? Or is it currently not supported?

# Repro script.
import torch
import torch.nn as nn
import torch_xla
import torch_xla.core.xla_model as xm
from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo

import torch
from torch.library import Library, impl, register_fake

# Create a new custom namespace
my_lib = Library("my_ops", "FRAGMENT")

# Define a custom operator with a list of tensors as input
my_lib.define("a_func(Tensor[] inputs) -> Tensor")

@impl(f"{my_lib.ns}::a_func", "default")
def layer_id_op(t):
    return t[0]

@register_fake(f"{my_lib.ns}::a_func")
def layer_id_op_meta(t):
    return t[0]

from torch.library import impl

# Use the custom op in a callable
class AModule(nn.Module):
    def forward(self, inputs: list[torch.Tensor]):
        return torch.ops.my_ops.a_func(inputs)

x1 = torch.randn(2, 3)
x2 = torch.randn(2, 3)
x3 = torch.randn(2, 3)
inputs = [x1, x2, x3]


# Export to ExportedProgram
ep = export(AModule(), (inputs,))

# Export to StableHLO
try:
    a = exported_program_to_stablehlo(ep, torch_xla.stablehlo.StableHLOExportOptions(
                    custom_ops_allowed_in_graph={"my_ops"}
                ),)
except:
    import ipdb
    ipdb.post_mortem()
breakpoint()

Error:

Traceback (most recent call last):
  File "/test/xla_test.py", line 42, in <module>
    a = exported_program_to_stablehlo(ep, torch_xla.stablehlo.StableHLOExportOptions(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 629, in exported_program_to_stablehlo
    bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 373, in _exported_program_to_stablehlo_bundle
    res = xla_interpreter.run(*_flat_input_args, enable_io_processing=False)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 171, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 284, in run_node
    res = super().run_node(n)
          ^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 240, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 265, in call_function
    return super().call_function(target, args, new_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 320, in call_function
    return target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/test-venv/lib/python3.11/site-packages/torch_xla/experimental/stablehlo_custom_call.py", line 16, in stablehlo_custom_call
    res = torch_xla._XLAC._xla_custom_call(args, call_target, output_shapes,
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: _xla_custom_call(): incompatible function arguments. The following argument types are supported:
    1. (arg0: list[torch.Tensor], arg1: str, arg2: list[list[int]], arg3: list[object], arg4: bool, arg5: str, arg6: int, arg7: dict[str, str]) -> list[torch.Tensor]

Invoked with: ([tensor([[-1.6685, -1.7139, -2.0131],
        [-0.1968,  0.8239,  1.9200]], device='xla:0'), tensor([[-0.3895,  0.7437, -0.5434],
        [ 0.0462,  0.1798, -3.2928]], device='xla:0'), tensor([[-0.1212,  0.4554, -0.2454],
        [ 0.1532, -1.0724,  1.1181]], device='xla:0')],), 'my_ops.a_func.default', [(2, 3)], [torch.float32], False, '', 0, {}

While executing %a_func : [num_users=1] = call_function[target=torch_xla.experimental.stablehlo_custom_call.stablehlo_custom_call](args = (([%inputs_0, %inputs_1, %inputs_2],), my_ops.a_func.default, [(2, 3)], [tor
ch.float32]), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
    def forward(self, inputs_0: "f32[2, 3][3, 1]", inputs_1: "f32[2, 3][3, 1]", inputs_2: "f32[2, 3][3, 1]"):
         # File: /test/xla_test.py:30 in forward, code: return torch.ops.my_ops.a_func(inputs)
        a_func: "f32[2, 3][3, 1]" = torch_xla_experimental_stablehlo_custom_call_stablehlo_custom_call(([inputs_0, inputs_1, inputs_2],), 'my_ops.a_func.default', [(2, 3)], [torch.float32]);  inputs_0 = inputs_1 =
inputs_2 = None
        return (a_func,)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions