Skip to content

Commit c903527

Browse files
a-r-r-o-wSunMarc
andauthored
Speedup model loading by 4-5x ⚡ (#11904)
* update * update * update * pin accelerate version * add comment explanations * update docstring * make style * non_blocking does not matter for dtype cast * _empty_cache -> clear_cache * update * Update src/diffusers/models/model_loading_utils.py Co-authored-by: Marc Sun <[email protected]> * Update src/diffusers/models/model_loading_utils.py --------- Co-authored-by: Marc Sun <[email protected]>
1 parent 7a935a0 commit c903527

File tree

8 files changed

+133
-45
lines changed

8 files changed

+133
-45
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .. import __version__
2525
from ..quantizers import DiffusersAutoQuantizer
2626
from ..utils import deprecate, is_accelerate_available, logging
27+
from ..utils.torch_utils import device_synchronize, empty_device_cache
2728
from .single_file_utils import (
2829
SingleFileComponentError,
2930
convert_animatediff_checkpoint_to_diffusers,
@@ -430,6 +431,10 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
430431
keep_in_fp32_modules=keep_in_fp32_modules,
431432
unexpected_keys=unexpected_keys,
432433
)
434+
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
435+
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
436+
empty_device_cache()
437+
device_synchronize()
433438
else:
434439
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
435440

src/diffusers/loaders/single_file_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
4848
from ..utils.hub_utils import _get_model_file
49+
from ..utils.torch_utils import device_synchronize, empty_device_cache
4950

5051

5152
if is_transformers_available():
@@ -1689,6 +1690,10 @@ def create_diffusers_clip_model_from_ldm(
16891690

16901691
if is_accelerate_available():
16911692
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1693+
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
1694+
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
1695+
empty_device_cache()
1696+
device_synchronize()
16921697
else:
16931698
model.load_state_dict(diffusers_format_checkpoint, strict=False)
16941699

@@ -2148,6 +2153,10 @@ def create_diffusers_t5_model_from_checkpoint(
21482153

21492154
if is_accelerate_available():
21502155
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
2156+
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
2157+
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
2158+
empty_device_cache()
2159+
device_synchronize()
21512160
else:
21522161
model.load_state_dict(diffusers_format_checkpoint)
21532162

src/diffusers/loaders/transformer_flux.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@
1818
MultiIPAdapterImageProjection,
1919
)
2020
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
21-
from ..utils import (
22-
is_accelerate_available,
23-
is_torch_version,
24-
logging,
25-
)
21+
from ..utils import is_accelerate_available, is_torch_version, logging
22+
from ..utils.torch_utils import device_synchronize, empty_device_cache
2623

2724

2825
if is_accelerate_available():
@@ -84,6 +81,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
8481
else:
8582
device_map = {"": self.device}
8683
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
84+
empty_device_cache()
85+
device_synchronize()
8786

8887
return image_projection
8988

@@ -158,6 +157,9 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
158157

159158
key_id += 1
160159

160+
empty_device_cache()
161+
device_synchronize()
162+
161163
return attn_procs
162164

163165
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):

src/diffusers/loaders/transformer_sd3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ..models.embeddings import IPAdapterTimeImageProjection
1919
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
2020
from ..utils import is_accelerate_available, is_torch_version, logging
21+
from ..utils.torch_utils import device_synchronize, empty_device_cache
2122

2223

2324
logger = logging.get_logger(__name__)
@@ -80,6 +81,9 @@ def _convert_ip_adapter_attn_to_diffusers(
8081
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
8182
)
8283

84+
empty_device_cache()
85+
device_synchronize()
86+
8387
return attn_procs
8488

8589
def _convert_ip_adapter_image_proj_to_diffusers(
@@ -147,6 +151,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(
147151
else:
148152
device_map = {"": self.device}
149153
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
154+
empty_device_cache()
155+
device_synchronize()
150156

151157
return image_proj
152158

src/diffusers/loaders/unet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
is_torch_version,
4444
logging,
4545
)
46+
from ..utils.torch_utils import device_synchronize, empty_device_cache
4647
from .lora_base import _func_optionally_disable_offloading
4748
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
4849
from .utils import AttnProcsLayers
@@ -753,6 +754,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
753754
else:
754755
device_map = {"": self.device}
755756
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
757+
empty_device_cache()
758+
device_synchronize()
756759

757760
return image_projection
758761

@@ -850,6 +853,9 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
850853

851854
key_id += 2
852855

856+
empty_device_cache()
857+
device_synchronize()
858+
853859
return attn_procs
854860

855861
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):

src/diffusers/models/model_loading_utils.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616

1717
import importlib
1818
import inspect
19+
import math
1920
import os
2021
from array import array
21-
from collections import OrderedDict
22+
from collections import OrderedDict, defaultdict
2223
from pathlib import Path
2324
from typing import Dict, List, Optional, Union
2425
from zipfile import is_zipfile
@@ -38,6 +39,7 @@
3839
_get_model_file,
3940
deprecate,
4041
is_accelerate_available,
42+
is_accelerate_version,
4143
is_gguf_available,
4244
is_torch_available,
4345
is_torch_version,
@@ -252,6 +254,10 @@ def load_model_dict_into_meta(
252254
param = param.to(dtype)
253255
set_module_kwargs["dtype"] = dtype
254256

257+
if is_accelerate_version(">", "1.8.1"):
258+
set_module_kwargs["non_blocking"] = True
259+
set_module_kwargs["clear_cache"] = False
260+
255261
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
256262
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
257263
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
@@ -520,3 +526,60 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
520526
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
521527

522528
return parsed_parameters
529+
530+
531+
def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
532+
mismatched_keys = []
533+
if not ignore_mismatched_sizes:
534+
return mismatched_keys
535+
for checkpoint_key in loaded_keys:
536+
model_key = checkpoint_key
537+
# If the checkpoint is sharded, we may not have the key here.
538+
if checkpoint_key not in state_dict:
539+
continue
540+
541+
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
542+
mismatched_keys.append(
543+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
544+
)
545+
del state_dict[checkpoint_key]
546+
return mismatched_keys
547+
548+
549+
def _expand_device_map(device_map, param_names):
550+
"""
551+
Expand a device map to return the correspondence parameter name to device.
552+
"""
553+
new_device_map = {}
554+
for module, device in device_map.items():
555+
new_device_map.update(
556+
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
557+
)
558+
return new_device_map
559+
560+
561+
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
562+
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
563+
"""
564+
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
565+
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
566+
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
567+
very large margin.
568+
"""
569+
# Remove disk and cpu devices, and cast to proper torch.device
570+
accelerator_device_map = {
571+
param: torch.device(device)
572+
for param, device in expanded_device_map.items()
573+
if str(device) not in ["cpu", "disk"]
574+
}
575+
parameter_count = defaultdict(lambda: 0)
576+
for param_name, device in accelerator_device_map.items():
577+
try:
578+
param = model.get_parameter(param_name)
579+
except AttributeError:
580+
param = model.get_buffer(param_name)
581+
parameter_count[device] += math.prod(param.shape)
582+
583+
# This will kick off the caching allocator to avoid having to Malloc afterwards
584+
for device, param_count in parameter_count.items():
585+
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)

src/diffusers/models/modeling_utils.py

Lines changed: 27 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,14 @@
6262
load_or_create_model_card,
6363
populate_model_card,
6464
)
65+
from ..utils.torch_utils import device_synchronize, empty_device_cache
6566
from .model_loading_utils import (
67+
_caching_allocator_warmup,
6668
_determine_device_map,
69+
_expand_device_map,
6770
_fetch_index_file,
6871
_fetch_index_file_legacy,
72+
_find_mismatched_keys,
6973
_load_state_dict_into_model,
7074
load_model_dict_into_meta,
7175
load_state_dict,
@@ -1469,11 +1473,6 @@ def _load_pretrained_model(
14691473
for pat in cls._keys_to_ignore_on_load_unexpected:
14701474
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
14711475

1472-
mismatched_keys = []
1473-
1474-
assign_to_params_buffers = None
1475-
error_msgs = []
1476-
14771476
# Deal with offload
14781477
if device_map is not None and "disk" in device_map.values():
14791478
if offload_folder is None:
@@ -1482,18 +1481,27 @@ def _load_pretrained_model(
14821481
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
14831482
" offers the weights in this format."
14841483
)
1485-
if offload_folder is not None:
1484+
else:
14861485
os.makedirs(offload_folder, exist_ok=True)
14871486
if offload_state_dict is None:
14881487
offload_state_dict = True
14891488

1489+
# If a device map has been used, we can speedup the load time by warming up the device caching allocator.
1490+
# If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
1491+
# lot of individual calls to device malloc). We can, however, preallocate the memory required by the
1492+
# tensors using their expected shape and not performing any initialization of the memory (empty data).
1493+
# When the actual device allocations happen, the allocator already has a pool of unused device memory
1494+
# that it can re-use for faster loading of the model.
1495+
# TODO: add support for warmup with hf_quantizer
1496+
if device_map is not None and hf_quantizer is None:
1497+
expanded_device_map = _expand_device_map(device_map, expected_keys)
1498+
_caching_allocator_warmup(model, expanded_device_map, dtype)
1499+
14901500
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
1501+
state_dict_folder, state_dict_index = None, None
14911502
if offload_state_dict:
14921503
state_dict_folder = tempfile.mkdtemp()
14931504
state_dict_index = {}
1494-
else:
1495-
state_dict_folder = None
1496-
state_dict_index = None
14971505

14981506
if state_dict is not None:
14991507
# load_state_dict will manage the case where we pass a dict instead of a file
@@ -1503,38 +1511,14 @@ def _load_pretrained_model(
15031511
if len(resolved_model_file) > 1:
15041512
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
15051513

1514+
mismatched_keys = []
1515+
assign_to_params_buffers = None
1516+
error_msgs = []
1517+
15061518
for shard_file in resolved_model_file:
15071519
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
1508-
1509-
def _find_mismatched_keys(
1510-
state_dict,
1511-
model_state_dict,
1512-
loaded_keys,
1513-
ignore_mismatched_sizes,
1514-
):
1515-
mismatched_keys = []
1516-
if ignore_mismatched_sizes:
1517-
for checkpoint_key in loaded_keys:
1518-
model_key = checkpoint_key
1519-
# If the checkpoint is sharded, we may not have the key here.
1520-
if checkpoint_key not in state_dict:
1521-
continue
1522-
1523-
if (
1524-
model_key in model_state_dict
1525-
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1526-
):
1527-
mismatched_keys.append(
1528-
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1529-
)
1530-
del state_dict[checkpoint_key]
1531-
return mismatched_keys
1532-
15331520
mismatched_keys += _find_mismatched_keys(
1534-
state_dict,
1535-
model_state_dict,
1536-
loaded_keys,
1537-
ignore_mismatched_sizes,
1521+
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes
15381522
)
15391523

15401524
if low_cpu_mem_usage:
@@ -1554,9 +1538,13 @@ def _find_mismatched_keys(
15541538
else:
15551539
if assign_to_params_buffers is None:
15561540
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
1557-
15581541
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
15591542

1543+
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
1544+
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
1545+
empty_device_cache()
1546+
device_synchronize()
1547+
15601548
if offload_index is not None and len(offload_index) > 0:
15611549
save_offload_index(offload_index, offload_folder)
15621550
offload_index = None

src/diffusers/utils/torch_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,5 +184,14 @@ def get_device():
184184
def empty_device_cache(device_type: Optional[str] = None):
185185
if device_type is None:
186186
device_type = get_device()
187+
if device_type in ["cpu"]:
188+
return
187189
device_mod = getattr(torch, device_type, torch.cuda)
188190
device_mod.empty_cache()
191+
192+
193+
def device_synchronize(device_type: Optional[str] = None):
194+
if device_type is None:
195+
device_type = get_device()
196+
device_mod = getattr(torch, device_type, torch.cuda)
197+
device_mod.synchronize()

0 commit comments

Comments
 (0)