Skip to content

Commit 29d024c

Browse files
Fix async coroutine limit not respected and add s3/gcs chunk size (#3080) (#3083)
Signed-off-by: Yee Hing Tong <[email protected]>
1 parent 169843a commit 29d024c

File tree

4 files changed

+161
-18
lines changed

4 files changed

+161
-18
lines changed

flytekit/core/data_persistence.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@
5252

5353
Uploadable = typing.Union[str, os.PathLike, pathlib.Path, bytes, io.BufferedReader, io.BytesIO, io.StringIO]
5454

55+
# This is the default chunk size flytekit will use for writing to S3 and GCS. This is set to 25MB by default and is
56+
# configurable by the user if needed. This is used when put() is called on filesystems.
57+
_WRITE_SIZE_CHUNK_BYTES = int(os.environ.get("_F_P_WRITE_CHUNK_SIZE", "26214400")) # 25 * 2**20
58+
5559

5660
def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False) -> Dict[str, Any]:
5761
kwargs: Dict[str, Any] = {
@@ -108,6 +112,27 @@ def get_fsspec_storage_options(
108112
return {}
109113

110114

115+
def get_additional_fsspec_call_kwargs(protocol: typing.Union[str, tuple], method_name: str) -> Dict[str, Any]:
116+
"""
117+
These are different from the setup args functions defined above. Those kwargs are applied when asking fsspec
118+
to create the filesystem. These kwargs returned here are for when the filesystem's methods are invoked.
119+
120+
:param protocol: s3, gcs, etc.
121+
:param method_name: Pass in the __name__ of the fsspec.filesystem function. _'s will be ignored.
122+
"""
123+
kwargs = {}
124+
method_name = method_name.replace("_", "")
125+
if isinstance(protocol, tuple):
126+
protocol = protocol[0]
127+
128+
# For s3fs and gcsfs, we feel the default chunksize of 50MB is too big.
129+
# Re-evaluate these kwargs when we move off of s3fs to obstore.
130+
if method_name == "put" and protocol in ["s3", "gs"]:
131+
kwargs["chunksize"] = _WRITE_SIZE_CHUNK_BYTES
132+
133+
return kwargs
134+
135+
111136
@decorator
112137
def retry_request(func, *args, **kwargs):
113138
# TODO: Remove this method once s3fs has a new release. https://github.com/fsspec/s3fs/pull/865
@@ -353,6 +378,10 @@ async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kw
353378
if "metadata" not in kwargs:
354379
kwargs["metadata"] = {}
355380
kwargs["metadata"].update(self._execution_metadata)
381+
382+
additional_kwargs = get_additional_fsspec_call_kwargs(file_system.protocol, file_system.put.__name__)
383+
kwargs.update(additional_kwargs)
384+
356385
if isinstance(file_system, AsyncFileSystem):
357386
dst = await file_system._put(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212
358387
else:

flytekit/core/type_engine.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
DEFINITIONS = "definitions"
5858
TITLE = "title"
5959

60+
_TYPE_ENGINE_COROS_BATCH_SIZE = int(os.environ.get("_F_TE_MAX_COROS", "10"))
61+
6062

6163
# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True.
6264
# This is relevant for cases like Dict[int, str].
@@ -1678,10 +1680,9 @@ async def async_to_literal(
16781680
raise TypeTransformerFailedError("Expected a list")
16791681

16801682
t = self.get_sub_type(python_type)
1681-
lit_list = [
1682-
asyncio.create_task(TypeEngine.async_to_literal(ctx, x, t, expected.collection_type)) for x in python_val
1683-
]
1684-
lit_list = await _run_coros_in_chunks(lit_list)
1683+
lit_list = [TypeEngine.async_to_literal(ctx, x, t, expected.collection_type) for x in python_val]
1684+
1685+
lit_list = await _run_coros_in_chunks(lit_list, batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)
16851686

16861687
return Literal(collection=LiteralCollection(literals=lit_list))
16871688

@@ -1703,7 +1704,7 @@ async def async_to_python_value( # type: ignore
17031704

17041705
st = self.get_sub_type(expected_python_type)
17051706
result = [TypeEngine.async_to_python_value(ctx, x, st) for x in lits]
1706-
result = await _run_coros_in_chunks(result)
1707+
result = await _run_coros_in_chunks(result, batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)
17071708
return result # type: ignore # should be a list, thinks its a tuple
17081709

17091710
def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore
@@ -2150,13 +2151,10 @@ async def async_to_literal(
21502151
else:
21512152
_, v_type = self.extract_types_or_metadata(python_type)
21522153

2153-
lit_map[k] = asyncio.create_task(
2154-
TypeEngine.async_to_literal(ctx, v, cast(type, v_type), expected.map_value_type)
2155-
)
2156-
2157-
await _run_coros_in_chunks([c for c in lit_map.values()])
2158-
for k, v in lit_map.items():
2159-
lit_map[k] = v.result()
2154+
lit_map[k] = TypeEngine.async_to_literal(ctx, v, cast(type, v_type), expected.map_value_type)
2155+
vals = await _run_coros_in_chunks([c for c in lit_map.values()], batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)
2156+
for idx, k in zip(range(len(vals)), lit_map.keys()):
2157+
lit_map[k] = vals[idx]
21602158

21612159
return Literal(map=LiteralMap(literals=lit_map))
21622160

@@ -2177,12 +2175,11 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p
21772175
raise TypeError("TypeMismatch. Destination dictionary does not accept 'str' key")
21782176
py_map = {}
21792177
for k, v in lv.map.literals.items():
2180-
fut = asyncio.create_task(TypeEngine.async_to_python_value(ctx, v, cast(Type, tp[1])))
2181-
py_map[k] = fut
2178+
py_map[k] = TypeEngine.async_to_python_value(ctx, v, cast(Type, tp[1]))
21822179

2183-
await _run_coros_in_chunks([c for c in py_map.values()])
2184-
for k, v in py_map.items():
2185-
py_map[k] = v.result()
2180+
vals = await _run_coros_in_chunks([c for c in py_map.values()], batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)
2181+
for idx, k in zip(range(len(vals)), py_map.keys()):
2182+
py_map[k] = vals[idx]
21862183

21872184
return py_map
21882185

tests/flytekit/unit/core/test_data_persistence.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
import mock
1111
import pytest
1212
from azure.identity import ClientSecretCredential, DefaultAzureCredential
13+
from mock import AsyncMock
1314

1415
from flytekit.configuration import Config
15-
from flytekit.core.data_persistence import FileAccessProvider
16+
from flytekit.core.data_persistence import FileAccessProvider, get_additional_fsspec_call_kwargs
1617
from flytekit.core.local_fsspec import FlyteLocalFileSystem
1718

1819

@@ -210,6 +211,37 @@ def __init__(self, *args, **kwargs):
210211
fp.get_filesystem("testgetfs", test_arg="test_arg")
211212

212213

214+
def test_get_additional_fsspec_call_kwargs():
215+
with mock.patch("flytekit.core.data_persistence._WRITE_SIZE_CHUNK_BYTES", 12345):
216+
kwargs = get_additional_fsspec_call_kwargs(("s3", "s3a"), "put")
217+
assert kwargs == {"chunksize": 12345}
218+
219+
kwargs = get_additional_fsspec_call_kwargs("s3", "_put")
220+
assert kwargs == {"chunksize": 12345}
221+
222+
kwargs = get_additional_fsspec_call_kwargs("s3", "get")
223+
assert kwargs == {}
224+
225+
226+
@pytest.mark.asyncio
227+
@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_async_filesystem_for_path", new_callable=AsyncMock)
228+
@mock.patch("flytekit.core.data_persistence.get_additional_fsspec_call_kwargs")
229+
async def test_chunk_size(mock_call_kwargs, mock_get_fs):
230+
mock_call_kwargs.return_value = {"chunksize": 1234}
231+
mock_fs = mock.MagicMock()
232+
mock_get_fs.return_value = mock_fs
233+
234+
mock_fs.protocol = ("s3", "s3a")
235+
fp = FileAccessProvider("/tmp", "s3://container/path/within/container")
236+
237+
def put(*args, **kwargs):
238+
assert "chunksize" in kwargs
239+
240+
mock_fs.put = put
241+
upload_location = await fp._put("/tmp/foo", "s3://bar")
242+
assert upload_location == "s3://bar"
243+
244+
213245
@pytest.mark.sandbox_test
214246
def test_put_raw_data_bytes():
215247
dc = Config.for_sandbox().data_config
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import asyncio
2+
import typing
3+
4+
import mock
5+
import pytest
6+
7+
from flytekit.core.context_manager import FlyteContext
8+
from flytekit.core.type_engine import (
9+
AsyncTypeTransformer,
10+
TypeEngine,
11+
)
12+
from flytekit.models.literals import (
13+
Literal,
14+
Primitive,
15+
Scalar,
16+
)
17+
from flytekit.models.types import LiteralType, SimpleType
18+
19+
20+
class MyInt:
21+
def __init__(self, x: int):
22+
self.val = x
23+
24+
def __eq__(self, other):
25+
if not isinstance(other, MyInt):
26+
return False
27+
return other.val == self.val
28+
29+
30+
class MyIntAsyncTransformer(AsyncTypeTransformer[MyInt]):
31+
def __init__(self):
32+
super().__init__(name="MyAsyncInt", t=MyInt)
33+
self.my_lock = asyncio.Lock()
34+
self.my_count = 0
35+
36+
def assert_type(self, t, v):
37+
return
38+
39+
def get_literal_type(self, t: typing.Type[MyInt]) -> LiteralType:
40+
return LiteralType(simple=SimpleType.INTEGER)
41+
42+
async def async_to_literal(
43+
self,
44+
ctx: FlyteContext,
45+
python_val: MyInt,
46+
python_type: typing.Type[MyInt],
47+
expected: LiteralType,
48+
) -> Literal:
49+
async with self.my_lock:
50+
self.my_count += 1
51+
if self.my_count > 2:
52+
raise ValueError("coroutine count exceeded")
53+
await asyncio.sleep(0.1)
54+
lit = Literal(scalar=Scalar(primitive=Primitive(integer=python_val.val)))
55+
56+
async with self.my_lock:
57+
self.my_count -= 1
58+
59+
return lit
60+
61+
async def async_to_python_value(
62+
self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[MyInt]
63+
) -> MyInt:
64+
return MyInt(lv.scalar.primitive.integer)
65+
66+
def guess_python_type(self, literal_type: LiteralType) -> typing.Type[MyInt]:
67+
return MyInt
68+
69+
70+
@pytest.mark.asyncio
71+
async def test_coroutine_batching_of_list_transformer():
72+
TypeEngine.register(MyIntAsyncTransformer())
73+
74+
lt = LiteralType(simple=SimpleType.INTEGER)
75+
python_val = [MyInt(10), MyInt(11), MyInt(12), MyInt(13), MyInt(14)]
76+
ctx = FlyteContext.current_context()
77+
78+
with mock.patch("flytekit.core.type_engine._TYPE_ENGINE_COROS_BATCH_SIZE", 2):
79+
TypeEngine.to_literal(ctx, python_val, typing.List[MyInt], lt)
80+
81+
with mock.patch("flytekit.core.type_engine._TYPE_ENGINE_COROS_BATCH_SIZE", 5):
82+
with pytest.raises(ValueError):
83+
TypeEngine.to_literal(ctx, python_val, typing.List[MyInt], lt)
84+
85+
del TypeEngine._REGISTRY[MyInt]

0 commit comments

Comments
 (0)