diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 3404ae5130fe..032a6e161bbd 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -92,7 +92,7 @@ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) -def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool: +def is_safetensors_compatible(filenames, passed_components=None, folder_names=None, variant=None) -> bool: """ Checking for safetensors compatibility: - The model is safetensors compatible only if there is a safetensors file for each model component present in @@ -103,6 +103,31 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" extension is replaced with ".safetensors" """ + weight_names = [ + WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + FLAX_WEIGHTS_NAME, + ONNX_WEIGHTS_NAME, + ONNX_EXTERNAL_WEIGHTS_NAME, + ] + + if is_transformers_available(): + weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] + + # model_pytorch, diffusion_model_pytorch, ... + weight_prefixes = [w.split(".")[0] for w in weight_names] + # .bin, .safetensors, ... + weight_suffixs = [w.split(".")[-1] for w in weight_names] + # -00001-of-00002 + transformers_index_format = r"\d{5}-of-\d{5}" + # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` + variant_file_re = re.compile( + rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" + ) + non_variant_file_re = re.compile( + rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" + ) + passed_components = passed_components or [] if folder_names: filenames = {f for f in filenames if os.path.split(f)[0] in folder_names} @@ -122,14 +147,22 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No # If there are no component folders check the main directory for safetensors files if not components: - return any(".safetensors" in filename for filename in filenames) + if variant is not None: + filtered_filenames = filter_with_regex(filenames, variant_file_re) + else: + filtered_filenames = filter_with_regex(filenames, non_variant_file_re) + return any(".safetensors" in filename for filename in filtered_filenames) # iterate over all files of a component # check if safetensor files exist for that component # if variant is provided check if the variant of the safetensors exists for component, component_filenames in components.items(): matches = [] - for component_filename in component_filenames: + if variant is not None: + filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re) + else: + filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re) + for component_filename in filtered_component_filenames: filename, extension = os.path.splitext(component_filename) match_exists = extension == ".safetensors" @@ -159,6 +192,10 @@ def filter_model_files(filenames): return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)] +def filter_with_regex(filenames, pattern_re): + return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None} + + def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]: weight_names = [ WEIGHTS_NAME, @@ -207,9 +244,6 @@ def filter_for_compatible_extensions(filenames, ignore_patterns=None): # interested in the extension name return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)} - def filter_with_regex(filenames, pattern_re): - return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None} - # Group files by component components = {} for filename in filenames: @@ -997,7 +1031,7 @@ def _get_ignore_patterns( use_safetensors and not allow_pickle and not is_safetensors_compatible( - model_filenames, passed_components=passed_components, folder_names=model_folder_names + model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant ) ): raise EnvironmentError( @@ -1008,7 +1042,7 @@ def _get_ignore_patterns( ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] elif use_safetensors and is_safetensors_compatible( - model_filenames, passed_components=passed_components, folder_names=model_folder_names + model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant ): ignore_patterns = ["*.bin", "*.msgpack"] diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 423c2b8ab146..f680cf2dcf18 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -87,21 +87,24 @@ def test_all_is_compatible_variant(self): "unet/diffusion_pytorch_model.fp16.bin", "unet/diffusion_pytorch_model.fp16.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames)) + self.assertFalse(is_safetensors_compatible(filenames)) + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) def test_diffusers_model_is_compatible_variant(self): filenames = [ "unet/diffusion_pytorch_model.fp16.bin", "unet/diffusion_pytorch_model.fp16.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames)) + self.assertFalse(is_safetensors_compatible(filenames)) + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) def test_diffusers_model_is_compatible_variant_mixed(self): filenames = [ "unet/diffusion_pytorch_model.bin", "unet/diffusion_pytorch_model.fp16.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames)) + self.assertFalse(is_safetensors_compatible(filenames)) + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) def test_diffusers_model_is_not_compatible_variant(self): filenames = [ @@ -121,7 +124,8 @@ def test_transformer_model_is_compatible_variant(self): "text_encoder/pytorch_model.fp16.bin", "text_encoder/model.fp16.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames)) + self.assertFalse(is_safetensors_compatible(filenames)) + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) def test_transformer_model_is_not_compatible_variant(self): filenames = [ @@ -145,7 +149,8 @@ def test_transformer_model_is_compatible_variant_extra_folder(self): "unet/diffusion_pytorch_model.fp16.bin", "unet/diffusion_pytorch_model.fp16.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"})) + self.assertFalse(is_safetensors_compatible(filenames, folder_names={"vae", "unet"})) + self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}, variant="fp16")) def test_transformer_model_is_not_compatible_variant_extra_folder(self): filenames = [ @@ -173,7 +178,8 @@ def test_transformers_is_compatible_variant_sharded(self): "text_encoder/model.fp16-00001-of-00002.safetensors", "text_encoder/model.fp16-00001-of-00002.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames)) + self.assertFalse(is_safetensors_compatible(filenames)) + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) def test_diffusers_is_compatible_sharded(self): filenames = [ @@ -189,13 +195,15 @@ def test_diffusers_is_compatible_variant_sharded(self): "unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors", "unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames)) + self.assertFalse(is_safetensors_compatible(filenames)) + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) def test_diffusers_is_compatible_only_variants(self): filenames = [ "unet/diffusion_pytorch_model.fp16.safetensors", ] - self.assertTrue(is_safetensors_compatible(filenames)) + self.assertFalse(is_safetensors_compatible(filenames)) + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) def test_diffusers_is_compatible_no_components(self): filenames = [ diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index caa7755904a5..a2241236da20 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -538,26 +538,38 @@ def test_download_variant_partly(self): variant = "no_ema" with tempfile.TemporaryDirectory() as tmpdirname: - tmpdirname = StableDiffusionPipeline.download( - "hf-internal-testing/stable-diffusion-all-variants", - cache_dir=tmpdirname, - variant=variant, - use_safetensors=use_safetensors, - ) - all_root_files = [t[-1] for t in os.walk(tmpdirname)] - files = [item for sublist in all_root_files for item in sublist] - - unet_files = os.listdir(os.path.join(tmpdirname, "unet")) + if use_safetensors: + with self.assertRaises(OSError) as error_context: + tmpdirname = StableDiffusionPipeline.download( + "hf-internal-testing/stable-diffusion-all-variants", + cache_dir=tmpdirname, + variant=variant, + use_safetensors=use_safetensors, + ) + assert "Could not find the necessary `safetensors` weights" in str(error_context.exception) + else: + tmpdirname = StableDiffusionPipeline.download( + "hf-internal-testing/stable-diffusion-all-variants", + cache_dir=tmpdirname, + variant=variant, + use_safetensors=use_safetensors, + ) + all_root_files = [t[-1] for t in os.walk(tmpdirname)] + files = [item for sublist in all_root_files for item in sublist] - # Some of the downloaded files should be a non-variant file, check: - # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet - assert len(files) == 15, f"We should only download 15 files, not {len(files)}" - # only unet has "no_ema" variant - assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files - assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1 - # vae, safety_checker and text_encoder should have no variant - assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 - assert not any(f.endswith(other_format) for f in files) + unet_files = os.listdir(os.path.join(tmpdirname, "unet")) + + # Some of the downloaded files should be a non-variant file, check: + # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet + assert len(files) == 15, f"We should only download 15 files, not {len(files)}" + # only unet has "no_ema" variant + assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files + assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1 + # vae, safety_checker and text_encoder should have no variant + assert ( + sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 + ) + assert not any(f.endswith(other_format) for f in files) def test_download_variants_with_sharded_checkpoints(self): # Here we test for downloading of "variant" files belonging to the `unet` and @@ -588,20 +600,17 @@ def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self): logger = logging.get_logger("diffusers.pipelines.pipeline_utils") deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant" - for is_local in [True, False]: - with CaptureLogger(logger) as cap_logger: - with tempfile.TemporaryDirectory() as tmpdirname: - local_repo_id = repo_id - if is_local: - local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname) + with CaptureLogger(logger) as cap_logger: + with tempfile.TemporaryDirectory() as tmpdirname: + local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname) - _ = DiffusionPipeline.from_pretrained( - local_repo_id, - safety_checker=None, - variant="fp16", - use_safetensors=True, - ) - assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs" + _ = DiffusionPipeline.from_pretrained( + local_repo_id, + safety_checker=None, + variant="fp16", + use_safetensors=True, + ) + assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs" def test_download_safetensors_only_variant_exists_for_model(self): variant = None @@ -616,7 +625,7 @@ def test_download_safetensors_only_variant_exists_for_model(self): variant=variant, use_safetensors=use_safetensors, ) - assert "Error no file name" in str(error_context.exception) + assert "Could not find the necessary `safetensors` weights" in str(error_context.exception) # text encoder has fp16 variants so we can load it with tempfile.TemporaryDirectory() as tmpdirname: @@ -675,7 +684,7 @@ def test_download_safetensors_variant_does_not_exist_for_model(self): use_safetensors=use_safetensors, ) - assert "Error no file name" in str(error_context.exception) + assert "Could not find the necessary `safetensors` weights" in str(error_context.exception) def test_download_bin_variant_does_not_exist_for_model(self): variant = "no_ema"