@@ -324,29 +324,32 @@ def settings(
324324 sys : Optional [bool ] = None ,
325325 namespace : Optional [str ] = None ,
326326 project : Optional [str ] = None ,
327+ batch_rows : Optional [int ] = None ,
327328 ) -> "Self" :
328329 """Change settings for chain.
329330
330331 This function changes specified settings without changing not specified ones.
331332 It returns chain, so, it can be chained later with next operation.
332333
333334 Parameters:
334- cache : data caching (default=False)
335+ cache : data caching. (default=False)
335336 parallel : number of thread for processors. True is a special value to
336- enable all available CPUs (default=1)
337+ enable all available CPUs. (default=1)
337338 workers : number of distributed workers. Only for Studio mode. (default=1)
338- min_task_size : minimum number of tasks (default=1)
339- prefetch: number of workers to use for downloading files in advance.
339+ min_task_size : minimum number of tasks. (default=1)
340+ prefetch : number of workers to use for downloading files in advance.
340341 This is enabled by default and uses 2 workers.
341342 To disable prefetching, set it to 0.
342- namespace: namespace name.
343- project: project name.
343+ namespace : namespace name.
344+ project : project name.
345+ batch_rows : row limit per insert to balance speed and memory usage.
346+ (default=2000)
344347
345348 Example:
346349 ```py
347350 chain = (
348351 chain
349- .settings(cache=True, parallel=8)
352+ .settings(cache=True, parallel=8, batch_rows=300 )
350353 .map(laion=process_webdataset(spec=WDSLaion), params="file")
351354 )
352355 ```
@@ -356,7 +359,14 @@ def settings(
356359 settings = copy .copy (self ._settings )
357360 settings .add (
358361 Settings (
359- cache , parallel , workers , min_task_size , prefetch , namespace , project
362+ cache ,
363+ parallel ,
364+ workers ,
365+ min_task_size ,
366+ prefetch ,
367+ namespace ,
368+ project ,
369+ batch_rows ,
360370 )
361371 )
362372 return self ._evolve (settings = settings , _sys = sys )
@@ -711,7 +721,7 @@ def map(
711721
712722 return self ._evolve (
713723 query = self ._query .add_signals (
714- udf_obj .to_udf_wrapper (),
724+ udf_obj .to_udf_wrapper (self . _settings . batch_rows ),
715725 ** self ._settings .to_dict (),
716726 ),
717727 signal_schema = self .signals_schema | udf_obj .output ,
@@ -749,7 +759,7 @@ def gen(
749759 udf_obj .prefetch = prefetch
750760 return self ._evolve (
751761 query = self ._query .generate (
752- udf_obj .to_udf_wrapper (),
762+ udf_obj .to_udf_wrapper (self . _settings . batch_rows ),
753763 ** self ._settings .to_dict (),
754764 ),
755765 signal_schema = udf_obj .output ,
@@ -885,7 +895,7 @@ def my_agg(files: list[File]) -> Iterator[tuple[File, int]]:
885895 udf_obj = self ._udf_to_obj (Aggregator , func , params , output , signal_map )
886896 return self ._evolve (
887897 query = self ._query .generate (
888- udf_obj .to_udf_wrapper (),
898+ udf_obj .to_udf_wrapper (self . _settings . batch_rows ),
889899 partition_by = processed_partition_by ,
890900 ** self ._settings .to_dict (),
891901 ),
@@ -919,9 +929,10 @@ def batch_map(
919929 ```
920930 """
921931 udf_obj = self ._udf_to_obj (BatchMapper , func , params , output , signal_map )
932+
922933 return self ._evolve (
923934 query = self ._query .add_signals (
924- udf_obj .to_udf_wrapper (batch ),
935+ udf_obj .to_udf_wrapper (self . _settings . batch_rows , batch = batch ),
925936 ** self ._settings .to_dict (),
926937 ),
927938 signal_schema = self .signals_schema | udf_obj .output ,
0 commit comments