Skip to content

Commit eb6feda

Browse files
committed
fix: wrong dtype and device in full_like decomposition
1 parent 5c37931 commit eb6feda

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
_get_decomp_for_cia,
1010
)
1111
from torch._ops import OpOverload
12+
1213
from torch_tensorrt.dynamo._defaults import default_device
1314
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
1415
from torch_tensorrt.dynamo.utils import to_torch_device
@@ -432,8 +433,8 @@ def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
432433
input = args[0]
433434
shape = args[0].shape
434435
fill_value = args[1]
435-
kwargs["dtype"] = input.dtype
436-
kwargs["device"] = to_torch_device(default_device())
436+
kwargs["dtype"] = kwargs.get("dtype", None) or input.dtype
437+
kwargs["device"] = input.device
437438
return torch.full(shape, fill_value, dtype=kwargs["dtype"], device=kwargs["device"])
438439

439440

0 commit comments

Comments
 (0)