imitation.data.wrappers#

Environment wrappers for collecting rollouts.

Classes

BufferingWrapper(venv[, ...])

Saves transitions of underlying VecEnv.

RolloutInfoWrapper(env)

Add the entire episode's rewards and observations to info at episode end.

class imitation.data.wrappers.BufferingWrapper(venv, error_on_premature_reset=True)[source]#

Bases: VecEnvWrapper

Saves transitions of underlying VecEnv.

Retrieve saved transitions using pop_transitions().

__init__(venv, error_on_premature_reset=True)[source]#

Builds BufferingWrapper.

Parameters
  • venv (VecEnv) – The wrapped VecEnv.

  • error_on_premature_reset (bool) – Error if reset() is called on this wrapper and there are saved samples that haven’t yet been accessed.

error_on_premature_event: bool#
n_transitions: Optional[int]#
pop_finished_trajectories()[source]#

Pops recorded complete trajectories trajs and episode lengths ep_lens.

Return type

Tuple[Sequence[TrajectoryWithRew], Sequence[int]]

Returns

A tuple (trajs, ep_lens) where trajs is a sequence of trajectories including the terminal state (but possibly missing initial states, if pop_trajectories was previously called) and ep_lens is a sequence of episode lengths. Note the episode length will be longer than the trajectory length when the trajectory misses initial states.

pop_trajectories()[source]#

Pops recorded trajectories trajs and episode lengths ep_lens.

Return type

Tuple[Sequence[TrajectoryWithRew], Sequence[int]]

Returns

A tuple (trajs, ep_lens). trajs is a sequence of trajectory fragments, consisting of data collected after the last call to pop_trajectories. They may miss initial states (if pop_trajectories previously returned a fragment for that episode), and terminal states (if the episode has yet to complete). ep_lens is the total length of completed episodes.

pop_transitions()[source]#

Pops recorded transitions, returning them as an instance of Transitions.

Return type

TransitionsWithRew

Returns

All transitions recorded since the last call.

Raises

RuntimeError – empty (no transitions recorded since last pop).

reset(**kwargs)[source]#

Reset all the environments and return an array of observations, or a tuple of observation arrays.

If step_async is still doing work, that work will be cancelled and step_wait() should not be called until step_async() is invoked again.

Returns

observation

step_async(actions)[source]#

Tell all the environments to start taking a step with the given actions. Call step_wait() to get the results of the step.

You should not call this if a step_async run is already pending.

step_wait()[source]#

Wait for the step taken with step_async().

Returns

observation, reward, done, information

class imitation.data.wrappers.RolloutInfoWrapper(env)[source]#

Bases: Wrapper

Add the entire episode’s rewards and observations to info at episode end.

Whenever done=True, info[“rollouts”] is a dict with keys “obs” and “rews”, whose corresponding values hold the NumPy arrays containing the raw observations and rewards seen during this episode.

__init__(env)[source]#

Builds RolloutInfoWrapper.

Parameters

env (Env) – Environment to wrap.

reset(**kwargs)[source]#

Uses the reset() of the env that can be overwritten to change the returned data.

step(action)[source]#

Uses the step() of the env that can be overwritten to change the returned data.