Skip to content

Commit 8add786

Browse files
SilvanCodeslhoestq
andauthored
Add missing property on RepeatExamplesIterable (#7581)
* Add missing property * Fix shard_data_sources for RepeatExamplesIterable * Update src/datasets/iterable_dataset.py --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 84963c7 commit 8add786

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/datasets/iterable_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,16 +1674,16 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RepeatExample
16741674
"""Shuffle the underlying iterable, then repeat."""
16751675
return RepeatExamplesIterable(self.ex_iterable.shuffle_data_sources(generator), num_times=self.num_times)
16761676

1677-
def shard_data_sources(self, worker_id: int, num_workers: int) -> "RepeatExamplesIterable":
1677+
def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "RepeatExamplesIterable":
16781678
"""Shard, then repeat shards."""
16791679
return RepeatExamplesIterable(
1680-
self.ex_iterable.shard_data_sources(worker_id, num_workers),
1680+
self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
16811681
num_times=self.num_times,
16821682
)
16831683

16841684
@property
1685-
def n_shards(self) -> int:
1686-
return self.ex_iterable.n_shards
1685+
def num_shards(self) -> int:
1686+
return self.ex_iterable.num_shards
16871687

16881688

16891689
class TakeExamplesIterable(_BaseExamplesIterable):

0 commit comments

Comments
 (0)