diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index 825be75076..1e1eff527f 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -116,8 +116,6 @@ aten.select_scatter, aten.sgn, aten.sigmoid_backward, - aten.silu, - aten.silu_, aten.silu_backward, aten.sinc, aten.slice_backward, diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 4cead5d0cb..b0cfdee4f0 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -9,7 +9,6 @@ _get_decomp_for_cia, ) from torch._ops import OpOverload - from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim from torch_tensorrt.dynamo.utils import to_torch_device @@ -253,6 +252,16 @@ def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: # type: ignor ) +@register_torch_trt_decomposition( + torch.ops.aten.silu.default, registry=TORCH_TRT_DECOMPOSITIONS +) +@register_torch_trt_decomposition( + torch.ops.aten.silu_.default, registry=TORCH_TRT_DECOMPOSITIONS +) +def silu(self: torch.Tensor) -> torch.Tensor: + return self * torch.sigmoid(self) + + @register_torch_trt_decomposition( torch.ops.aten.scatter_add.default, registry=TORCH_TRT_DECOMPOSITIONS ) @@ -423,8 +432,8 @@ def instance_norm_decomposition( @register_torch_trt_decomposition( torch.ops.aten.full_like, registry=TORCH_TRT_DECOMPOSITIONS -) # type: ignore -def full_like_decomposition(*args, **kwargs) -> torch.Tensor: +) +def full_like_decomposition(*args: Any, **kwargs: Any) -> torch.Tensor: input = args[0] shape = args[0].shape fill_value = args[1]