You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
**Summary:** We have an guard preventing users from using a
cuda quantized on cpu and vice versa. However, this also
prevents users who load their checkpoints on cpu first and
then move them to cuda later, which is what torchtune does:
```
quantize_(model.cuda(), Int4WeightOnlyConfig())
# save checkpoint in cuda
torch.save(model.state_dict(), "my_checkpoint.pt")
# load checkpoint on cpu
# This is what torchtune does: https://github.com/pytorch/torchtune/blob/v0.6.1/torchtune/training/checkpointing/_utils.py#L253
sd = torch.load("my_checkpoint.pt", weights_only=False, map_location="cpu")
# move checkpoint to cuda
for k, v in sd.items():
sd[k] = v.to("cuda")
# load state_dict in cuda
model.load_state_dict(sd, assign=True)
```
This use case is safe in that the model was quantized in
cuda and ultimately used on cuda. This commit relaxes the
error to allow the above use case. More details here:
#1117.
**Test Plan:**
python test/quantization/test_quant_api.py -k test_int4wo_cuda_serialization
0 commit comments