Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions src/lerobot/policies/diffusion/configuration_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ class DiffusionConfig(PreTrainedConfig):
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision backbone. If None, no resizing is done.
crop_ratio: Ratio to calculate crop size from resize_shape (crop = resize_shape * crop_ratio).
Set to 1.0 to disable cropping.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
Expand Down Expand Up @@ -123,7 +124,8 @@ class DiffusionConfig(PreTrainedConfig):
# Architecture / modeling.
# Vision backbone.
vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84)
resize_shape: tuple[int, int] | None = (96, 96)
crop_ratio: float = 0.90
crop_is_random: bool = True
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
Expand Down Expand Up @@ -207,14 +209,9 @@ def validate_features(self) -> None:
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")

if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
)
if self.resize_shape is not None:
# crop_shape is derived from resize_shape, so it always fits
pass

# Check that all input images have the same shape.
if len(self.image_features) > 0:
Expand All @@ -236,3 +233,9 @@ def action_delta_indices(self) -> list:
@property
def reward_delta_indices(self) -> None:
return None

@property
def crop_shape(self) -> tuple[int, int] | None:
if self.resize_shape is None or self.crop_ratio >= 1.0:
return None
return (int(self.resize_shape[0] * self.crop_ratio), int(self.resize_shape[1] * self.crop_ratio))
15 changes: 12 additions & 3 deletions src/lerobot/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,12 +446,18 @@ class DiffusionRgbEncoder(nn.Module):
def __init__(self, config: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if config.crop_shape is not None:
if config.resize_shape is not None:
self.resize = torchvision.transforms.Resize(config.resize_shape)
else:
self.resize = None

crop_shape = config.crop_shape
if crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
Expand Down Expand Up @@ -500,6 +506,9 @@ def forward(self, x: Tensor) -> Tensor:
(B, D) image feature.
"""
# Preprocess: maybe crop (if it was set up in the __init__).

if self.resize is not None:
x = self.resize(x)
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
Expand Down