Skip to content

Commit b7f00d5

Browse files
authored
handle missing TF2_WEIGHTS_NAME for newer transformers versions (#2384)
1 parent 2e8e232 commit b7f00d5

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

optimum/exporters/tasks.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from packaging import version
2828
from requests.exceptions import ConnectionError
2929
from transformers import AutoConfig, PretrainedConfig
30-
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_NAME, WEIGHTS_NAME, http_user_agent
30+
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, http_user_agent
3131

3232
from ..utils.import_utils import is_diffusers_available, is_torch_available
3333
from ..utils.logging import get_logger
@@ -627,14 +627,8 @@ def determine_framework(
627627
for file in all_files
628628
]
629629

630-
weight_name = Path(TF2_WEIGHTS_NAME).stem
631-
weight_extension = Path(TF2_WEIGHTS_NAME).suffix
632-
is_tf_weight_file = [file.startswith(weight_name) and file.endswith(weight_extension) for file in all_files]
633-
634630
if any(is_pt_weight_file):
635631
framework = "pt"
636-
elif any(is_tf_weight_file):
637-
framework = "tf"
638632
elif "model_index.json" in all_files and any(
639633
file.endswith((pt_weight_extension, safe_weight_extension)) for file in all_files
640634
):
@@ -649,11 +643,11 @@ def determine_framework(
649643
f"The framework could not be automatically inferred. If using the command-line, please provide the argument --framework (pt,tf) Detailed error: {request_exception}"
650644
)
651645
else:
652-
raise FileNotFoundError(
653-
"Cannot determine framework from given checkpoint location."
654-
f" There should be a {Path(WEIGHTS_NAME).stem}*{Path(WEIGHTS_NAME).suffix} for PyTorch"
655-
f" or {Path(TF2_WEIGHTS_NAME).stem}*{Path(TF2_WEIGHTS_NAME).suffix} for TensorFlow."
646+
msg = (
647+
"Cannot determine framework from given checkpoint location. "
648+
f"There should be a {Path(WEIGHTS_NAME).stem}*{Path(WEIGHTS_NAME).suffix} for PyTorch."
656649
)
650+
raise FileNotFoundError(msg)
657651

658652
if is_torch_available():
659653
framework = framework or "pt"

0 commit comments

Comments
 (0)