Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions src/lerobot/datasets/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import shutil
from pathlib import Path

import datasets
import pandas as pd
import tqdm

Expand All @@ -32,6 +33,7 @@
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
get_file_size_in_mb,
get_hf_features_from_features,
get_parquet_file_size_in_mb,
to_parquet_with_hf_images,
update_chunk_file_indices,
Expand Down Expand Up @@ -402,12 +404,21 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
}

unique_chunk_file_ids = sorted(unique_chunk_file_ids)
contains_images = len(dst_meta.image_keys) > 0

# retrieve features schema for proper image typing in parquet
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None

for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
chunk_index=src_chunk_idx, file_index=src_file_idx
)
df = pd.read_parquet(src_path)
if contains_images:
# Use HuggingFace datasets to read source data to preserve image format
src_ds = datasets.Dataset.from_parquet(str(src_path))
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When reading image datasets using datasets.Dataset.from_parquet, the features parameter should be passed to ensure image columns are properly loaded with the correct schema. Without this, the image data might not be correctly preserved during the read-update-write cycle. Consider using datasets.Dataset.from_parquet(str(src_path), features=hf_features) to maintain schema consistency.

Suggested change
src_ds = datasets.Dataset.from_parquet(str(src_path))
src_ds = datasets.Dataset.from_parquet(str(src_path), features=hf_features)

Copilot uses AI. Check for mistakes.
df = src_ds.to_pandas()
else:
df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta)

data_idx = append_or_create_parquet_file(
Expand All @@ -417,8 +428,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
data_files_size_in_mb,
chunk_size,
DEFAULT_DATA_PATH,
contains_images=len(dst_meta.image_keys) > 0,
contains_images=contains_images,
aggr_root=dst_meta.root,
hf_features=hf_features,
)

return data_idx
Expand Down Expand Up @@ -488,6 +500,7 @@ def append_or_create_parquet_file(
default_path: str,
contains_images: bool = False,
aggr_root: Path = None,
hf_features: datasets.Features | None = None,
):
"""Appends data to an existing parquet file or creates a new one based on size constraints.

Expand All @@ -503,6 +516,7 @@ def append_or_create_parquet_file(
default_path: Format string for generating file paths.
contains_images: Whether the data contains images requiring special handling.
aggr_root: Root path for the aggregated dataset.
hf_features: Optional HuggingFace Features schema for proper image typing.

Returns:
dict: Updated index dictionary with current chunk and file indices.
Expand All @@ -512,7 +526,7 @@ def append_or_create_parquet_file(
if not dst_path.exists():
dst_path.parent.mkdir(parents=True, exist_ok=True)
if contains_images:
to_parquet_with_hf_images(df, dst_path)
to_parquet_with_hf_images(df, dst_path, features=hf_features)
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function to_parquet_with_hf_images is being called with a features parameter, but the current function signature in utils.py only accepts (df: pandas.DataFrame, path: Path) and does not have a features parameter. This will cause a TypeError at runtime. The function signature needs to be updated to accept and use the features parameter to properly preserve the HuggingFace Image schema.

Copilot uses AI. Check for mistakes.
else:
df.to_parquet(dst_path)
return idx
Expand All @@ -527,12 +541,17 @@ def append_or_create_parquet_file(
final_df = df
target_path = new_path
else:
existing_df = pd.read_parquet(dst_path)
if contains_images:
# Use HuggingFace datasets to read existing data to preserve image format
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When reading existing image datasets using datasets.Dataset.from_parquet, the features parameter should be passed to ensure image columns are properly loaded with the correct schema. Without this, the image data might not be correctly preserved during the read-merge-write cycle. Consider using datasets.Dataset.from_parquet(str(dst_path), features=hf_features) to maintain schema consistency.

Suggested change
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
existing_ds = datasets.Dataset.from_parquet(str(dst_path), features=hf_features)

Copilot uses AI. Check for mistakes.
existing_df = existing_ds.to_pandas()
else:
existing_df = pd.read_parquet(dst_path)
final_df = pd.concat([existing_df, df], ignore_index=True)
target_path = dst_path

if contains_images:
to_parquet_with_hf_images(final_df, target_path)
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function to_parquet_with_hf_images is being called with a features parameter, but the current function signature in utils.py only accepts (df: pandas.DataFrame, path: Path) and does not have a features parameter. This will cause a TypeError at runtime. The function signature needs to be updated to accept and use the features parameter to properly preserve the HuggingFace Image schema.

Copilot uses AI. Check for mistakes.
else:
final_df.to_parquet(target_path)

Expand Down
13 changes: 11 additions & 2 deletions src/lerobot/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,12 +1172,21 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
)


def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
def to_parquet_with_hf_images(
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
) -> None:
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
This way, it can be loaded by HF dataset and correctly formatted images are returned.

Args:
df: DataFrame to write to parquet.
path: Path to write the parquet file.
features: Optional HuggingFace Features schema. If provided, ensures image columns
are properly typed as Image() in the parquet schema.
"""
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
ds.to_parquet(path)


def item_to_torch(item: dict) -> dict:
Expand Down
145 changes: 145 additions & 0 deletions tests/datasets/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from unittest.mock import patch

import datasets
import torch

from lerobot.datasets.aggregate import aggregate_datasets
Expand Down Expand Up @@ -380,3 +381,147 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
for key in aggr_ds.meta.video_keys:
assert key in item, f"Video key {key} missing from item {i}"
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"


def assert_image_schema_preserved(aggr_ds):
"""Test that HuggingFace Image feature schema is preserved in aggregated parquet files.

This verifies the fix for a bug where image columns were written with a generic
struct schema {'bytes': Value('binary'), 'path': Value('string')} instead of
the proper Image() feature type, causing HuggingFace Hub viewer to display
raw dict objects instead of image thumbnails.
"""
image_keys = aggr_ds.meta.image_keys
if not image_keys:
return

# Check that parquet files have proper Image schema
data_dir = aggr_ds.root / "data"
parquet_files = list(data_dir.rglob("*.parquet"))
assert len(parquet_files) > 0, "No parquet files found in aggregated dataset"

for parquet_file in parquet_files:
# Load with HuggingFace datasets to check schema
ds = datasets.Dataset.from_parquet(str(parquet_file))

for image_key in image_keys:
feature = ds.features.get(image_key)
assert feature is not None, f"Image key '{image_key}' not found in parquet schema"
assert isinstance(feature, datasets.Image), (
f"Image key '{image_key}' should have Image() feature type, "
f"but got {type(feature).__name__}: {feature}. "
"This indicates image schema was not preserved during aggregation."
)


def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
"""Test that image frames are correctly preserved after aggregation."""
image_keys = aggr_ds.meta.image_keys
if not image_keys:
return

def images_equal(img1, img2):
return torch.allclose(img1, img2)

# Test the section corresponding to the first dataset (ds_0)
for i in range(len(ds_0)):
assert aggr_ds[i]["index"] == i, (
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
)
for key in image_keys:
assert images_equal(aggr_ds[i][key], ds_0[i][key]), (
f"Image frames at position {i} should be equal between aggregated and ds_0"
)

# Test the section corresponding to the second dataset (ds_1)
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
assert aggr_ds[i]["index"] == i, (
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
)
for key in image_keys:
assert images_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), (
f"Image frames at position {i} should be equal between aggregated and ds_1"
)


def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.

This test specifically verifies that:
1. Image-based datasets can be aggregated correctly
2. The HuggingFace Image() feature type is preserved in parquet files
3. Image data integrity is maintained across aggregation
4. Images can be properly decoded after aggregation

This catches the bug where to_parquet_with_hf_images() was not passing
the features schema, causing image columns to be written as generic
struct types instead of Image() types.
"""
ds_0_num_frames = 50
ds_1_num_frames = 75
ds_0_num_episodes = 2
ds_1_num_episodes = 3

# Create two image-based datasets (use_videos=False)
ds_0 = lerobot_dataset_factory(
root=tmp_path / "image_0",
repo_id=f"{DUMMY_REPO_ID}_image_0",
total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames,
use_videos=False, # Image-based dataset
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "image_1",
repo_id=f"{DUMMY_REPO_ID}_image_1",
total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames,
use_videos=False, # Image-based dataset
)

# Verify source datasets have image keys
assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys"
assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys"

# Aggregate the datasets
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_image_aggr",
aggr_root=tmp_path / "image_aggr",
)

# Load the aggregated dataset
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "image_aggr")
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_image_aggr", root=tmp_path / "image_aggr")

# Verify aggregated dataset has image keys
assert len(aggr_ds.meta.image_keys) > 0, "Aggregated dataset should have image keys"
assert aggr_ds.meta.image_keys == ds_0.meta.image_keys, "Image keys should match source datasets"

# Run standard aggregation assertions
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
expected_total_frames = ds_0_num_frames + ds_1_num_frames

assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)

# Image-specific assertions
assert_image_schema_preserved(aggr_ds)
assert_image_frames_integrity(aggr_ds, ds_0, ds_1)

# Verify images can be accessed and have correct shape
sample_item = aggr_ds[0]
for image_key in aggr_ds.meta.image_keys:
img = sample_item[image_key]
assert isinstance(img, torch.Tensor), f"Image {image_key} should be a tensor"
assert img.dim() == 3, f"Image {image_key} should have 3 dimensions (C, H, W)"
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"

assert_dataset_iteration_works(aggr_ds)