11import contextlib
22import math
3- import sys
43from abc import ABC , abstractmethod
54from collections .abc import Generator , Sequence
65from typing import Callable , Optional , Union
76
87import sqlalchemy as sa
98
109from datachain .data_storage .schema import PARTITION_COLUMN_ID
10+ from datachain .lib .memory_utils import estimate_row_memory , is_memory_usage_high
1111from datachain .query .utils import get_query_column
1212
1313RowsOutputBatch = Sequence [Sequence ]
1414RowsOutput = Union [Sequence , RowsOutputBatch ]
1515
16- OBJECT_OVERHEAD_BYTES = 100
17-
1816
1917class BatchingStrategy (ABC ):
2018 """BatchingStrategy provides means of batching UDF executions."""
@@ -116,35 +114,12 @@ def __init__(
116114 # If we yield individual rows, set is_batching to False
117115 self .is_batching = is_input_batched
118116
119- def _estimate_row_memory (self , row ) -> int :
120- """Estimate memory usage of a row in bytes."""
121- if not row :
122- return 0
123-
124- total_size = 0
125- for item in row :
126- if isinstance (item , (str , bytes , int , float , bool )):
127- total_size += sys .getsizeof (item )
128- elif isinstance (item , (list , tuple )):
129- total_size += sys .getsizeof (item )
130- for subitem in item :
131- total_size += sys .getsizeof (subitem )
132- else :
133- # For complex objects, use a conservative estimate
134- total_size += (
135- sys .getsizeof (item ) + OBJECT_OVERHEAD_BYTES
136- ) # Add buffer for object overhead
137-
138- return total_size
139-
140117 def __call__ (
141118 self ,
142119 execute : Callable ,
143120 query : sa .Select ,
144121 id_col : Optional [sa .ColumnElement ] = None ,
145122 ) -> Generator [RowsOutput , None , None ]:
146- import psutil
147-
148123 from datachain .data_storage .warehouse import SELECT_BATCH_SIZE
149124
150125 ids_only = False
@@ -162,7 +137,7 @@ def __call__(
162137
163138 with contextlib .closing (execute (query , page_size = page_size )) as chunk_rows :
164139 for row in chunk_rows :
165- row_memory = self . _estimate_row_memory (row )
140+ row_memory = estimate_row_memory (row )
166141 row_count += 1
167142
168143 # Check if adding this row would exceed limits
@@ -171,7 +146,7 @@ def __call__(
171146 should_yield = (
172147 len (results ) >= self .max_rows
173148 or current_memory + row_memory > self .max_memory_bytes
174- or (row_count % 100 == 0 and psutil . virtual_memory (). percent > 80 )
149+ or (row_count % 100 == 0 and is_memory_usage_high () )
175150 )
176151
177152 if should_yield and results : # Yield current batch if we have one
0 commit comments