Open
Description
🐛 Bug
MRC 👇
from stable_baselines3.common.preprocessing import is_image_space
import gymnasium as gym
import numpy as np
from gymnasium.wrappers import FrameStackObservation
image_space = gym.spaces.Box(0, 255, (3, 64, 64), np.uint8) # a basic RGB 64x64 image
class DummyEnv(gym.Env):
def __init__(self):
super().__init__()
self.observation_space = image_space
env = DummyEnv()
print(is_image_space(env.observation_space)) # True
env = FrameStackObservation(env, stack_size=2)
print(env.observation_space.shape) # (2, 3, 64, 64), stacked obs on dim=0
print(is_image_space(env.observation_space)) # False
From here it seems to be due to the fact that sb3 expects images to be tensors of dimension strictly equal to 3.
I am wondering why not setting the check to be >=3
instead of strictly equal. In this way, one would still pass the image check (though I reckon might have problems with NatureCNN
as I am not sure how it would handle 4-dimensional tensors).
Worth adding that:
- A workaround this bug is to transform the observation using the
TransformObservation
env wrapper fromgymnasium
, merging the two first dimension into one - This is precisely what sb3 does in their VecFrameStack!
Worth saying this can have very unintended consequence: if one passes an image through Gymnasium's frame stacking and then uses sb3, the image won't be recognized as such, as the feature extractor for that image will be set to be Flatten
(instead of NatureCNN
).
Happy to open a PR to change this check, but I wanted to double check it made sense first!
To Reproduce
see above
Relevant log output / Error message
see above
System Info
- OS: macOS-15.0.1-arm64-arm-64bit Darwin Kernel Version 24.0.0: Tue Sep 24 23:36:26 PDT 2024; root:xnu-11215.1.12~1/RELEASE_ARM64_T8103
- Python: 3.12.9
- Stable-Baselines3: 2.5.0
- PyTorch: 2.6.0
- GPU Enabled: False
- Numpy: 2.2.3
- Cloudpickle: 3.1.1
- Gymnasium: 1.0.0
Checklist
- My issue does not relate to a custom gym environment. (Use the custom gym env template instead)
- I have checked that there is no similar issue in the repo
- I have read the documentation
- I have provided a minimal and working example to reproduce the bug
- I've used the markdown code blocks for both code and stack traces.