Skip to content

Speedup model loading by 4-5x ⚡ #11904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
Expand Down Expand Up @@ -430,6 +431,10 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)

Expand Down
9 changes: 9 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file
from ..utils.torch_utils import device_synchronize, empty_device_cache


if is_transformers_available():
Expand Down Expand Up @@ -1689,6 +1690,10 @@ def create_diffusers_clip_model_from_ldm(

if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
model.load_state_dict(diffusers_format_checkpoint, strict=False)

Expand Down Expand Up @@ -2148,6 +2153,10 @@ def create_diffusers_t5_model_from_checkpoint(

if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
model.load_state_dict(diffusers_format_checkpoint)

Expand Down
12 changes: 7 additions & 5 deletions src/diffusers/loaders/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
is_accelerate_available,
is_torch_version,
logging,
)
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache


if is_accelerate_available():
Expand Down Expand Up @@ -84,6 +81,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()

return image_projection

Expand Down Expand Up @@ -158,6 +157,9 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_

key_id += 1

empty_device_cache()
device_synchronize()

return attn_procs

def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/loaders/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -80,6 +81,9 @@ def _convert_ip_adapter_attn_to_diffusers(
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
)

empty_device_cache()
device_synchronize()

return attn_procs

def _convert_ip_adapter_image_proj_to_diffusers(
Expand Down Expand Up @@ -147,6 +151,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()

return image_proj

Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
is_torch_version,
logging,
)
from ..utils.torch_utils import device_synchronize, empty_device_cache
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
Expand Down Expand Up @@ -753,6 +754,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()

return image_projection

Expand Down Expand Up @@ -850,6 +853,9 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_

key_id += 2

empty_device_cache()
device_synchronize()

return attn_procs

def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
Expand Down
65 changes: 64 additions & 1 deletion src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

import importlib
import inspect
import math
import os
from array import array
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
Expand All @@ -38,6 +39,7 @@
_get_model_file,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_gguf_available,
is_torch_available,
is_torch_version,
Expand Down Expand Up @@ -252,6 +254,10 @@ def load_model_dict_into_meta(
param = param.to(dtype)
set_module_kwargs["dtype"] = dtype

if is_accelerate_version(">", "1.8.1"):
set_module_kwargs["non_blocking"] = True
set_module_kwargs["clear_cache"] = False

# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
Expand Down Expand Up @@ -520,3 +526,60 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights

return parsed_parameters


def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
mismatched_keys = []
if not ignore_mismatched_sizes:
return mismatched_keys
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue

if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys


def _expand_device_map(device_map, param_names):
"""
Expand a device map to return the correspondence parameter name to device.
"""
new_device_map = {}
for module, device in device_map.items():
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
return new_device_map


# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
"""
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
very large margin.
"""
# Remove disk and cpu devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device)
for param, device in expanded_device_map.items()
if str(device) not in ["cpu", "disk"]
}
parameter_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
try:
param = model.get_parameter(param_name)
except AttributeError:
param = model.get_buffer(param_name)
parameter_count[device] += math.prod(param.shape)

# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, param_count in parameter_count.items():
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
66 changes: 27 additions & 39 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,14 @@
load_or_create_model_card,
populate_model_card,
)
from ..utils.torch_utils import device_synchronize, empty_device_cache
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
_expand_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_find_mismatched_keys,
_load_state_dict_into_model,
load_model_dict_into_meta,
load_state_dict,
Expand Down Expand Up @@ -1469,11 +1473,6 @@ def _load_pretrained_model(
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

mismatched_keys = []

assign_to_params_buffers = None
error_msgs = []

# Deal with offload
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
Expand All @@ -1482,18 +1481,27 @@ def _load_pretrained_model(
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if offload_folder is not None:
else:
os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True

# If a device map has been used, we can speedup the load time by warming up the device caching allocator.
# If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
# lot of individual calls to device malloc). We can, however, preallocate the memory required by the
# tensors using their expected shape and not performing any initialization of the memory (empty data).
# When the actual device allocations happen, the allocator already has a pool of unused device memory
# that it can re-use for faster loading of the model.
# TODO: add support for warmup with hf_quantizer
if device_map is not None and hf_quantizer is None:
expanded_device_map = _expand_device_map(device_map, expected_keys)
_caching_allocator_warmup(model, expanded_device_map, dtype)

offload_index = {} if device_map is not None and "disk" in device_map.values() else None
state_dict_folder, state_dict_index = None, None
if offload_state_dict:
state_dict_folder = tempfile.mkdtemp()
state_dict_index = {}
else:
state_dict_folder = None
state_dict_index = None

if state_dict is not None:
# load_state_dict will manage the case where we pass a dict instead of a file
Expand All @@ -1503,38 +1511,14 @@ def _load_pretrained_model(
if len(resolved_model_file) > 1:
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")

mismatched_keys = []
assign_to_params_buffers = None
error_msgs = []

for shard_file in resolved_model_file:
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)

def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue

if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys

mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes
)

if low_cpu_mem_usage:
Expand All @@ -1554,9 +1538,13 @@ def _find_mismatched_keys(
else:
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)

error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)

# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()

if offload_index is not None and len(offload_index) > 0:
save_offload_index(offload_index, offload_folder)
offload_index = None
Expand Down
9 changes: 9 additions & 0 deletions src/diffusers/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,5 +184,14 @@ def get_device():
def empty_device_cache(device_type: Optional[str] = None):
if device_type is None:
device_type = get_device()
if device_type in ["cpu"]:
return
device_mod = getattr(torch, device_type, torch.cuda)
device_mod.empty_cache()


def device_synchronize(device_type: Optional[str] = None):
if device_type is None:
device_type = get_device()
device_mod = getattr(torch, device_type, torch.cuda)
device_mod.synchronize()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess all different backends ought to have this method. Just flagging.

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik, synchronize should be available on all devices. Just the empty_cache function required a special check because it would fail if device was cpu