fix: wrong dtype and device in aten.full_like
decomposition
#3535
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR addresses a bug in the Torch-TensorRT decomposition of
torch.ops.aten.full_like
.In the current implementation, the decomposition incorrectly overrides the
dtype
anddevice
arguments, ignoring explicitly setdtype
values and assigning all tensors to thedefault_device()
(typicallycuda:0
), regardless of the inputs' actual device.Specifically, the issue occurs in the following decomposition function:
TensorRT/py/torch_tensorrt/dynamo/lowering/_decompositions.py
Lines 428 to 437 in 5c37931
This implementation causes two main issues:
torch.full_like(..., dtype=torch.bool)
is used in the model, the decomposition overwrites thedtype
withinput.dtype
(e.g.,float16
), resulting in an incorrect output type.cuda:1
(e.g.,cuda:1
), the decomposition forces outputs to be oncuda:0
, causing runtime errors or silent bugs due to device mismatch.To demonstrate the issue, the following test cases are included in this PR:
Results:
test1
: Verifies thattorch.ones_like
returns a tensor with the correct dtype.test2
: Shows that the exported model viatorch.export(...).run_decompositions(...)
fails to preservedtype
.test3
: Demonstrates the incrroectdevice
assignment after decomposition when using non-default CUDA devices.This PR fixes the decomposition logic to correctly respect the explicitly passed
dtype
anddevice
values, or fall back to those inferred from the input tensor only if not explicitly provided.Type of change
Checklist: