Skip to content

IterableDatasetDict.map() call removes column_names (in fact info.features) #7568

@mombip

Description

@mombip

When calling IterableDatasetDict.map(), each split’s IterableDataset.map() is invoked without a features argument. While omitting the argument isn’t itself incorrect, the implementation then sets info.features = features, which destroys the original features content. Since IterableDataset.column_names relies on info.features, it ends up broken (None).

Reproduction

  1. Define an IterableDatasetDict with a non-None features schema.
  2. my_iterable_dataset_dict contains "text" column.
  3. Call:
new_dict = my_iterable_dataset_dict.map(
    function=my_fn,
    with_indices=False,
    batched=True,
    batch_size=16,
)
  1. Observe
new_dict["train"].info.features  # {'text': Value(dtype='string', id=None)}
new_dict["train"].column_names   # ['text']
  1. Call:
new_dict = my_iterable_dataset_dict.map(
    function=my_fn,
    with_indices=False,
    batched=True,
    batch_size=16,
    remove_columns=["foo"]
)
  1. Observe:
new_dict["train"].info.features  # → None
new_dict["train"].column_names   # → None
  1. Internally, in dataset_dict.py this loop omits features (code):
for split, dataset in self.items():
    dataset_dict[split] = dataset.map(
        function=function,
        with_indices=with_indices,
        input_columns=input_columns,
        batched=batched,
        batch_size=batch_size,
        drop_last_batch=drop_last_batch,
        remove_columns=remove_columns,
        fn_kwargs=fn_kwargs,
        # features omitted → defaults to None
    )
  1. Then inside IterableDataset.map() (code) correct info.features is replaced by features which is None:
info = self.info.copy()
info.features = features  # features is None here
return IterableDataset(..., info=info, ...)

Suggestion
It looks like this replacement was added intentionally but maybe should be done only if features is not None.

Workarround:
SFTTrainer calls dataset.map() several times and then fails on NoneType when iterating dataset.column_names.
I decided to write this patch - works form me.

def patch_iterable_dataset_map():
    _orig_map = IterableDataset.map

    def _patched_map(self, *args, **kwargs):
        if "features" not in kwargs or kwargs["features"] is None:
            kwargs["features"] = self.info.features
        return _orig_map(self, *args, **kwargs)

    IterableDataset.map = _patched_map

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions