Description
🚀 Feature
When sampling from a RolloutBuffer
, we return RolloutBufferSample
s containing tensors of observations, actions etc.
stable-baselines3/stable_baselines3/common/buffers.py
Lines 473 to 479 in 69b94dd
It would be nice if RolloutBufferSamples
could also contain a batch of next observations (alongside a mask that, for each observation, tells us whether that observation has a successor).
Motivation
I'm implementing an RL pipeline in which I extend PPO with a custom loss. For this custom loss, I need access to (observation, next observation) pairs.
In the PPO implementation
stable-baselines3/stable_baselines3/ppo/ppo.py
Lines 192 to 197 in 69b94dd
each batch of rollout data over which we compute the PPO loss is a RolloutBufferSample
-- and, as these consist of a random subset of observations from the RolloutBuffer
, we do not have enough information to compute the next observation for each observation in the batch.
Pitch
I have already implemented this feature and submitted it as a PR [to be linked after submission].
Alternatives
Alternatively, we could return the indices of the sampled elements with respect to the original buffer. While this may allow for more general buffer manipulation, this feels less pleasant to use.
Additional context
No response
Checklist
- I have checked that there is no similar issue in the repo