From 7a9c448ac2f8b0ff66b9dab4e4efdaee9b98e08a Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 9 Jul 2025 12:06:29 +0200 Subject: [PATCH 01/12] update --- src/diffusers/models/model_loading_utils.py | 45 +++++++++++++++++++-- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index ebc7d79aeb28..4e02f870aae8 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -16,9 +16,10 @@ import importlib import inspect +import math import os from array import array -from collections import OrderedDict +from collections import OrderedDict, defaultdict from pathlib import Path from typing import Dict, List, Optional, Union from zipfile import is_zipfile @@ -230,6 +231,16 @@ def load_model_dict_into_meta( is_quantized = hf_quantizer is not None empty_state_dict = model.state_dict() + expanded_device_map = {} + + if device_map is not None: + for param_name, param in state_dict.items(): + if param_name not in empty_state_dict: + continue + param_device = _determine_param_device(param_name, device_map) + expanded_device_map[param_name] = param_device + print(expanded_device_map) + _caching_allocator_warmup(model, expanded_device_map, dtype) for param_name, param in state_dict.items(): if param_name not in empty_state_dict: @@ -243,13 +254,13 @@ def load_model_dict_into_meta( if keep_in_fp32_modules is not None and any( module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules ): - param = param.to(torch.float32) + param = param.to(torch.float32, non_blocking=True) set_module_kwargs["dtype"] = torch.float32 # For quantizers have save weights using torch.float8_e4m3fn elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None): pass else: - param = param.to(dtype) + param = param.to(dtype, non_blocking=True) set_module_kwargs["dtype"] = dtype # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which @@ -265,7 +276,7 @@ def load_model_dict_into_meta( if old_param is not None: if dtype is None: - param = param.to(old_param.dtype) + param = param.to(old_param.dtype, non_blocking=True) if old_param.is_contiguous(): param = param.contiguous() @@ -520,3 +531,29 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights return parsed_parameters + + +# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859 +def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None: + """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each + device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model, + which is actually the loading speed botteneck. Calling this function allows to cut the model loading time by a very + large margin. + """ + # Remove disk and cpu devices, and cast to proper torch.device + accelerator_device_map = { + param: torch.device(device) + for param, device in expanded_device_map.items() + if str(device) not in ["cpu", "disk"] + } + parameter_count = defaultdict(lambda: 0) + for param_name, device in accelerator_device_map.items(): + try: + param = model.get_parameter(param_name) + except AttributeError: + param = model.get_buffer(param_name) + parameter_count[device] += math.prod(param.shape) + + # This will kick off the caching allocator to avoid having to Malloc afterwards + for device, param_count in parameter_count.items(): + _ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False) From 8385f4548017a7f4d649b5fbf9350999942aec4f Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 9 Jul 2025 12:14:20 +0200 Subject: [PATCH 02/12] update --- src/diffusers/models/model_loading_utils.py | 16 ++++++++-------- src/diffusers/models/modeling_utils.py | 2 ++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 4e02f870aae8..9132667bb720 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -233,14 +233,14 @@ def load_model_dict_into_meta( empty_state_dict = model.state_dict() expanded_device_map = {} - if device_map is not None: - for param_name, param in state_dict.items(): - if param_name not in empty_state_dict: - continue - param_device = _determine_param_device(param_name, device_map) - expanded_device_map[param_name] = param_device - print(expanded_device_map) - _caching_allocator_warmup(model, expanded_device_map, dtype) + # if device_map is not None: + # for param_name, param in state_dict.items(): + # if param_name not in empty_state_dict: + # continue + # param_device = _determine_param_device(param_name, device_map) + # expanded_device_map[param_name] = param_device + # print(expanded_device_map) + # _caching_allocator_warmup(model, expanded_device_map, dtype) for param_name, param in state_dict.items(): if param_name not in empty_state_dict: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 8e1ec5f55889..312f0962ae4f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1557,6 +1557,8 @@ def _find_mismatched_keys( error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) + torch.cuda.synchronize() + if offload_index is not None and len(offload_index) > 0: save_offload_index(offload_index, offload_folder) offload_index = None From 9e4873becc4a525420df505744a8d30a8a0730b7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 10 Jul 2025 04:49:50 +0200 Subject: [PATCH 03/12] update --- src/diffusers/loaders/single_file_model.py | 3 ++ src/diffusers/loaders/single_file_utils.py | 5 ++ src/diffusers/loaders/transformer_flux.py | 12 +++-- src/diffusers/loaders/transformer_sd3.py | 6 +++ src/diffusers/loaders/unet.py | 6 +++ src/diffusers/models/model_loading_utils.py | 55 +++++++++++++++---- src/diffusers/models/modeling_utils.py | 60 +++++++-------------- src/diffusers/utils/torch_utils.py | 9 ++++ 8 files changed, 99 insertions(+), 57 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 17ac81ca26f6..d78219b560cc 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -24,6 +24,7 @@ from .. import __version__ from ..quantizers import DiffusersAutoQuantizer from ..utils import deprecate, is_accelerate_available, logging +from ..utils.torch_utils import device_synchronize, empty_device_cache from .single_file_utils import ( SingleFileComponentError, convert_animatediff_checkpoint_to_diffusers, @@ -430,6 +431,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = keep_in_fp32_modules=keep_in_fp32_modules, unexpected_keys=unexpected_keys, ) + empty_device_cache() + device_synchronize() else: _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index ee0786aa2d6a..bd71ff99ac65 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -46,6 +46,7 @@ ) from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT from ..utils.hub_utils import _get_model_file +from ..utils.torch_utils import device_synchronize, empty_device_cache if is_transformers_available(): @@ -1689,6 +1690,8 @@ def create_diffusers_clip_model_from_ldm( if is_accelerate_available(): load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + empty_device_cache() + device_synchronize() else: model.load_state_dict(diffusers_format_checkpoint, strict=False) @@ -2148,6 +2151,8 @@ def create_diffusers_t5_model_from_checkpoint( if is_accelerate_available(): load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + empty_device_cache() + device_synchronize() else: model.load_state_dict(diffusers_format_checkpoint) diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index c7d81a8baebd..af03d09029c1 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -18,11 +18,8 @@ MultiIPAdapterImageProjection, ) from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta -from ..utils import ( - is_accelerate_available, - is_torch_version, - logging, -) +from ..utils import is_accelerate_available, is_torch_version, logging +from ..utils.torch_utils import device_synchronize, empty_device_cache if is_accelerate_available(): @@ -84,6 +81,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us else: device_map = {"": self.device} load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) + empty_device_cache() + device_synchronize() return image_projection @@ -158,6 +157,9 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_ key_id += 1 + empty_device_cache() + device_synchronize() + return attn_procs def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py index c58d3280cfe1..4421f46dfc02 100644 --- a/src/diffusers/loaders/transformer_sd3.py +++ b/src/diffusers/loaders/transformer_sd3.py @@ -18,6 +18,7 @@ from ..models.embeddings import IPAdapterTimeImageProjection from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import is_accelerate_available, is_torch_version, logging +from ..utils.torch_utils import device_synchronize, empty_device_cache logger = logging.get_logger(__name__) @@ -80,6 +81,9 @@ def _convert_ip_adapter_attn_to_diffusers( attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype ) + empty_device_cache() + device_synchronize() + return attn_procs def _convert_ip_adapter_image_proj_to_diffusers( @@ -147,6 +151,8 @@ def _convert_ip_adapter_image_proj_to_diffusers( else: device_map = {"": self.device} load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype) + empty_device_cache() + device_synchronize() return image_proj diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index c9b6a7d7d862..250542b17abe 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -44,6 +44,7 @@ is_torch_version, logging, ) +from ..utils.torch_utils import device_synchronize, empty_device_cache from .lora_base import _func_optionally_disable_offloading from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME from .utils import AttnProcsLayers @@ -752,6 +753,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us else: device_map = {"": self.device} load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) + empty_device_cache() + device_synchronize() return image_projection @@ -849,6 +852,9 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_ key_id += 2 + empty_device_cache() + device_synchronize() + return attn_procs def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 9132667bb720..b29ad47700b1 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -231,16 +231,6 @@ def load_model_dict_into_meta( is_quantized = hf_quantizer is not None empty_state_dict = model.state_dict() - expanded_device_map = {} - - # if device_map is not None: - # for param_name, param in state_dict.items(): - # if param_name not in empty_state_dict: - # continue - # param_device = _determine_param_device(param_name, device_map) - # expanded_device_map[param_name] = param_device - # print(expanded_device_map) - # _caching_allocator_warmup(model, expanded_device_map, dtype) for param_name, param in state_dict.items(): if param_name not in empty_state_dict: @@ -310,7 +300,15 @@ def load_model_dict_into_meta( model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype ) else: - set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs) + set_module_tensor_to_device( + model, + param_name, + param_device, + value=param, + non_blocking=True, + _empty_cache=False, + **set_module_kwargs, + ) return offload_index, state_dict_index @@ -533,6 +531,41 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): return parsed_parameters +def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, +): + mismatched_keys = [] + if not ignore_mismatched_sizes: + return mismatched_keys + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + # If the checkpoint is sharded, we may not have the key here. + if checkpoint_key not in state_dict: + continue + + if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape: + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + +def _expand_device_map(device_map, param_names): + """ + Expand a device map to return the correspondence parameter name to device. + """ + new_device_map = {} + for module, device in device_map.items(): + new_device_map.update( + {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} + ) + return new_device_map + + # Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859 def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None: """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 312f0962ae4f..512825236a25 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -62,10 +62,14 @@ load_or_create_model_card, populate_model_card, ) +from ..utils.torch_utils import device_synchronize, empty_device_cache from .model_loading_utils import ( + _caching_allocator_warmup, _determine_device_map, + _expand_device_map, _fetch_index_file, _fetch_index_file_legacy, + _find_mismatched_keys, _load_state_dict_into_model, load_model_dict_into_meta, load_state_dict, @@ -1469,11 +1473,6 @@ def _load_pretrained_model( for pat in cls._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - mismatched_keys = [] - - assign_to_params_buffers = None - error_msgs = [] - # Deal with offload if device_map is not None and "disk" in device_map.values(): if offload_folder is None: @@ -1482,18 +1481,21 @@ def _load_pretrained_model( " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" " offers the weights in this format." ) - if offload_folder is not None: + else: os.makedirs(offload_folder, exist_ok=True) if offload_state_dict is None: offload_state_dict = True + # Caching allocator warmup + if device_map is not None: + expanded_device_map = _expand_device_map(device_map, expected_keys) + _caching_allocator_warmup(model, expanded_device_map, dtype) + offload_index = {} if device_map is not None and "disk" in device_map.values() else None + state_dict_folder, state_dict_index = None, None if offload_state_dict: state_dict_folder = tempfile.mkdtemp() state_dict_index = {} - else: - state_dict_folder = None - state_dict_index = None if state_dict is not None: # load_state_dict will manage the case where we pass a dict instead of a file @@ -1503,38 +1505,14 @@ def _load_pretrained_model( if len(resolved_model_file) > 1: resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards") + mismatched_keys = [] + assign_to_params_buffers = None + error_msgs = [] + for shard_file in resolved_model_file: state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) - - def _find_mismatched_keys( - state_dict, - model_state_dict, - loaded_keys, - ignore_mismatched_sizes, - ): - mismatched_keys = [] - if ignore_mismatched_sizes: - for checkpoint_key in loaded_keys: - model_key = checkpoint_key - # If the checkpoint is sharded, we may not have the key here. - if checkpoint_key not in state_dict: - continue - - if ( - model_key in model_state_dict - and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape - ): - mismatched_keys.append( - (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) - ) - del state_dict[checkpoint_key] - return mismatched_keys - mismatched_keys += _find_mismatched_keys( - state_dict, - model_state_dict, - loaded_keys, - ignore_mismatched_sizes, + state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes ) if low_cpu_mem_usage: @@ -1554,11 +1532,11 @@ def _find_mismatched_keys( else: if assign_to_params_buffers is None: assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) - error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) - torch.cuda.synchronize() - + empty_device_cache() + device_synchronize() + if offload_index is not None and len(offload_index) > 0: save_offload_index(offload_index, offload_folder) offload_index = None diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 61a5d95b6926..a97345a761a1 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -182,5 +182,14 @@ def get_device(): def empty_device_cache(device_type: Optional[str] = None): if device_type is None: device_type = get_device() + if device_type in ["cpu"]: + return device_mod = getattr(torch, device_type, torch.cuda) device_mod.empty_cache() + + +def device_synchronize(device_type: Optional[str] = None): + if device_type is None: + device_type = get_device() + device_mod = getattr(torch, device_type, torch.cuda) + device_mod.synchronize() From b776aaa1eb8621f0f2d1b72b8e94c8d2e7bbca5d Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 10 Jul 2025 05:10:17 +0200 Subject: [PATCH 04/12] pin accelerate version --- src/diffusers/models/model_loading_utils.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index b29ad47700b1..e856b218c358 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -39,6 +39,7 @@ _get_model_file, deprecate, is_accelerate_available, + is_accelerate_version, is_gguf_available, is_torch_available, is_torch_version, @@ -253,6 +254,10 @@ def load_model_dict_into_meta( param = param.to(dtype, non_blocking=True) set_module_kwargs["dtype"] = dtype + if is_accelerate_version(">=", "1.9.0.dev0"): + set_module_kwargs["non_blocking"] = True + set_module_kwargs["_empty_cache"] = False + # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 @@ -300,15 +305,7 @@ def load_model_dict_into_meta( model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype ) else: - set_module_tensor_to_device( - model, - param_name, - param_device, - value=param, - non_blocking=True, - _empty_cache=False, - **set_module_kwargs, - ) + set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs) return offload_index, state_dict_index From ea446b117b1097e670c848a27709d3f65e320370 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 10 Jul 2025 11:18:59 +0200 Subject: [PATCH 05/12] add comment explanations --- src/diffusers/loaders/single_file_model.py | 2 ++ src/diffusers/loaders/single_file_utils.py | 4 ++++ src/diffusers/models/modeling_utils.py | 12 ++++++++++-- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index d78219b560cc..b9b86cf480a2 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -431,6 +431,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = keep_in_fp32_modules=keep_in_fp32_modules, unexpected_keys=unexpected_keys, ) + # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is + # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() device_synchronize() else: diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index bd71ff99ac65..5fafcb02be6f 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -1690,6 +1690,8 @@ def create_diffusers_clip_model_from_ldm( if is_accelerate_available(): load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is + # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() device_synchronize() else: @@ -2151,6 +2153,8 @@ def create_diffusers_t5_model_from_checkpoint( if is_accelerate_available(): load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is + # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() device_synchronize() else: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 512825236a25..d7b2136b4afc 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1486,8 +1486,14 @@ def _load_pretrained_model( if offload_state_dict is None: offload_state_dict = True - # Caching allocator warmup - if device_map is not None: + # If a device map has been used, we can speedup the load time by warming up the device caching allocator. + # If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a + # lot of individual calls to device malloc). We can, however, preallocate the memory required by the + # tensors using their expected shape and not performing any initialization of the memory (empty data). + # When the actual device allocations happen, the allocator already has a pool of unused device memory + # that it can re-use for faster loading of the model. + # TODO: add support for warmup with hf_quantizer + if device_map is not None and hf_quantizer is None: expanded_device_map = _expand_device_map(device_map, expected_keys) _caching_allocator_warmup(model, expanded_device_map, dtype) @@ -1534,6 +1540,8 @@ def _load_pretrained_model( assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) + # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is + # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() device_synchronize() From e736b094e7b6d2801030094f878e1262b7ddb27f Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 10 Jul 2025 11:22:22 +0200 Subject: [PATCH 06/12] update docstring --- src/diffusers/models/model_loading_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index e856b218c358..a1ac8f216fec 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -565,10 +565,11 @@ def _expand_device_map(device_map, param_names): # Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859 def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None: - """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each - device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model, - which is actually the loading speed botteneck. Calling this function allows to cut the model loading time by a very - large margin. + """ + This function warm-ups the caching allocator based on the size of the model tensors that will reside on each + device. It allows to have one large call to Malloc, instead of recursively calling it later when loading + the model, which is actually the loading speed bottleneck. + Calling this function allows to cut the model loading time by a very large margin. """ # Remove disk and cpu devices, and cast to proper torch.device accelerator_device_map = { From 4c81c962d8df203a97c7a925d7dc0baca0d7c8c3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 10 Jul 2025 11:22:33 +0200 Subject: [PATCH 07/12] make style --- src/diffusers/models/model_loading_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index a1ac8f216fec..691bb3925d7b 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -567,9 +567,9 @@ def _expand_device_map(device_map, param_names): def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None: """ This function warm-ups the caching allocator based on the size of the model tensors that will reside on each - device. It allows to have one large call to Malloc, instead of recursively calling it later when loading - the model, which is actually the loading speed bottleneck. - Calling this function allows to cut the model loading time by a very large margin. + device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model, + which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a + very large margin. """ # Remove disk and cpu devices, and cast to proper torch.device accelerator_device_map = { From 582af9b5a1481cdf0a82afb7d8a7bdcde6c455a8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 10 Jul 2025 11:28:53 +0200 Subject: [PATCH 08/12] non_blocking does not matter for dtype cast --- src/diffusers/models/model_loading_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 691bb3925d7b..812027d6d99c 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -245,13 +245,13 @@ def load_model_dict_into_meta( if keep_in_fp32_modules is not None and any( module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules ): - param = param.to(torch.float32, non_blocking=True) + param = param.to(torch.float32) set_module_kwargs["dtype"] = torch.float32 # For quantizers have save weights using torch.float8_e4m3fn elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None): pass else: - param = param.to(dtype, non_blocking=True) + param = param.to(dtype) set_module_kwargs["dtype"] = dtype if is_accelerate_version(">=", "1.9.0.dev0"): @@ -271,7 +271,7 @@ def load_model_dict_into_meta( if old_param is not None: if dtype is None: - param = param.to(old_param.dtype, non_blocking=True) + param = param.to(old_param.dtype) if old_param.is_contiguous(): param = param.contiguous() From a6ee6606c146ca32e236732954e1fe136290ec18 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 11 Jul 2025 06:02:58 +0200 Subject: [PATCH 09/12] _empty_cache -> clear_cache --- src/diffusers/models/model_loading_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 812027d6d99c..5bcac998a7f1 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -256,7 +256,7 @@ def load_model_dict_into_meta( if is_accelerate_version(">=", "1.9.0.dev0"): set_module_kwargs["non_blocking"] = True - set_module_kwargs["_empty_cache"] = False + set_module_kwargs["clear_cache"] = False # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. From 39f08509a61183d3fb8441357f624b67f346ea44 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 11 Jul 2025 06:04:34 +0200 Subject: [PATCH 10/12] update --- src/diffusers/models/model_loading_utils.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 5bcac998a7f1..104f519756dc 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -528,12 +528,7 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): return parsed_parameters -def _find_mismatched_keys( - state_dict, - model_state_dict, - loaded_keys, - ignore_mismatched_sizes, -): +def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes): mismatched_keys = [] if not ignore_mismatched_sizes: return mismatched_keys From 58fcfdc8549e5a5a9127fd7a2484f9d29e62eb7d Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 11 Jul 2025 20:33:20 +0530 Subject: [PATCH 11/12] Update src/diffusers/models/model_loading_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/diffusers/models/model_loading_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 104f519756dc..c312997bd3fd 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -254,7 +254,7 @@ def load_model_dict_into_meta( param = param.to(dtype) set_module_kwargs["dtype"] = dtype - if is_accelerate_version(">=", "1.9.0.dev0"): + if is_accelerate_version(">", "1.8.0"): set_module_kwargs["non_blocking"] = True set_module_kwargs["clear_cache"] = False From 275e470d4631c2e6d68adf6ba4e75c40be240f7e Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 11 Jul 2025 20:51:59 +0530 Subject: [PATCH 12/12] Update src/diffusers/models/model_loading_utils.py --- src/diffusers/models/model_loading_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index c312997bd3fd..4e2d24b75011 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -254,7 +254,7 @@ def load_model_dict_into_meta( param = param.to(dtype) set_module_kwargs["dtype"] = dtype - if is_accelerate_version(">", "1.8.0"): + if is_accelerate_version(">", "1.8.1"): set_module_kwargs["non_blocking"] = True set_module_kwargs["clear_cache"] = False