Source code for imitation.scripts.ingredients.environment

"""This ingredient provides a vectorized gym environment."""
import contextlib
from typing import Any, Generator, Mapping

import numpy as np
import sacred
from stable_baselines3.common import vec_env

from imitation.data import wrappers
from imitation.util import util

environment_ingredient = sacred.Ingredient("environment")


@environment_ingredient.config
def config():
    num_vec = 8  # number of environments in VecEnv
    parallel = True  # Use SubprocVecEnv rather than DummyVecEnv
    max_episode_steps = None  # Set to positive int to limit episode horizons
    env_make_kwargs = {}  # The kwargs passed to `spec.make`.
    gym_id = "seals/CartPole-v0"  # The environment to train on

    locals()  # quieten flake8


[docs]@contextlib.contextmanager @environment_ingredient.capture def make_venv( gym_id: str, num_vec: int, parallel: bool, max_episode_steps: int, env_make_kwargs: Mapping[str, Any], _run: sacred.run.Run, _rnd: np.random.Generator, **kwargs, ) -> Generator[vec_env.VecEnv, None, None]: """Builds the vector environment. Args: gym_id: The id of the environment to create. num_vec: Number of `gym.Env` instances to combine into a vector environment. parallel: Whether to use "true" parallelism. If True, then use `SubProcVecEnv`. Otherwise, use `DummyVecEnv` which steps through environments serially. max_episode_steps: If not None, then a TimeLimit wrapper is applied to each environment to artificially limit the maximum number of timesteps in an episode. env_make_kwargs: The kwargs passed to `spec.make` of a gym environment. kwargs: Passed through to `util.make_vec_env`. Yields: The constructed vector environment. """ # Note: we create the venv outside the try -- finally block for the case that env # creation fails. venv = util.make_vec_env( gym_id, rng=_rnd, n_envs=num_vec, parallel=parallel, max_episode_steps=max_episode_steps, log_dir=_run.config["logging"]["log_dir"] if "logging" in _run.config else None, env_make_kwargs=env_make_kwargs, **kwargs, ) try: yield venv finally: venv.close()
[docs]@contextlib.contextmanager @environment_ingredient.capture def make_rollout_venv( gym_id: str, num_vec: int, parallel: bool, max_episode_steps: int, env_make_kwargs: Mapping[str, Any], _rnd: np.random.Generator, ) -> Generator[vec_env.VecEnv, None, None]: """Builds the vector environment for rollouts. This environment does no logging, and it is wrapped in a `RolloutInfoWrapper`. Args: gym_id: The id of the environment to create. num_vec: Number of `gym.Env` instances to combine into a vector environment. parallel: Whether to use "true" parallelism. If True, then use `SubProcVecEnv`. Otherwise, use `DummyVecEnv` which steps through environments serially. max_episode_steps: If not None, then a TimeLimit wrapper is applied to each environment to artificially limit the maximum number of timesteps in an episode. env_make_kwargs: The kwargs passed to `spec.make` of a gym environment. _rnd: Random number generator provided by Sacred. Yields: The constructed vector environment. """ # Note: we create the venv outside the try -- finally block for the case that env # creation fails. venv = util.make_vec_env( gym_id, rng=_rnd, n_envs=num_vec, parallel=parallel, max_episode_steps=max_episode_steps, log_dir=None, env_make_kwargs=env_make_kwargs, post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)], ) try: yield venv finally: venv.close()
@environment_ingredient.named_config def fast(): num_vec = 2 parallel = False # easier to debug with everything in one process max_episode_steps = 5 locals() # quieten flake8