imitation.rewards.reward_wrapper#

Common wrapper for adding custom reward values to an environment.

Classes

RewardVecEnvWrapper(venv, reward_fn[, ...])

Uses a provided reward_fn to replace the reward function returned by step().

WrappedRewardCallback(episode_rewards, ...)

Logs mean wrapped reward as part of RL (or other) training.

class imitation.rewards.reward_wrapper.RewardVecEnvWrapper(venv, reward_fn, ep_history=100)[source]#

Bases: VecEnvWrapper

Uses a provided reward_fn to replace the reward function returned by step().

Automatically resets the inner VecEnv upon initialization. A tricky part about this class is keeping track of the most recent observation from each environment.

Will also include the previous reward given by the inner VecEnv in the returned info dict under the original_env_rew key.

__init__(venv, reward_fn, ep_history=100)[source]#

Builds RewardVecEnvWrapper.

Parameters
  • venv (VecEnv) – The VecEnv to wrap.

  • reward_fn (RewardFn) – A function that wraps takes in vectorized transitions (obs, act, next_obs) a vector of episode timesteps, and returns a vector of rewards.

  • ep_history (int) – The number of episode rewards to retain for computing mean reward.

property envs#
make_log_callback()[source]#

Creates WrappedRewardCallback connected to this RewardVecEnvWrapper.

Return type

WrappedRewardCallback

reset()[source]#

Reset all the environments and return an array of observations, or a tuple of observation arrays.

If step_async is still doing work, that work will be cancelled and step_wait() should not be called until step_async() is invoked again.

Returns

observation

step_async(actions)[source]#

Tell all the environments to start taking a step with the given actions. Call step_wait() to get the results of the step.

You should not call this if a step_async run is already pending.

step_wait()[source]#

Wait for the step taken with step_async().

Returns

observation, reward, done, information

class imitation.rewards.reward_wrapper.WrappedRewardCallback(episode_rewards, *args, **kwargs)[source]#

Bases: BaseCallback

Logs mean wrapped reward as part of RL (or other) training.

__init__(episode_rewards, *args, **kwargs)[source]#

Builds WrappedRewardCallback.

Parameters
  • episode_rewards (Deque[float]) – A queue that episode rewards will be placed into.

  • *args – Passed through to callbacks.BaseCallback.

  • **kwargs – Passed through to callbacks.BaseCallback.

model: base_class.BaseAlgorithm#