8
8
9
9
import inspect
10
10
import re
11
- from typing import Any , Callable , Dict , List , Optional , Tuple , Type , TypeVar , Union
11
+ from typing import Any , Callable , Optional , Tuple , TypeVar , Union
12
12
13
- import typing_inspect
14
13
15
-
16
- def to_list (arg : str ) -> List [str ]:
14
+ def to_list (arg : str ) -> list [str ]:
17
15
conf = []
18
16
if len (arg .strip ()) == 0 :
19
17
return []
@@ -22,9 +20,9 @@ def to_list(arg: str) -> List[str]:
22
20
return conf
23
21
24
22
25
- def to_dict (arg : str ) -> Dict [str , str ]:
23
+ def to_dict (arg : str ) -> dict [str , str ]:
26
24
"""
27
- Parses the given ``arg`` string literal into a ``Dict [str, str]`` of
25
+ Parses the given ``arg`` string literal into a ``dict [str, str]`` of
28
26
key-value pairs delimited by ``"="`` (equals). The values may be a
29
27
list literal where the list elements are delimited by ``","`` (comma)
30
28
or ``";"`` (semi-colon). The same delimiters (``","`` and ``";"``) are used
@@ -85,14 +83,14 @@ def to_val(val: str) -> str:
85
83
return val [1 :- 1 ]
86
84
return val if val != '""' and val != "''" else ""
87
85
88
- arg_map : Dict [str , str ] = {}
86
+ arg_map : dict [str , str ] = {}
89
87
90
88
if not arg :
91
89
return arg_map
92
90
93
91
# find quoted values
94
92
quoted_pattern = r'([\'"])((?:\\.|(?!\1).)*?)\1'
95
- quoted_values : List [str ] = []
93
+ quoted_values : list [str ] = []
96
94
97
95
def replace_quoted (match ):
98
96
quoted_values .append (match .group (0 ))
@@ -133,19 +131,26 @@ def replace_quoted(match):
133
131
134
132
# pyre-ignore-all-errors[3, 2]
135
133
def _decode_string_to_dict (
136
- encoded_value : str , param_type : Type [Dict [Any , Any ]]
137
- ) -> Dict [Any , Any ]:
138
- key_type , value_type = typing_inspect .get_args (param_type )
134
+ encoded_value : str , param_type : type [dict [Any , Any ]]
135
+ ) -> dict [Any , Any ]:
136
+ # pyre-ignore[16]
137
+ if not hasattr (param_type , "__args__" ) or len (param_type .__args__ ) != 2 :
138
+ raise ValueError (f"param_type must be a `dict` type, but was `{ param_type } `" )
139
+
140
+ key_type , value_type = param_type .__args__
139
141
arg_values = {}
140
142
for key , value in to_dict (encoded_value ).items ():
141
143
arg_values [key_type (key )] = value_type (value )
142
144
return arg_values
143
145
144
146
145
147
def _decode_string_to_list (
146
- encoded_value : str , param_type : Type [List [Any ]]
147
- ) -> List [Any ]:
148
- value_type = typing_inspect .get_args (param_type )[0 ]
148
+ encoded_value : str , param_type : type [list [Any ]]
149
+ ) -> list [Any ]:
150
+ # pyre-ignore[16]
151
+ if not hasattr (param_type , "__args__" ) or len (param_type .__args__ ) != 1 :
152
+ raise ValueError (f"param_type must be a `list` type, but was `{ param_type } `" )
153
+ value_type = param_type .__args__ [0 ]
149
154
if not is_primitive (value_type ):
150
155
raise ValueError ("List types support only primitives: int, str, float" )
151
156
arg_values = []
@@ -166,7 +171,7 @@ def decode(encoded_value: Any, annotation: Any):
166
171
167
172
def decode_from_string (
168
173
encoded_value : str , annotation : Any
169
- ) -> Union [Dict [Any , Any ], List [Any ], None ]:
174
+ ) -> Union [dict [Any , Any ], list [Any ], None ]:
170
175
"""Decodes string representation to the underlying type(Dict or List)
171
176
172
177
Given a string representation of the value, the method decodes it according
@@ -191,13 +196,13 @@ def decode_from_string(
191
196
if not encoded_value :
192
197
return None
193
198
value_type = annotation
194
- value_origin = typing_inspect . get_origin (value_type )
195
- if value_origin is dict :
196
- return _decode_string_to_dict ( encoded_value , value_type )
197
- elif value_origin is list :
198
- return _decode_string_to_list ( encoded_value , value_type )
199
- else :
200
- raise ValueError ("Unknown" )
199
+ if hasattr (value_type , "__origin__" ):
200
+ value_origin = value_type . __origin__
201
+ if value_origin is dict :
202
+ return _decode_string_to_dict ( encoded_value , value_type )
203
+ elif value_origin is list :
204
+ return _decode_string_to_list ( encoded_value , value_type )
205
+ raise ValueError ("Unknown" )
201
206
202
207
203
208
def is_bool (param_type : Any ) -> bool :
@@ -229,12 +234,13 @@ def decode_optional(param_type: Any) -> Any:
229
234
If ``param_type`` is type Optional[INNER_TYPE], method returns INNER_TYPE
230
235
Otherwise returns ``param_type``
231
236
"""
232
- param_origin = typing_inspect .get_origin (param_type )
233
- if param_origin is not Union :
237
+ if not hasattr (param_type , "__origin__" ):
238
+ return param_type
239
+ if param_type .__origin__ is not Union :
234
240
return param_type
235
- key_type , value_type = typing_inspect . get_args ( param_type )
236
- if value_type is type (None ):
237
- return key_type
241
+ args = param_type . __args__
242
+ if len ( args ) == 2 and args [ 1 ] is type (None ):
243
+ return args [ 0 ]
238
244
else :
239
245
return param_type
240
246
0 commit comments