Skip to content

Commit 9e66bb8

Browse files
authored
Introduce chunk_rows to settings (#1270)
1 parent ced5ea6 commit 9e66bb8

9 files changed

Lines changed: 214 additions & 59 deletions

File tree

src/datachain/lib/dc/datachain.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -324,29 +324,32 @@ def settings(
324324
sys: Optional[bool] = None,
325325
namespace: Optional[str] = None,
326326
project: Optional[str] = None,
327+
batch_rows: Optional[int] = None,
327328
) -> "Self":
328329
"""Change settings for chain.
329330
330331
This function changes specified settings without changing not specified ones.
331332
It returns chain, so, it can be chained later with next operation.
332333
333334
Parameters:
334-
cache : data caching (default=False)
335+
cache : data caching. (default=False)
335336
parallel : number of thread for processors. True is a special value to
336-
enable all available CPUs (default=1)
337+
enable all available CPUs. (default=1)
337338
workers : number of distributed workers. Only for Studio mode. (default=1)
338-
min_task_size : minimum number of tasks (default=1)
339-
prefetch: number of workers to use for downloading files in advance.
339+
min_task_size : minimum number of tasks. (default=1)
340+
prefetch : number of workers to use for downloading files in advance.
340341
This is enabled by default and uses 2 workers.
341342
To disable prefetching, set it to 0.
342-
namespace: namespace name.
343-
project: project name.
343+
namespace : namespace name.
344+
project : project name.
345+
batch_rows : row limit per insert to balance speed and memory usage.
346+
(default=2000)
344347
345348
Example:
346349
```py
347350
chain = (
348351
chain
349-
.settings(cache=True, parallel=8)
352+
.settings(cache=True, parallel=8, batch_rows=300)
350353
.map(laion=process_webdataset(spec=WDSLaion), params="file")
351354
)
352355
```
@@ -356,7 +359,14 @@ def settings(
356359
settings = copy.copy(self._settings)
357360
settings.add(
358361
Settings(
359-
cache, parallel, workers, min_task_size, prefetch, namespace, project
362+
cache,
363+
parallel,
364+
workers,
365+
min_task_size,
366+
prefetch,
367+
namespace,
368+
project,
369+
batch_rows,
360370
)
361371
)
362372
return self._evolve(settings=settings, _sys=sys)
@@ -711,7 +721,7 @@ def map(
711721

712722
return self._evolve(
713723
query=self._query.add_signals(
714-
udf_obj.to_udf_wrapper(),
724+
udf_obj.to_udf_wrapper(self._settings.batch_rows),
715725
**self._settings.to_dict(),
716726
),
717727
signal_schema=self.signals_schema | udf_obj.output,
@@ -749,7 +759,7 @@ def gen(
749759
udf_obj.prefetch = prefetch
750760
return self._evolve(
751761
query=self._query.generate(
752-
udf_obj.to_udf_wrapper(),
762+
udf_obj.to_udf_wrapper(self._settings.batch_rows),
753763
**self._settings.to_dict(),
754764
),
755765
signal_schema=udf_obj.output,
@@ -885,7 +895,7 @@ def my_agg(files: list[File]) -> Iterator[tuple[File, int]]:
885895
udf_obj = self._udf_to_obj(Aggregator, func, params, output, signal_map)
886896
return self._evolve(
887897
query=self._query.generate(
888-
udf_obj.to_udf_wrapper(),
898+
udf_obj.to_udf_wrapper(self._settings.batch_rows),
889899
partition_by=processed_partition_by,
890900
**self._settings.to_dict(),
891901
),
@@ -919,9 +929,10 @@ def batch_map(
919929
```
920930
"""
921931
udf_obj = self._udf_to_obj(BatchMapper, func, params, output, signal_map)
932+
922933
return self._evolve(
923934
query=self._query.add_signals(
924-
udf_obj.to_udf_wrapper(batch),
935+
udf_obj.to_udf_wrapper(self._settings.batch_rows, batch=batch),
925936
**self._settings.to_dict(),
926937
),
927938
signal_schema=self.signals_schema | udf_obj.output,

src/datachain/lib/dc/records.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
P = ParamSpec("P")
1717

18+
READ_RECORDS_BATCH_SIZE = 10000
19+
1820

1921
def read_records(
2022
to_insert: Optional[Union[dict, Iterable[dict]]],
@@ -41,7 +43,7 @@ def read_records(
4143
Notes:
4244
This call blocks until all records are inserted.
4345
"""
44-
from datachain.query.dataset import INSERT_BATCH_SIZE, adjust_outputs, get_col_types
46+
from datachain.query.dataset import adjust_outputs, get_col_types
4547
from datachain.sql.types import SQLType
4648
from datachain.utils import batched
4749

@@ -94,7 +96,7 @@ def read_records(
9496
{c.name: c.type for c in columns if isinstance(c.type, SQLType)},
9597
)
9698
records = (adjust_outputs(warehouse, record, col_types) for record in to_insert)
97-
for chunk in batched(records, INSERT_BATCH_SIZE):
99+
for chunk in batched(records, READ_RECORDS_BATCH_SIZE):
98100
warehouse.insert_rows(table, chunk)
99101
warehouse.insert_rows_done(table)
100102
return read_dataset(name=dsr.full_name, session=session, settings=settings)

src/datachain/lib/settings.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from datachain.lib.utils import DataChainParamsError
2+
from datachain.utils import DEFAULT_CHUNK_ROWS
23

34

45
class SettingsError(DataChainParamsError):
@@ -16,6 +17,7 @@ def __init__(
1617
prefetch=None,
1718
namespace=None,
1819
project=None,
20+
batch_rows=None,
1921
):
2022
self._cache = cache
2123
self.parallel = parallel
@@ -24,6 +26,7 @@ def __init__(
2426
self.prefetch = prefetch
2527
self.namespace = namespace
2628
self.project = project
29+
self._chunk_rows = batch_rows
2730

2831
if not isinstance(cache, bool) and cache is not None:
2932
raise SettingsError(
@@ -53,6 +56,18 @@ def __init__(
5356
f", {min_task_size.__class__.__name__} was given"
5457
)
5558

59+
if batch_rows is not None and not isinstance(batch_rows, int):
60+
raise SettingsError(
61+
"'batch_rows' argument must be int or None"
62+
f", {batch_rows.__class__.__name__} was given"
63+
)
64+
65+
if batch_rows is not None and batch_rows <= 0:
66+
raise SettingsError(
67+
"'batch_rows' argument must be positive integer"
68+
f", {batch_rows} was given"
69+
)
70+
5671
@property
5772
def cache(self):
5873
return self._cache if self._cache is not None else False
@@ -61,6 +76,10 @@ def cache(self):
6176
def workers(self):
6277
return self._workers if self._workers is not None else False
6378

79+
@property
80+
def batch_rows(self):
81+
return self._chunk_rows if self._chunk_rows is not None else DEFAULT_CHUNK_ROWS
82+
6483
def to_dict(self):
6584
res = {}
6685
if self._cache is not None:
@@ -75,6 +94,8 @@ def to_dict(self):
7594
res["namespace"] = self.namespace
7695
if self.project is not None:
7796
res["project"] = self.project
97+
if self._chunk_rows is not None:
98+
res["batch_rows"] = self._chunk_rows
7899
return res
79100

80101
def add(self, settings: "Settings"):
@@ -86,3 +107,5 @@ def add(self, settings: "Settings"):
86107
self.project = settings.project or self.project
87108
if settings.prefetch is not None:
88109
self.prefetch = settings.prefetch
110+
if settings._chunk_rows is not None:
111+
self._chunk_rows = settings._chunk_rows

src/datachain/lib/udf.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,21 @@ def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
6262
return self.udf.get_batching(use_partitioning)
6363

6464
@property
65-
def batch(self):
66-
return self.udf.batch
65+
def batch_rows(self):
66+
return self.udf.batch_rows
6767

6868

6969
@attrs.define(slots=False)
7070
class UDFAdapter:
7171
inner: "UDFBase"
7272
output: UDFOutputSpec
73+
batch_rows: Optional[int] = None
7374
batch: int = 1
7475

7576
def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
7677
if use_partitioning:
7778
return Partition()
79+
7880
if self.batch == 1:
7981
return NoBatching()
8082
if self.batch > 1:
@@ -233,10 +235,15 @@ def verbose_name(self):
233235
def signal_names(self) -> Iterable[str]:
234236
return self.output.to_udf_spec().keys()
235237

236-
def to_udf_wrapper(self, batch: int = 1) -> UDFAdapter:
238+
def to_udf_wrapper(
239+
self,
240+
batch_rows: Optional[int] = None,
241+
batch: int = 1,
242+
) -> UDFAdapter:
237243
return UDFAdapter(
238244
self,
239245
self.output.to_udf_spec(),
246+
batch_rows,
240247
batch,
241248
)
242249

src/datachain/query/dataset.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -333,32 +333,24 @@ def process_udf_outputs(
333333
udf_table: "Table",
334334
udf_results: Iterator[Iterable["UDFResult"]],
335335
udf: "UDFAdapter",
336-
batch_size: int = INSERT_BATCH_SIZE,
337336
cb: Callback = DEFAULT_CALLBACK,
338337
) -> None:
339-
import psutil
340-
341-
rows: list[UDFResult] = []
342338
# Optimization: Compute row types once, rather than for every row.
343339
udf_col_types = get_col_types(warehouse, udf.output)
340+
batch_rows = udf.batch_rows or INSERT_BATCH_SIZE
344341

345-
for udf_output in udf_results:
346-
if not udf_output:
347-
continue
348-
with safe_closing(udf_output):
349-
for row in udf_output:
350-
cb.relative_update()
351-
rows.append(adjust_outputs(warehouse, row, udf_col_types))
352-
if len(rows) >= batch_size or (
353-
len(rows) % 10 == 0 and psutil.virtual_memory().percent > 80
354-
):
355-
for row_chunk in batched(rows, batch_size):
356-
warehouse.insert_rows(udf_table, row_chunk)
357-
rows.clear()
342+
def _insert_rows():
343+
for udf_output in udf_results:
344+
if not udf_output:
345+
continue
346+
347+
with safe_closing(udf_output):
348+
for row in udf_output:
349+
cb.relative_update()
350+
yield adjust_outputs(warehouse, row, udf_col_types)
358351

359-
if rows:
360-
for row_chunk in batched(rows, batch_size):
361-
warehouse.insert_rows(udf_table, row_chunk)
352+
for row_chunk in batched(_insert_rows(), batch_rows):
353+
warehouse.insert_rows(udf_table, row_chunk)
362354

363355
warehouse.insert_rows_done(udf_table)
364356

@@ -401,6 +393,7 @@ class UDFStep(Step, ABC):
401393
min_task_size: Optional[int] = None
402394
is_generator = False
403395
cache: bool = False
396+
batch_rows: Optional[int] = None
404397

405398
@abstractmethod
406399
def create_udf_table(self, query: Select) -> "Table":
@@ -602,6 +595,7 @@ def clone(self, partition_by: Optional[PartitionByType] = None) -> "Self":
602595
parallel=self.parallel,
603596
workers=self.workers,
604597
min_task_size=self.min_task_size,
598+
batch_rows=self.batch_rows,
605599
)
606600
return self.__class__(self.udf, self.catalog)
607601

@@ -1633,6 +1627,7 @@ def add_signals(
16331627
min_task_size: Optional[int] = None,
16341628
partition_by: Optional[PartitionByType] = None,
16351629
cache: bool = False,
1630+
batch_rows: Optional[int] = None,
16361631
) -> "Self":
16371632
"""
16381633
Adds one or more signals based on the results from the provided UDF.
@@ -1658,6 +1653,7 @@ def add_signals(
16581653
workers=workers,
16591654
min_task_size=min_task_size,
16601655
cache=cache,
1656+
batch_rows=batch_rows,
16611657
)
16621658
)
16631659
return query
@@ -1679,6 +1675,7 @@ def generate(
16791675
namespace: Optional[str] = None,
16801676
project: Optional[str] = None,
16811677
cache: bool = False,
1678+
batch_rows: Optional[int] = None,
16821679
) -> "Self":
16831680
query = self.clone()
16841681
steps = query.steps
@@ -1691,6 +1688,7 @@ def generate(
16911688
workers=workers,
16921689
min_task_size=min_task_size,
16931690
cache=cache,
1691+
batch_rows=batch_rows,
16941692
)
16951693
)
16961694
return query

0 commit comments

Comments
 (0)