Skip to content

Commit ced5ea6

Browse files
authored
Fix double mapper call when using to_storage method (#1258)
1 parent 88e4bd3 commit ced5ea6

2 files changed

Lines changed: 26 additions & 10 deletions

File tree

src/datachain/lib/dc/datachain.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2419,9 +2419,11 @@ def to_storage(
24192419
ds.to_storage("gs://mybucket", placement="filename")
24202420
```
24212421
"""
2422+
chain = self.persist()
2423+
count = chain.count()
2424+
24222425
if placement == "filename" and (
2423-
self._query.distinct(pathfunc.name(C(f"{signal}__path"))).count()
2424-
!= self._query.count()
2426+
chain._query.distinct(pathfunc.name(C(f"{signal}__path"))).count() != count
24252427
):
24262428
raise ValueError("Files with the same name found")
24272429

@@ -2433,7 +2435,7 @@ def to_storage(
24332435
unit=" files",
24342436
unit_scale=True,
24352437
unit_divisor=10,
2436-
total=self.count(),
2438+
total=count,
24372439
leave=False,
24382440
)
24392441
file_exporter = FileExporter(
@@ -2444,7 +2446,10 @@ def to_storage(
24442446
max_threads=num_threads or 1,
24452447
client_config=client_config,
24462448
)
2447-
file_exporter.run(self.to_values(signal), progress_bar)
2449+
file_exporter.run(
2450+
(rows[0] for rows in chain.to_iter(signal)),
2451+
progress_bar,
2452+
)
24482453

24492454
def shuffle(self) -> "Self":
24502455
"""Shuffle the rows of the chain deterministically."""

tests/func/test_datachain.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from collections.abc import Iterator
1010
from datetime import datetime, timedelta, timezone
1111
from pathlib import Path, PurePosixPath
12-
from unittest.mock import patch
12+
from unittest.mock import Mock, patch
1313

1414
import numpy as np
1515
import pandas as pd
@@ -358,15 +358,24 @@ def test_to_storage(
358358
file_type,
359359
num_threads,
360360
):
361+
mapper = Mock(side_effect=lambda file_path: len(file_path))
362+
361363
ctc = cloud_test_catalog
362364
df = dc.read_storage(ctc.src_uri, type=file_type, session=test_session)
363365
if use_map:
364-
df.settings(cache=use_cache).map(
365-
res=lambda file: file.export(tmp_dir / "output", placement=placement)
366-
).exec()
366+
(
367+
df.settings(cache=use_cache)
368+
.map(mapper, params=["file.path"], output={"path_len": int})
369+
.map(res=lambda file: file.export(tmp_dir / "output", placement=placement))
370+
.exec()
371+
)
367372
else:
368-
df.settings(cache=use_cache).to_storage(
369-
tmp_dir / "output", placement=placement, num_threads=num_threads
373+
(
374+
df.settings(cache=use_cache)
375+
.map(mapper, params=["file.path"], output={"path_len": int})
376+
.to_storage(
377+
tmp_dir / "output", placement=placement, num_threads=num_threads
378+
)
370379
)
371380

372381
expected = {
@@ -387,6 +396,8 @@ def test_to_storage(
387396
with open(tmp_dir / "output" / file_path) as f:
388397
assert f.read() == expected[file.name]
389398

399+
assert mapper.call_count == len(expected)
400+
390401

391402
@pytest.mark.parametrize("use_cache", [True, False])
392403
def test_export_images_files(test_session, tmp_dir, tmp_path, use_cache):

0 commit comments

Comments
 (0)