Source code for imitation.data.wrappers

"""Environment wrappers for collecting rollouts."""

from typing import List, Optional, Sequence, Tuple

import gymnasium as gym
import numpy as np
import numpy.typing as npt
from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper

from imitation.data import rollout, types


[docs]class BufferingWrapper(VecEnvWrapper): """Saves transitions of underlying VecEnv. Retrieve saved transitions using `pop_transitions()`. """ error_on_premature_event: bool _trajectories: List[types.TrajectoryWithRew] _ep_lens: List[int] _init_reset: bool _traj_accum: Optional[rollout.TrajectoryAccumulator] _timesteps: Optional[npt.NDArray[np.int_]] n_transitions: Optional[int]
[docs] def __init__(self, venv: VecEnv, error_on_premature_reset: bool = True): """Builds BufferingWrapper. Args: venv: The wrapped VecEnv. error_on_premature_reset: Error if `reset()` is called on this wrapper and there are saved samples that haven't yet been accessed. """ super().__init__(venv) self.error_on_premature_reset = error_on_premature_reset self._trajectories = [] self._ep_lens = [] self._init_reset = False self._traj_accum = None self._saved_acts = None self._timesteps = None self.n_transitions = None
[docs] def reset(self, **kwargs): if ( self._init_reset and self.error_on_premature_reset and self.n_transitions > 0 ): # noqa: E127 raise RuntimeError("BufferingWrapper reset() before samples were accessed") self._init_reset = True self.n_transitions = 0 obs = self.venv.reset(**kwargs) self._traj_accum = rollout.TrajectoryAccumulator() obs = types.maybe_wrap_in_dictobs(obs) for i, ob in enumerate(obs): self._traj_accum.add_step({"obs": ob}, key=i) self._timesteps = np.zeros((len(obs),), dtype=int) obs = types.maybe_unwrap_dictobs(obs) return obs
[docs] def step_async(self, actions): assert self._init_reset assert self._saved_acts is None self.venv.step_async(actions) self._saved_acts = actions
[docs] def step_wait(self): assert self._init_reset assert self._saved_acts is not None acts, self._saved_acts = self._saved_acts, None obs, rews, dones, infos = self.venv.step_wait() self.n_transitions += self.num_envs self._timesteps += 1 ep_lens = self._timesteps[dones] if len(ep_lens) > 0: self._ep_lens += list(ep_lens) self._timesteps[dones] = 0 finished_trajs = self._traj_accum.add_steps_and_auto_finish( acts, obs, rews, dones, infos, ) self._trajectories.extend(finished_trajs) return obs, rews, dones, infos
def _finish_partial_trajectories(self) -> Sequence[types.TrajectoryWithRew]: """Finishes and returns partial trajectories in `self._traj_accum`.""" assert self._traj_accum is not None trajs = [] for i in range(self.num_envs): # Check that we have any transitions at all. # The number of "transitions" or "timesteps" stored for the ith # environment is the number of step dicts stored in # `partial_trajectories[i]` minus one. We need to offset by one because # the first step dict is comes from `reset()`, not from `step()`. n_transitions = len(self._traj_accum.partial_trajectories[i]) - 1 assert n_transitions >= 0, "Invalid TrajectoryAccumulator state" if n_transitions >= 1: traj = self._traj_accum.finish_trajectory(i, terminal=False) trajs.append(traj) # Reinitialize a partial trajectory starting with the final observation. self._traj_accum.add_step({"obs": traj.obs[-1]}, key=i) return trajs
[docs] def pop_finished_trajectories( self, ) -> Tuple[Sequence[types.TrajectoryWithRew], Sequence[int]]: """Pops recorded complete trajectories `trajs` and episode lengths `ep_lens`. Returns: A tuple `(trajs, ep_lens)` where `trajs` is a sequence of trajectories including the terminal state (but possibly missing initial states, if `pop_trajectories` was previously called) and `ep_lens` is a sequence of episode lengths. Note the episode length will be longer than the trajectory length when the trajectory misses initial states. """ trajectories = self._trajectories ep_lens = self._ep_lens self._trajectories = [] self._ep_lens = [] self.n_transitions = 0 return trajectories, ep_lens
[docs] def pop_trajectories( self, ) -> Tuple[Sequence[types.TrajectoryWithRew], Sequence[int]]: """Pops recorded trajectories `trajs` and episode lengths `ep_lens`. Returns: A tuple `(trajs, ep_lens)`. `trajs` is a sequence of trajectory fragments, consisting of data collected after the last call to `pop_trajectories`. They may miss initial states (if `pop_trajectories` previously returned a fragment for that episode), and terminal states (if the episode has yet to complete). `ep_lens` is the total length of completed episodes. """ if self.n_transitions == 0: return [], [] partial_trajs = self._finish_partial_trajectories() self._trajectories.extend(partial_trajs) return self.pop_finished_trajectories()
[docs] def pop_transitions(self) -> types.TransitionsWithRew: """Pops recorded transitions, returning them as an instance of Transitions. Returns: All transitions recorded since the last call. Raises: RuntimeError: empty (no transitions recorded since last pop). """ if self.n_transitions == 0: # It would be better to return an empty `Transitions`, but we would need # to get the non-zero dimensions of every np.ndarray attribute correct to # avoid downstream errors. This is easier and sufficient for now. raise RuntimeError("Called pop_transitions on an empty BufferingWrapper") # make a copy for the assert later n_transitions = self.n_transitions trajectories, _ = self.pop_trajectories() transitions = rollout.flatten_trajectories_with_rew(trajectories) assert len(transitions.obs) == n_transitions return transitions
[docs]class RolloutInfoWrapper(gym.Wrapper): """Add the entire episode's rewards and observations to `info` at episode end. Whenever done=True, `info["rollouts"]` is a dict with keys "obs" and "rews", whose corresponding values hold the NumPy arrays containing the raw observations and rewards seen during this episode. """
[docs] def __init__(self, env: gym.Env): """Builds RolloutInfoWrapper. Args: env: Environment to wrap. """ super().__init__(env) self._obs = None self._rews = None
[docs] def reset(self, **kwargs): new_obs, info = super().reset(**kwargs) self._obs = [types.maybe_wrap_in_dictobs(new_obs)] self._rews = [] return new_obs, info
[docs] def step(self, action): obs, rew, terminated, truncated, info = self.env.step(action) done = terminated or truncated self._obs.append(types.maybe_wrap_in_dictobs(obs)) self._rews.append(rew) if done: assert "rollout" not in info info["rollout"] = { "obs": types.stack_maybe_dictobs(self._obs), "rews": np.stack(self._rews), } return obs, rew, terminated, truncated, info