Skip to content

Commit 0227a1c

Browse files
Transformers v5 pre release (#2387)
* test main * test * fix * fx * move
1 parent b7f00d5 commit 0227a1c

File tree

9 files changed

+33
-36
lines changed

9 files changed

+33
-36
lines changed

.github/workflows/build_main_documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ jobs:
6767
- name: Make Optimum documentation
6868
run: |
6969
cd optimum
70-
uv pip install . accelerate
70+
uv pip install . accelerate --prerelease=allow
7171
make doc BUILD_DIR=optimum-doc-build VERSION=${{ env.VERSION }}
7272
cd ..
7373

.github/workflows/build_pr_documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ jobs:
6464
- name: Make Optimum documentation
6565
run: |
6666
cd optimum
67-
uv pip install . accelerate
67+
uv pip install . accelerate --prerelease=allow
6868
make doc BUILD_DIR=optimum-doc-build VERSION=pr_$PR_NUMBER
6969
cd ..
7070

.github/workflows/test_cli.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
run: |
5151
python -m pip install --upgrade pip uv
5252
uv pip install ./optimum-onnx[onnxruntime]
53-
uv pip install ./optimum[tests]
53+
uv pip install ./optimum[tests] --prerelease=allow
5454
5555
- name: Test with pytest
5656
run: |
@@ -61,7 +61,7 @@ jobs:
6161
run: |
6262
uv pip uninstall optimum-onnx optimum
6363
uv pip install -e ./optimum-onnx[onnxruntime]
64-
uv pip install -e ./optimum[tests]
64+
uv pip install -e ./optimum[tests] --prerelease=allow
6565
6666
- name: Test with pytest (editable mode)
6767
run: |

.github/workflows/test_common.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
- name: Install dependencies
3939
run: |
4040
python -m pip install --upgrade pip uv
41-
uv pip install .[tests]
41+
uv pip install .[tests] --prerelease=allow
4242
4343
- name: Test with pytest
4444
run: |

.github/workflows/test_exporters_common.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
run: |
5151
python -m pip install --upgrade pip uv
5252
uv pip install ./optimum-onnx
53-
uv pip install ./optimum[tests]
53+
uv pip install ./optimum[tests] --prerelease=allow
5454
5555
- name: Test with pytest
5656
run: |
@@ -61,7 +61,7 @@ jobs:
6161
run: |
6262
uv pip uninstall optimum-onnx optimum
6363
uv pip install -e ./optimum-onnx
64-
uv pip install -e ./optimum[tests]
64+
uv pip install -e ./optimum[tests] --prerelease=allow
6565
6666
- name: Test with pytest (editable mode)
6767
run: |

.github/workflows/test_pipelines.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
run: |
5151
python -m pip install --upgrade pip uv
5252
uv pip install ./optimum-onnx[onnxruntime]
53-
uv pip install ./optimum[tests]
53+
uv pip install ./optimum[tests] --prerelease=allow
5454
5555
- name: Test with pytest
5656
run: |
@@ -61,7 +61,7 @@ jobs:
6161
run: |
6262
uv pip uninstall optimum-onnx optimum
6363
uv pip install -e ./optimum-onnx[onnxruntime]
64-
uv pip install -e ./optimum[tests]
64+
uv pip install -e ./optimum[tests] --prerelease=allow
6565
6666
- name: Test with pytest (editable mode)
6767
run: |

.github/workflows/test_utils.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
- name: Install dependencies
3838
run: |
3939
python -m pip install --upgrade pip uv
40-
uv pip install .[tests]
40+
uv pip install .[tests] --prerelease=allow
4141
4242
- name: Tests needing datasets
4343
run: |

optimum/configuration_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from transformers import PretrainedConfig
2525
from transformers import __version__ as transformers_version_str
2626
from transformers.dynamic_module_utils import custom_object_save
27-
from transformers.utils import cached_file, download_url, extract_commit_hash, is_remote_url
27+
from transformers.utils import cached_file, extract_commit_hash
2828

2929
from .utils import logging
3030
from .version import __version__
@@ -192,10 +192,6 @@ def _get_config_dict(
192192
# Special case when pretrained_model_name_or_path is a local file
193193
resolved_config_file = pretrained_model_name_or_path
194194
is_local = True
195-
# TODO: remove condition once transformers release version is way above 4.22.
196-
elif is_remote_url(pretrained_model_name_or_path):
197-
configuration_file = pretrained_model_name_or_path
198-
resolved_config_file = download_url(pretrained_model_name_or_path)
199195
else:
200196
configuration_file = kwargs.pop("_configuration_file", cls.CONFIG_NAME)
201197

optimum/fx/optimization/transformations.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,10 @@
2222
from typing import List
2323

2424
import torch
25-
from torch.fx import GraphModule, Node
25+
from torch.fx import GraphModule, Node, Proxy
2626
from transformers.file_utils import add_end_docstrings
2727

2828

29-
try:
30-
from transformers.utils.fx import _gen_constructor_wrapper
31-
except ImportError:
32-
from transformers.utils.fx import gen_constructor_wrapper
33-
34-
def _gen_constructor_wrapper(*args, **kwargs):
35-
wrapper, target = gen_constructor_wrapper(*args, **kwargs)
36-
37-
def wrapper_with_forced_tracing(*_args, **_kwargs):
38-
import torch.fx._symbolic_trace
39-
40-
original_flag = torch.fx._symbolic_trace._is_fx_tracing_flag
41-
torch.fx._symbolic_trace._is_fx_tracing_flag = True
42-
out = wrapper(*_args, **_kwargs)
43-
torch.fx._symbolic_trace._is_fx_tracing_flag = original_flag
44-
return out
45-
46-
return wrapper_with_forced_tracing, target
47-
48-
4929
_ATTRIBUTES_DOCSTRING = r"""
5030
Attributes:
5131
preserves_computation (`bool`, defaults to `False`):
@@ -819,3 +799,24 @@ def reverse(self, graph_module):
819799
return ComposeTransformation._reverse_composition(graph_module)
820800

821801
return ComposeTransformation()
802+
803+
804+
def _gen_constructor_wrapper(target):
805+
@functools.wraps(target)
806+
def wrapper(*args, **kwargs):
807+
proxy = None
808+
809+
def check_has_proxy(v):
810+
if isinstance(v, Proxy):
811+
nonlocal proxy
812+
proxy = v
813+
814+
torch.fx.node.map_aggregate(args, check_has_proxy)
815+
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
816+
817+
if proxy is not None:
818+
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
819+
else:
820+
return target(*args, **kwargs)
821+
822+
return wrapper, target

0 commit comments

Comments
 (0)