Skip to content

Commit f4febed

Browse files
committed
Address comments
1 parent bcba5f3 commit f4febed

File tree

8 files changed

+179
-65
lines changed

8 files changed

+179
-65
lines changed

src/datachain/lib/memory_utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Memory estimation utilities for DataChain."""
2+
3+
import sys
4+
from typing import Any, Union
5+
6+
# Default batch processing values
7+
DEFAULT_CHUNK_ROWS = 2000
8+
DEFAULT_CHUNK_MB = 1000
9+
10+
# Memory monitoring threshold (percentage)
11+
MEMORY_USAGE_THRESHOLD = 80
12+
13+
# System memory check frequency (every N rows)
14+
MEMORY_CHECK_FREQUENCY = 100
15+
16+
# Shared constant for object overhead estimation
17+
OBJECT_OVERHEAD_BYTES = 100
18+
19+
20+
def estimate_memory_recursive(item: Any) -> int:
21+
if item is None:
22+
return 0
23+
24+
if isinstance(item, (str, bytes, int, float, bool)):
25+
return sys.getsizeof(item)
26+
if isinstance(item, (list, tuple)):
27+
total_size = sys.getsizeof(item)
28+
for subitem in item:
29+
total_size += sys.getsizeof(subitem)
30+
return total_size
31+
# For complex objects, use a conservative estimate
32+
return sys.getsizeof(item) + OBJECT_OVERHEAD_BYTES
33+
34+
35+
def estimate_row_memory(row: Union[list, tuple]) -> int:
36+
if not row:
37+
return 0
38+
39+
total_size = 0
40+
for item in row:
41+
total_size += estimate_memory_recursive(item)
42+
43+
return total_size
44+
45+
46+
def get_system_memory_percent() -> float:
47+
try:
48+
import psutil
49+
50+
return psutil.virtual_memory().percent
51+
except ImportError:
52+
import warnings
53+
54+
warnings.warn(
55+
"psutil not available. Memory-based checks will be skipped. "
56+
"Install psutil to enable memory monitoring.",
57+
UserWarning,
58+
stacklevel=2,
59+
)
60+
return 0.0
61+
62+
63+
def is_memory_usage_high() -> bool:
64+
return get_system_memory_percent() > MEMORY_USAGE_THRESHOLD

src/datachain/lib/settings.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datachain.lib.memory_utils import DEFAULT_CHUNK_MB, DEFAULT_CHUNK_ROWS
12
from datachain.lib.utils import DataChainParamsError
23

34

@@ -63,12 +64,23 @@ def __init__(
6364
f", {chunk_rows.__class__.__name__} was given"
6465
)
6566

67+
if chunk_rows is not None and chunk_rows <= 0:
68+
raise SettingsError(
69+
"'chunk_rows' argument must be positive integer"
70+
f", {chunk_rows} was given"
71+
)
72+
6673
if chunk_mb is not None and not isinstance(chunk_mb, (int, float)):
6774
raise SettingsError(
6875
"'chunk_mb' argument must be int/float or None"
6976
f", {chunk_mb.__class__.__name__} was given"
7077
)
7178

79+
if chunk_mb is not None and chunk_mb <= 0:
80+
raise SettingsError(
81+
f"'chunk_mb' argument must be positive number, {chunk_mb} was given"
82+
)
83+
7284
@property
7385
def cache(self):
7486
return self._cache if self._cache is not None else False
@@ -79,11 +91,11 @@ def workers(self):
7991

8092
@property
8193
def chunk_rows(self):
82-
return self._chunk_rows if self._chunk_rows is not None else 2000
94+
return self._chunk_rows if self._chunk_rows is not None else DEFAULT_CHUNK_ROWS
8395

8496
@property
8597
def chunk_mb(self):
86-
return self._chunk_mb if self._chunk_mb is not None else 1000
98+
return self._chunk_mb if self._chunk_mb is not None else DEFAULT_CHUNK_MB
8799

88100
def to_dict(self):
89101
res = {}
@@ -115,6 +127,6 @@ def add(self, settings: "Settings"):
115127
if settings.prefetch is not None:
116128
self.prefetch = settings.prefetch
117129
if settings._chunk_rows is not None:
118-
self._chunk_rows = settings.chunk_rows
130+
self._chunk_rows = settings._chunk_rows
119131
if settings._chunk_mb is not None:
120132
self._chunk_mb = settings._chunk_mb

src/datachain/lib/udf.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from datachain.dataset import RowDict
1515
from datachain.lib.convert.flatten import flatten
1616
from datachain.lib.file import DataModel, File
17+
from datachain.lib.memory_utils import DEFAULT_CHUNK_MB, DEFAULT_CHUNK_ROWS
1718
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
1819
from datachain.query.batch import (
1920
BatchingStrategy,
@@ -89,14 +90,21 @@ def get_batching(
8990

9091
# If we have explicit chunk_rows/chunk_mb set on this adapter, use them
9192
if self.chunk_rows is not None or self.chunk_mb is not None:
92-
return DynamicBatch(self.chunk_rows, self.chunk_mb, is_input_batched)
93+
return DynamicBatch(
94+
self.chunk_rows if self.chunk_rows is not None else DEFAULT_CHUNK_ROWS,
95+
self.chunk_mb if self.chunk_mb is not None else DEFAULT_CHUNK_MB,
96+
is_input_batched,
97+
)
9398

9499
# If settings are provided and have batch configuration, use appropriate
95100
# batching
96101
if settings:
97102
max_rows: Optional[int] = getattr(settings, "_chunk_rows", None)
98103
max_mem: Optional[Union[int, float]] = getattr(settings, "_chunk_mb", None)
99104
if max_rows is not None or max_mem is not None:
105+
# Use settings values, falling back to defaults if None
106+
max_rows = max_rows if max_rows is not None else DEFAULT_CHUNK_ROWS
107+
max_mem = max_mem if max_mem is not None else DEFAULT_CHUNK_MB
100108
return DynamicBatch(max_rows, max_mem, is_input_batched)
101109

102110
return NoBatching()

src/datachain/query/batch.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
import contextlib
22
import math
3-
import sys
43
from abc import ABC, abstractmethod
54
from collections.abc import Generator, Sequence
65
from typing import Callable, Optional, Union
76

87
import sqlalchemy as sa
98

109
from datachain.data_storage.schema import PARTITION_COLUMN_ID
10+
from datachain.lib.memory_utils import estimate_row_memory, is_memory_usage_high
1111
from datachain.query.utils import get_query_column
1212

1313
RowsOutputBatch = Sequence[Sequence]
1414
RowsOutput = Union[Sequence, RowsOutputBatch]
1515

16-
OBJECT_OVERHEAD_BYTES = 100
17-
1816

1917
class BatchingStrategy(ABC):
2018
"""BatchingStrategy provides means of batching UDF executions."""
@@ -116,35 +114,12 @@ def __init__(
116114
# If we yield individual rows, set is_batching to False
117115
self.is_batching = is_input_batched
118116

119-
def _estimate_row_memory(self, row) -> int:
120-
"""Estimate memory usage of a row in bytes."""
121-
if not row:
122-
return 0
123-
124-
total_size = 0
125-
for item in row:
126-
if isinstance(item, (str, bytes, int, float, bool)):
127-
total_size += sys.getsizeof(item)
128-
elif isinstance(item, (list, tuple)):
129-
total_size += sys.getsizeof(item)
130-
for subitem in item:
131-
total_size += sys.getsizeof(subitem)
132-
else:
133-
# For complex objects, use a conservative estimate
134-
total_size += (
135-
sys.getsizeof(item) + OBJECT_OVERHEAD_BYTES
136-
) # Add buffer for object overhead
137-
138-
return total_size
139-
140117
def __call__(
141118
self,
142119
execute: Callable,
143120
query: sa.Select,
144121
id_col: Optional[sa.ColumnElement] = None,
145122
) -> Generator[RowsOutput, None, None]:
146-
import psutil
147-
148123
from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
149124

150125
ids_only = False
@@ -162,7 +137,7 @@ def __call__(
162137

163138
with contextlib.closing(execute(query, page_size=page_size)) as chunk_rows:
164139
for row in chunk_rows:
165-
row_memory = self._estimate_row_memory(row)
140+
row_memory = estimate_row_memory(row)
166141
row_count += 1
167142

168143
# Check if adding this row would exceed limits
@@ -171,7 +146,7 @@ def __call__(
171146
should_yield = (
172147
len(results) >= self.max_rows
173148
or current_memory + row_memory > self.max_memory_bytes
174-
or (row_count % 100 == 0 and psutil.virtual_memory().percent > 80)
149+
or (row_count % 100 == 0 and is_memory_usage_high())
175150
)
176151

177152
if should_yield and results: # Yield current batch if we have one

src/datachain/utils.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,13 @@
3131
except ImportError:
3232
psutil = None
3333

34-
# Constants for memory estimation
35-
OBJECT_OVERHEAD_BYTES = 100
36-
34+
# Import shared memory utilities
35+
from datachain.lib.memory_utils import (
36+
DEFAULT_CHUNK_MB,
37+
DEFAULT_CHUNK_ROWS,
38+
estimate_memory_recursive,
39+
is_memory_usage_high,
40+
)
3741

3842
logger = logging.getLogger("datachain")
3943

@@ -246,16 +250,14 @@ def _dynamic_batched_core(
246250
current_memory = 0
247251

248252
for row_count, item in enumerate(iterable):
249-
item_memory = _estimate_item_memory(item)
253+
item_memory = estimate_memory_recursive(item)
250254

251255
# Check if adding this item would exceed limits
252256
# Also check system memory usage every 100 items
253257
should_yield = (
254258
len(batch) >= chunk_rows
255259
or current_memory + item_memory > max_memory_bytes
256-
or (
257-
row_count % 100 == 0 and psutil and psutil.virtual_memory().percent > 80
258-
)
260+
or (row_count % 100 == 0 and is_memory_usage_high())
259261
)
260262

261263
if should_yield and batch: # Yield current batch if we have one
@@ -284,7 +286,9 @@ def batched(
284286

285287

286288
def batched_it(
287-
iterable: Iterable[_T_co], chunk_rows: int = 2000, chunk_mb: float = 1000
289+
iterable: Iterable[_T_co],
290+
chunk_rows: int = DEFAULT_CHUNK_ROWS,
291+
chunk_mb: float = DEFAULT_CHUNK_MB,
288292
) -> Iterator[Iterator[_T_co]]:
289293
"""
290294
Batch data into iterators with dynamic sizing
@@ -295,25 +299,6 @@ def batched_it(
295299
)
296300

297301

298-
def _estimate_item_memory(item) -> int:
299-
"""Estimate memory usage of an item in bytes."""
300-
if item is None:
301-
return 0
302-
303-
total_size = 0
304-
if isinstance(item, (str, bytes, int, float, bool)):
305-
total_size += sys.getsizeof(item)
306-
elif isinstance(item, (list, tuple)):
307-
total_size += sys.getsizeof(item)
308-
for subitem in item:
309-
total_size += sys.getsizeof(subitem)
310-
else:
311-
# For complex objects, use a conservative estimate
312-
total_size += sys.getsizeof(item) + OBJECT_OVERHEAD_BYTES
313-
314-
return total_size
315-
316-
317302
def flatten(items):
318303
for item in items:
319304
if isinstance(item, (list, tuple)):

tests/func/test_datachain.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2470,9 +2470,10 @@ def add_one_with_batch_size(x):
24702470
chain = dc.read_values(x=list(range(100)), session=test_session)
24712471
chain_with_settings = chain.settings(chunk_rows=50, chunk_mb=1000)
24722472

2473-
result = chain_with_settings.batch_map(
2474-
add_one_with_batch_size, output={"result": Result}, batch=15
2475-
)
2473+
with pytest.warns(DeprecationWarning):
2474+
result = chain_with_settings.batch_map(
2475+
add_one_with_batch_size, output={"result": Result}, batch=15
2476+
)
24762477

24772478
results = [r[0] for r in result.to_iter("result")]
24782479

@@ -2485,8 +2486,8 @@ def add_one_with_batch_size(x):
24852486

24862487
assert len(results) == 100
24872488

2488-
expected_values = list(range(1, 101))
2489-
actual_values = [r.result for r in results]
2489+
expected_values = set(range(1, 101))
2490+
actual_values = {r.result for r in results}
24902491
assert actual_values == expected_values
24912492

24922493

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Tests for memory utility functions."""
2+
3+
from datachain.lib.memory_utils import (
4+
OBJECT_OVERHEAD_BYTES,
5+
estimate_memory_recursive,
6+
estimate_row_memory,
7+
get_system_memory_percent,
8+
)
9+
10+
11+
def test_estimate_memory_recursive():
12+
"""Test memory estimation for basic types."""
13+
assert estimate_memory_recursive(None) == 0
14+
assert estimate_memory_recursive(42) > 0
15+
assert estimate_memory_recursive("test") > 0
16+
assert estimate_memory_recursive([1, 2, 3]) > 0
17+
18+
19+
def test_estimate_row_memory():
20+
"""Test memory estimation for rows."""
21+
assert estimate_row_memory([]) == 0
22+
assert estimate_row_memory([1, "test", 3.14]) > 0
23+
24+
25+
def test_system_memory_functions():
26+
"""Test system memory monitoring functions."""
27+
memory_percent = get_system_memory_percent()
28+
assert isinstance(memory_percent, (int, float))
29+
assert 0.0 <= memory_percent <= 100.0
30+
31+
32+
def test_object_overhead_constant():
33+
"""Test that OBJECT_OVERHEAD_BYTES is defined."""
34+
assert isinstance(OBJECT_OVERHEAD_BYTES, int)
35+
assert OBJECT_OVERHEAD_BYTES > 0

0 commit comments

Comments
 (0)