Skip to content

Commit 25c0898

Browse files
kiukchungfacebook-github-bot
authored andcommitted
(torchx/components) Remove dependency to pyre_extensions (#1088)
Summary: There was only a single usage of `pyre_extensions` in the entire code base. Get rid of this dependency by using an `assert` statement rather than `pyre_extensions.none_throws`. Reviewed By: highker Differential Revision: D77621041
1 parent 7fabab4 commit 25c0898

File tree

4 files changed

+49
-49
lines changed

4 files changed

+49
-49
lines changed

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
pyre-extensions
21
docstring-parser>=0.8.1
32
importlib-metadata
43
pyyaml

torchx/components/structured_arg.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
from pathlib import Path
3131
from typing import Optional
3232

33-
from pyre_extensions import none_throws
34-
3533
from torchx import specs
3634

3735

@@ -148,7 +146,8 @@ def parse_from(
148146
if m: # use the last module name
149147
run_name = m.rpartition(".")[2]
150148
else: # use script name w/ no extension
151-
run_name = Path(none_throws(script)).stem
149+
assert script, "`script` can't be `None` here due checks above"
150+
run_name = Path(script).stem
152151
return StructuredNameArgument(
153152
experiment_name or default_experiment_name, run_name
154153
)

torchx/util/test/types_test.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88

99
import inspect
1010
import unittest
11-
from typing import cast, Dict, List, Optional, Union
11+
from typing import cast, Optional, Union
1212

13-
import typing_inspect
1413
from torchx.util.types import (
1514
decode,
1615
decode_from_string,
@@ -26,25 +25,25 @@
2625

2726
def _test_complex_args(
2827
arg1: int,
29-
arg2: Optional[List[str]],
28+
arg2: Optional[list[str]],
3029
arg3: Union[float, int],
3130
) -> int:
3231
return 42
3332

3433

35-
def _test_dict(arg1: Dict[int, float]) -> int:
34+
def _test_dict(arg1: dict[int, float]) -> int:
3635
return 42
3736

3837

39-
def _test_nested_object(arg1: Dict[str, List[str]]) -> int:
38+
def _test_nested_object(arg1: dict[str, list[str]]) -> int:
4039
return 42
4140

4241

43-
def _test_list(arg1: List[float]) -> int:
42+
def _test_list(arg1: list[float]) -> int:
4443
return 42
4544

4645

47-
def _test_complex_list(arg1: List[List[float]]) -> int:
46+
def _test_complex_list(arg1: list[list[float]]) -> int:
4847
return 42
4948

5049

@@ -59,24 +58,21 @@ def test_decode_optional(self) -> None:
5958
arg1_parameter = parameters["arg1"]
6059
arg1_type = decode_optional(arg1_parameter.annotation)
6160
self.assertTrue(arg1_type is int)
62-
63-
arg2_parameter = parameters["arg2"]
6461
arg2_type = decode_optional(parameters["arg2"].annotation)
65-
self.assertTrue(typing_inspect.get_origin(arg2_type) is list)
66-
62+
self.assertTrue(getattr(arg2_type, "__origin__", None) is list)
6763
arg3_parameter = parameters["arg3"]
6864
arg3_type = decode_optional(arg3_parameter.annotation)
69-
self.assertTrue(typing_inspect.get_origin(arg3_type) is Union)
65+
self.assertTrue(
66+
hasattr(arg3_type, "__origin__") and arg3_type.__origin__ is Union
67+
)
7068

7169
def test_is_primitive(self) -> None:
7270
parameters = inspect.signature(_test_complex_args).parameters
7371

7472
arg1_parameter = parameters["arg1"]
75-
arg1_type = decode_optional(arg1_parameter.annotation)
7673
self.assertTrue(is_primitive(arg1_parameter.annotation))
7774

7875
arg2_parameter = parameters["arg2"]
79-
arg2_type = decode_optional(parameters["arg2"].annotation)
8076
self.assertFalse(is_primitive(arg2_parameter.annotation))
8177

8278
def test_is_bool(self) -> None:
@@ -89,7 +85,7 @@ def test_decode_from_string_dict(self) -> None:
8985
encoded_value = "1=1.0,2=42.1,3=10"
9086

9187
value = decode_from_string(encoded_value, parameters["arg1"].annotation)
92-
value = cast(Dict[int, float], value)
88+
value = cast(dict[int, float], value)
9389
self.assertEqual(3, len(value))
9490
self.assertEqual(float(1.0), value[1])
9591
self.assertEqual(float(42.1), value[2])
@@ -101,7 +97,7 @@ def test_decode_from_string_list(self) -> None:
10197
encoded_value = "1.0,42.2,3.9"
10298

10399
value = decode_from_string(encoded_value, parameters["arg1"].annotation)
104-
value = cast(List[float], value)
100+
value = cast(list[float], value)
105101
self.assertEqual(3, len(value))
106102
self.assertEqual(float(1.0), value[0])
107103
self.assertEqual(float(42.2), value[1])
@@ -217,8 +213,8 @@ def fake_component(
217213
f: float,
218214
s: str,
219215
b: bool,
220-
l: List[str],
221-
m: Dict[str, str],
216+
l: list[str],
217+
m: dict[str, str],
222218
o: Optional[int],
223219
) -> None:
224220
# component has to return an AppDef

torchx/util/types.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88

99
import inspect
1010
import re
11-
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
11+
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
1212

13-
import typing_inspect
1413

15-
16-
def to_list(arg: str) -> List[str]:
14+
def to_list(arg: str) -> list[str]:
1715
conf = []
1816
if len(arg.strip()) == 0:
1917
return []
@@ -22,9 +20,9 @@ def to_list(arg: str) -> List[str]:
2220
return conf
2321

2422

25-
def to_dict(arg: str) -> Dict[str, str]:
23+
def to_dict(arg: str) -> dict[str, str]:
2624
"""
27-
Parses the given ``arg`` string literal into a ``Dict[str, str]`` of
25+
Parses the given ``arg`` string literal into a ``dict[str, str]`` of
2826
key-value pairs delimited by ``"="`` (equals). The values may be a
2927
list literal where the list elements are delimited by ``","`` (comma)
3028
or ``";"`` (semi-colon). The same delimiters (``","`` and ``";"``) are used
@@ -85,14 +83,14 @@ def to_val(val: str) -> str:
8583
return val[1:-1]
8684
return val if val != '""' and val != "''" else ""
8785

88-
arg_map: Dict[str, str] = {}
86+
arg_map: dict[str, str] = {}
8987

9088
if not arg:
9189
return arg_map
9290

9391
# find quoted values
9492
quoted_pattern = r'([\'"])((?:\\.|(?!\1).)*?)\1'
95-
quoted_values: List[str] = []
93+
quoted_values: list[str] = []
9694

9795
def replace_quoted(match):
9896
quoted_values.append(match.group(0))
@@ -133,19 +131,26 @@ def replace_quoted(match):
133131

134132
# pyre-ignore-all-errors[3, 2]
135133
def _decode_string_to_dict(
136-
encoded_value: str, param_type: Type[Dict[Any, Any]]
137-
) -> Dict[Any, Any]:
138-
key_type, value_type = typing_inspect.get_args(param_type)
134+
encoded_value: str, param_type: type[dict[Any, Any]]
135+
) -> dict[Any, Any]:
136+
# pyre-ignore[16]
137+
if not hasattr(param_type, "__args__") or len(param_type.__args__) != 2:
138+
raise ValueError(f"param_type must be a `dict` type, but was `{param_type}`")
139+
140+
key_type, value_type = param_type.__args__
139141
arg_values = {}
140142
for key, value in to_dict(encoded_value).items():
141143
arg_values[key_type(key)] = value_type(value)
142144
return arg_values
143145

144146

145147
def _decode_string_to_list(
146-
encoded_value: str, param_type: Type[List[Any]]
147-
) -> List[Any]:
148-
value_type = typing_inspect.get_args(param_type)[0]
148+
encoded_value: str, param_type: type[list[Any]]
149+
) -> list[Any]:
150+
# pyre-ignore[16]
151+
if not hasattr(param_type, "__args__") or len(param_type.__args__) != 1:
152+
raise ValueError(f"param_type must be a `list` type, but was `{param_type}`")
153+
value_type = param_type.__args__[0]
149154
if not is_primitive(value_type):
150155
raise ValueError("List types support only primitives: int, str, float")
151156
arg_values = []
@@ -166,7 +171,7 @@ def decode(encoded_value: Any, annotation: Any):
166171

167172
def decode_from_string(
168173
encoded_value: str, annotation: Any
169-
) -> Union[Dict[Any, Any], List[Any], None]:
174+
) -> Union[dict[Any, Any], list[Any], None]:
170175
"""Decodes string representation to the underlying type(Dict or List)
171176
172177
Given a string representation of the value, the method decodes it according
@@ -191,13 +196,13 @@ def decode_from_string(
191196
if not encoded_value:
192197
return None
193198
value_type = annotation
194-
value_origin = typing_inspect.get_origin(value_type)
195-
if value_origin is dict:
196-
return _decode_string_to_dict(encoded_value, value_type)
197-
elif value_origin is list:
198-
return _decode_string_to_list(encoded_value, value_type)
199-
else:
200-
raise ValueError("Unknown")
199+
if hasattr(value_type, "__origin__"):
200+
value_origin = value_type.__origin__
201+
if value_origin is dict:
202+
return _decode_string_to_dict(encoded_value, value_type)
203+
elif value_origin is list:
204+
return _decode_string_to_list(encoded_value, value_type)
205+
raise ValueError("Unknown")
201206

202207

203208
def is_bool(param_type: Any) -> bool:
@@ -229,12 +234,13 @@ def decode_optional(param_type: Any) -> Any:
229234
If ``param_type`` is type Optional[INNER_TYPE], method returns INNER_TYPE
230235
Otherwise returns ``param_type``
231236
"""
232-
param_origin = typing_inspect.get_origin(param_type)
233-
if param_origin is not Union:
237+
if not hasattr(param_type, "__origin__"):
238+
return param_type
239+
if param_type.__origin__ is not Union:
234240
return param_type
235-
key_type, value_type = typing_inspect.get_args(param_type)
236-
if value_type is type(None):
237-
return key_type
241+
args = param_type.__args__
242+
if len(args) == 2 and args[1] is type(None):
243+
return args[0]
238244
else:
239245
return param_type
240246

0 commit comments

Comments
 (0)