imitation.policies.replay_buffer_wrapper#

Wrapper for reward labeling for transitions sampled from a replay buffer.

Classes

ReplayBufferRewardWrapper(buffer_size, ...)

Relabel the rewards in transitions sampled from a ReplayBuffer.

class imitation.policies.replay_buffer_wrapper.ReplayBufferRewardWrapper(buffer_size, observation_space, action_space, *, replay_buffer_class, reward_fn, **kwargs)[source]#

Bases: ReplayBuffer

Relabel the rewards in transitions sampled from a ReplayBuffer.

__init__(buffer_size, observation_space, action_space, *, replay_buffer_class, reward_fn, **kwargs)[source]#

Builds ReplayBufferRewardWrapper.

Parameters
  • buffer_size (int) – Max number of elements in the buffer

  • observation_space (Space) – Observation space

  • action_space (Space) – Action space

  • replay_buffer_class (Type[ReplayBuffer]) – Class of the replay buffer.

  • reward_fn (RewardFn) – Reward function for reward relabeling.

  • **kwargs – keyword arguments for ReplayBuffer.

actions: ndarray#
add(*args, **kwargs)[source]#

Add elements to the buffer.

dones: ndarray#
property full: bool#
Return type

bool

next_observations: ndarray#
observations: ndarray#
property pos: int#
Return type

int

rewards: ndarray#
sample(*args, **kwargs)[source]#

Sample elements from the replay buffer. Custom sampling when using memory efficient variant, as we should not sample the element with index self.pos See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274

Parameters
  • batch_size – Number of element to sample

  • env – associated gym VecEnv to normalize the observations/rewards when sampling

Returns

timeouts: ndarray#