Source code for imitation.scripts.ingredients.policy_evaluation

"""This ingredient performs evaluation of learned policy.

It takes care of the right wrappers, does some rollouts
and computes statistics of the rollouts.
"""

from typing import Mapping, Union

import numpy as np
import sacred
from stable_baselines3.common import base_class, policies, vec_env

from imitation.data import rollout

policy_evaluation_ingredient = sacred.Ingredient("policy_evaluation")


@policy_evaluation_ingredient.config
def config():
    n_episodes_eval = 50  # Num of episodes for final mean ground truth return
    locals()  # quieten flake8


@policy_evaluation_ingredient.named_config
def fast():
    n_episodes_eval = 1  # noqa: F841


[docs]@policy_evaluation_ingredient.capture def eval_policy( rl_algo: Union[base_class.BaseAlgorithm, policies.BasePolicy], venv: vec_env.VecEnv, n_episodes_eval: int, _rnd: np.random.Generator, ) -> Mapping[str, float]: """Evaluation of imitation learned policy. Has the side effect of setting `rl_algo`'s environment to `venv` if it is a `BaseAlgorithm`. Args: rl_algo: Algorithm to evaluate. venv: Environment to evaluate on. n_episodes_eval: The number of episodes to average over when calculating the average episode reward of the imitation policy for return. _rnd: Random number generator provided by Sacred. Returns: A dictionary with two keys. "imit_stats" gives the return value of `rollout_stats()` on rollouts test-reward-wrapped environment, using the final policy (remember that the ground-truth reward can be recovered from the "monitor_return" key). "expert_stats" gives the return value of `rollout_stats()` on the expert demonstrations loaded from `path`. """ sample_until_eval = rollout.make_min_episodes(n_episodes_eval) if isinstance(rl_algo, base_class.BaseAlgorithm): # Set RL algorithm's env to venv, removing any cruft wrappers that the RL # algorithm's environment may have accumulated. rl_algo.set_env(venv) # Generate trajectories with the RL algorithm's env - SB3 may apply wrappers # under the hood to get it to work with the RL algorithm (e.g. transposing # images, so they can be fed into CNNs). train_env = rl_algo.get_env() assert train_env is not None else: train_env = venv trajs = rollout.generate_trajectories( rl_algo, train_env, sample_until=sample_until_eval, rng=_rnd, ) return rollout.rollout_stats(trajs)