Skip to content

[Feature Request] Add a next_observations field to RolloutBufferSamples #1328

Closed as not planned
@euanong

Description

@euanong

🚀 Feature

When sampling from a RolloutBuffer, we return RolloutBufferSamples containing tensors of observations, actions etc.

def _get_samples(
self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
data = (
self.observations[batch_inds],

It would be nice if RolloutBufferSamples could also contain a batch of next observations (alongside a mask that, for each observation, tells us whether that observation has a successor).

Motivation

I'm implementing an RL pipeline in which I extend PPO with a custom loss. For this custom loss, I need access to (observation, next observation) pairs.

In the PPO implementation

# train for n_epochs epochs
for epoch in range(self.n_epochs):
approx_kl_divs = []
# Do a complete pass on the rollout buffer
for rollout_data in self.rollout_buffer.get(self.batch_size):
actions = rollout_data.actions

each batch of rollout data over which we compute the PPO loss is a RolloutBufferSample -- and, as these consist of a random subset of observations from the RolloutBuffer, we do not have enough information to compute the next observation for each observation in the batch.

Pitch

I have already implemented this feature and submitted it as a PR [to be linked after submission].

Alternatives

Alternatively, we could return the indices of the sampled elements with respect to the original buffer. While this may allow for more general buffer manipulation, this feels less pleasant to use.

Additional context

No response

Checklist

  • I have checked that there is no similar issue in the repo

Metadata

Metadata

Assignees

No one assigned

    Labels

    duplicateThis issue or pull request already existsenhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions