Skip to content

Fix bug when variant and safetensor file does not match #11587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 26, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
@@ -92,7 +92,7 @@
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])


def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None, variant=None) -> bool:
"""
Checking for safetensors compatibility:
- The model is safetensors compatible only if there is a safetensors file for each model component present in
@@ -103,6 +103,31 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
"""
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
]

if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]

# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
# .bin, .safetensors, ...
weight_suffixs = [w.split(".")[-1] for w in weight_names]
# -00001-of-00002
transformers_index_format = r"\d{5}-of-\d{5}"
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
variant_file_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
)
non_variant_file_re = re.compile(
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
)

passed_components = passed_components or []
if folder_names:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
@@ -122,14 +147,22 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No

# If there are no component folders check the main directory for safetensors files
if not components:
return any(".safetensors" in filename for filename in filenames)
if variant is not None:
filtered_filenames = filter_with_regex(filenames, variant_file_re)
else:
filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
return any(".safetensors" in filename for filename in filtered_filenames)

# iterate over all files of a component
# check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists
for component, component_filenames in components.items():
matches = []
for component_filename in component_filenames:
if variant is not None:
filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)
else:
filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
for component_filename in filtered_component_filenames:
filename, extension = os.path.splitext(component_filename)

match_exists = extension == ".safetensors"
@@ -159,6 +192,10 @@ def filter_model_files(filenames):
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]


def filter_with_regex(filenames, pattern_re):
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}


def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
weight_names = [
WEIGHTS_NAME,
@@ -207,9 +244,6 @@ def filter_for_compatible_extensions(filenames, ignore_patterns=None):
# interested in the extension name
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}

def filter_with_regex(filenames, pattern_re):
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}

# Group files by component
components = {}
for filename in filenames:
@@ -997,7 +1031,7 @@ def _get_ignore_patterns(
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
)
):
raise EnvironmentError(
@@ -1008,7 +1042,7 @@ def _get_ignore_patterns(
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]

elif use_safetensors and is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
):
ignore_patterns = ["*.bin", "*.msgpack"]

24 changes: 16 additions & 8 deletions tests/pipelines/test_pipeline_utils.py
Original file line number Diff line number Diff line change
@@ -87,21 +87,24 @@ def test_all_is_compatible_variant(self):
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))

def test_diffusers_model_is_compatible_variant(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))

def test_diffusers_model_is_compatible_variant_mixed(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))

def test_diffusers_model_is_not_compatible_variant(self):
filenames = [
@@ -121,7 +124,8 @@ def test_transformer_model_is_compatible_variant(self):
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))

def test_transformer_model_is_not_compatible_variant(self):
filenames = [
@@ -145,7 +149,8 @@ def test_transformer_model_is_compatible_variant_extra_folder(self):
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
self.assertFalse(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}, variant="fp16"))

def test_transformer_model_is_not_compatible_variant_extra_folder(self):
filenames = [
@@ -173,7 +178,8 @@ def test_transformers_is_compatible_variant_sharded(self):
"text_encoder/model.fp16-00001-of-00002.safetensors",
"text_encoder/model.fp16-00001-of-00002.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))

def test_diffusers_is_compatible_sharded(self):
filenames = [
@@ -189,13 +195,15 @@ def test_diffusers_is_compatible_variant_sharded(self):
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))

def test_diffusers_is_compatible_only_variants(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))

def test_diffusers_is_compatible_no_components(self):
filenames = [
77 changes: 43 additions & 34 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
@@ -538,26 +538,38 @@ def test_download_variant_partly(self):
variant = "no_ema"

with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist]

unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
if use_safetensors:
with self.assertRaises(OSError) as error_context:
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
else:
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist]

# Some of the downloaded files should be a non-variant file, check:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
# vae, safety_checker and text_encoder should have no variant
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
assert not any(f.endswith(other_format) for f in files)
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))

# Some of the downloaded files should be a non-variant file, check:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
# vae, safety_checker and text_encoder should have no variant
assert (
sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
)
assert not any(f.endswith(other_format) for f in files)

def test_download_variants_with_sharded_checkpoints(self):
# Here we test for downloading of "variant" files belonging to the `unet` and
@@ -588,20 +600,17 @@ def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self):
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant"

for is_local in [True, False]:
with CaptureLogger(logger) as cap_logger:
with tempfile.TemporaryDirectory() as tmpdirname:
local_repo_id = repo_id
if is_local:
local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
with CaptureLogger(logger) as cap_logger:
with tempfile.TemporaryDirectory() as tmpdirname:
local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)

_ = DiffusionPipeline.from_pretrained(
local_repo_id,
safety_checker=None,
variant="fp16",
use_safetensors=True,
)
assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"
_ = DiffusionPipeline.from_pretrained(
local_repo_id,
safety_checker=None,
variant="fp16",
use_safetensors=True,
)
assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"

def test_download_safetensors_only_variant_exists_for_model(self):
variant = None
@@ -616,7 +625,7 @@ def test_download_safetensors_only_variant_exists_for_model(self):
variant=variant,
use_safetensors=use_safetensors,
)
assert "Error no file name" in str(error_context.exception)
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)

# text encoder has fp16 variants so we can load it
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -675,7 +684,7 @@ def test_download_safetensors_variant_does_not_exist_for_model(self):
use_safetensors=use_safetensors,
)

assert "Error no file name" in str(error_context.exception)
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)

def test_download_bin_variant_does_not_exist_for_model(self):
variant = "no_ema"