-
Notifications
You must be signed in to change notification settings - Fork 6k
Torchao int4 serialization #11591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Torchao int4 serialization #11591
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Some questions.
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO | ||
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] | ||
): | ||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sufficiently safe to say that we would always have a non-None device_map
?
Also, what happens if the device_map
has multiple CUDA devices specified? Would the indexing make sense there?
Okay for this PR but we could potentially have a resolve_map_location()
per quantizer class, maybe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sufficiently safe to say that we would always have a non-None device_map?
I check that the device_map is not None. Also this should be safe enough. I took that from transformers. There shouldn't be an issue with the indexing, in any case we will move again the tensors if they are multiple index.
Yeah I can switch to update_map_location
.
Once the PR is close to merging, let's also add a test. |
Will add a test ! Note that in general, I wouldn't recommend saving int4 models with torchao as this is hardware dependent between |
Indeed. Then let's also add a note in the docs |
What does this PR do?
This PR fixes torchao int4 checkpoint loading. We need to load the checkpoint directly on the device we quantized it. We make the assumption that we are loading the model on the right device at the start.
Needed for this model https://huggingface.co/diffusers/FLUX.1-dev-torchao-int4