Skip to content

Commit a776b1f

Browse files
authored
Relax int4wo device mismatch error (#2254)
**Summary:** We have an guard preventing users from using a cuda quantized on cpu and vice versa. However, this also prevents users who load their checkpoints on cpu first and then move them to cuda later, which is what torchtune does: ``` quantize_(model.cuda(), Int4WeightOnlyConfig()) # save checkpoint in cuda torch.save(model.state_dict(), "my_checkpoint.pt") # load checkpoint on cpu # This is what torchtune does: https://github.com/pytorch/torchtune/blob/v0.6.1/torchtune/training/checkpointing/_utils.py#L253 sd = torch.load("my_checkpoint.pt", weights_only=False, map_location="cpu") # move checkpoint to cuda for k, v in sd.items(): sd[k] = v.to("cuda") # load state_dict in cuda model.load_state_dict(sd, assign=True) ``` This use case is safe in that the model was quantized in cuda and ultimately used on cuda. This commit relaxes the error to allow the above use case. More details here: #1117. **Test Plan:** python test/quantization/test_quant_api.py -k test_int4wo_cuda_serialization
1 parent 446f07d commit a776b1f

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

test/quantization/test_quant_api.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,25 @@ def test_ao_per_module_config_skip(self):
10171017
assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout)
10181018
assert not isinstance(model.linear2.weight, AffineQuantizedTensor)
10191019

1020+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
1021+
def test_int4wo_cuda_serialization(self):
1022+
config = Int4WeightOnlyConfig(group_size=32)
1023+
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
1024+
# quantize in cuda
1025+
quantize_(model, config)
1026+
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
1027+
model(*example_inputs)
1028+
with tempfile.NamedTemporaryFile() as ckpt:
1029+
# save checkpoint in cuda
1030+
torch.save(model.state_dict(), ckpt)
1031+
# load checkpoint on cpu then move checkpoint to cuda
1032+
# This is what torchtune does: https://github.com/pytorch/torchtune/blob/v0.6.1/torchtune/training/checkpointing/_utils.py#L253
1033+
sd = torch.load(ckpt.name, weights_only=False, map_location="cpu")
1034+
for k, v in sd.items():
1035+
sd[k] = v.to("cuda")
1036+
# load state_dict in cuda
1037+
model.load_state_dict(sd, assign=True)
1038+
10201039

10211040
class TestMultiTensorFlow(TestCase):
10221041
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")

torchao/dtypes/uintx/tensor_core_tiled_layout.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import logging
67
from dataclasses import dataclass
78
from typing import Optional, Tuple
89

@@ -318,7 +319,7 @@ def to(self, *args, **kwargs):
318319
# between these two devices, in the future we should not use the same layout for
319320
# cpu and cuda device: https://github.com/pytorch/ao/issues/1117
320321
if not is_device(torch.device(self.device).type, device):
321-
raise ValueError(
322+
logging.warning(
322323
f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}"
323324
)
324325
return self.__class__(

0 commit comments

Comments
 (0)