download this notebook here

Learning a Reward Function using Preference Comparisons on Atari#

In this case, we will use a convolutional neural network for our policy and reward model. We will also shape the learned reward model with the policy’s learned value function, since these shaped rewards will be more informative for training - incentivizing agents to move to high-value states. In the interests of execution time, we will only do a little bit of training - much less than in the previous preference comparison notebook. To run this notebook, be sure to install the atari extras, for example by running pip install imitation[atari].

First, we will set up the environment, reward network, et cetera.

import torch as th
import gym
from gym.wrappers import TimeLimit
import numpy as np

from seals.util import AutoResetWrapper

from stable_baselines3 import PPO
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.ppo import CnnPolicy

from imitation.algorithms import preference_comparisons
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.base import NormalizeFeaturesExtractor
from imitation.rewards.reward_nets import CnnRewardNet


device = th.device("cuda" if th.cuda.is_available() else "cpu")

rng = np.random.default_rng()

# Here we ensure that our environment has constant-length episodes by resetting
# it when done, and running until 100 timesteps have elapsed.
# For real training, you will want a much longer time limit.
def constant_length_asteroids(num_steps):
    atari_env = gym.make("AsteroidsNoFrameskip-v4")
    preprocessed_env = AtariWrapper(atari_env)
    endless_env = AutoResetWrapper(preprocessed_env)
    limited_env = TimeLimit(endless_env, max_episode_steps=num_steps)
    return RolloutInfoWrapper(limited_env)


# For real training, you will want a vectorized environment with 8 environments in parallel.
# This can be done by passing in n_envs=8 as an argument to make_vec_env.
venv = make_vec_env(constant_length_asteroids, env_kwargs={"num_steps": 100})
venv = VecFrameStack(venv, n_stack=4)

reward_net = CnnRewardNet(
    venv.observation_space,
    venv.action_space,
).to(device)

fragmenter = preference_comparisons.RandomFragmenter(warning_threshold=0, rng=rng)
gatherer = preference_comparisons.SyntheticGatherer(rng=rng)
preference_model = preference_comparisons.PreferenceModel(reward_net)
reward_trainer = preference_comparisons.BasicRewardTrainer(
    preference_model=preference_model,
    loss=preference_comparisons.CrossEntropyRewardLoss(),
    epochs=3,
    rng=rng,
)

agent = PPO(
    policy=CnnPolicy,
    env=venv,
    seed=0,
    n_steps=16,  # To train on atari well, set this to 128
    batch_size=16,  # To train on atari well, set this to 256
    ent_coef=0.01,
    learning_rate=0.00025,
    n_epochs=4,
)

trajectory_generator = preference_comparisons.AgentTrainer(
    algorithm=agent,
    reward_fn=reward_net,
    venv=venv,
    exploration_frac=0.0,
    rng=rng,
)

pref_comparisons = preference_comparisons.PreferenceComparisons(
    trajectory_generator,
    reward_net,
    num_iterations=2,
    fragmenter=fragmenter,
    preference_gatherer=gatherer,
    reward_trainer=reward_trainer,
    fragment_length=10,
    transition_oversampling=1,
    initial_comparison_frac=0.1,
    allow_variable_horizon=False,
    initial_epoch_multiplier=1,
)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 66
     48 reward_trainer = preference_comparisons.BasicRewardTrainer(
     49     preference_model=preference_model,
     50     loss=preference_comparisons.CrossEntropyRewardLoss(),
     51     epochs=3,
     52     rng=rng,
     53 )
     55 agent = PPO(
     56     policy=CnnPolicy,
     57     env=venv,
   (...)
     63     n_epochs=4,
     64 )
---> 66 trajectory_generator = preference_comparisons.AgentTrainer(
     67     algorithm=agent,
     68     reward_fn=reward_net,
     69     venv=venv,
     70     exploration_frac=0.0,
     71     rng=rng,
     72 )
     74 pref_comparisons = preference_comparisons.PreferenceComparisons(
     75     trajectory_generator,
     76     reward_net,
   (...)
     85     initial_epoch_multiplier=1,
     86 )

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/imitation/algorithms/preference_comparisons.py:182, in AgentTrainer.__init__(self, algorithm, reward_fn, venv, rng, exploration_frac, switch_prob, random_prob, custom_logger)
    172 # The BufferingWrapper records all trajectories, so we can return
    173 # them after training. This should come first (before the wrapper that
    174 # changes the reward function), so that we return the original environment
   (...)
    179 # SB3 may move the image-channel dimension in the observation space, making
    180 # `algorithm.get_env()` not match with `reward_fn`.
    181 self.buffering_wrapper = wrappers.BufferingWrapper(venv)
--> 182 self.venv = self.reward_venv_wrapper = reward_wrapper.RewardVecEnvWrapper(
    183     self.buffering_wrapper,
    184     reward_fn=self.reward_fn,
    185 )
    187 self.log_callback = self.reward_venv_wrapper.make_log_callback()
    189 self.algorithm.set_env(self.venv)

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/imitation/rewards/reward_wrapper.py:73, in RewardVecEnvWrapper.__init__(self, venv, reward_fn, ep_history)
     71 self._old_obs = None
     72 self._actions = None
---> 73 self.reset()

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/imitation/rewards/reward_wrapper.py:84, in RewardVecEnvWrapper.reset(self)
     83 def reset(self):
---> 84     self._old_obs = self.venv.reset()
     85     return self._old_obs

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/imitation/data/wrappers.py:54, in BufferingWrapper.reset(self, **kwargs)
     52 self._init_reset = True
     53 self.n_transitions = 0
---> 54 obs = self.venv.reset(**kwargs)
     55 self._traj_accum = rollout.TrajectoryAccumulator()
     56 for i, ob in enumerate(obs):

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/stable_baselines3/common/vec_env/vec_frame_stack.py:38, in VecFrameStack.reset(self)
     37 def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
---> 38     observation = self.venv.reset()  # pytype:disable=annotation-type-mismatch
     39     observation = self.stacked_obs.reset(observation)
     40     return observation

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py:74, in DummyVecEnv.reset(self)
     72 def reset(self) -> VecEnvObs:
     73     for env_idx in range(self.num_envs):
---> 74         obs = self.envs[env_idx].reset()
     75         self._save_obs(env_idx, obs)
     76     return self._obs_from_buf()

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/stable_baselines3/common/monitor.py:84, in Monitor.reset(self, **kwargs)
     82         raise ValueError(f"Expected you to pass keyword argument {key} into reset")
     83     self.current_reset_info[key] = value
---> 84 return self.env.reset(**kwargs)

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/imitation/data/wrappers.py:189, in RolloutInfoWrapper.reset(self, **kwargs)
    188 def reset(self, **kwargs):
--> 189     new_obs = super().reset(**kwargs)
    190     self._obs = [new_obs]
    191     self._rews = []

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/gym/core.py:292, in Wrapper.reset(self, **kwargs)
    291 def reset(self, **kwargs):
--> 292     return self.env.reset(**kwargs)

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/gym/wrappers/time_limit.py:27, in TimeLimit.reset(self, **kwargs)
     25 def reset(self, **kwargs):
     26     self._elapsed_steps = 0
---> 27     return self.env.reset(**kwargs)

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/gymnasium/core.py:467, in Wrapper.reset(self, seed, options)
    463 def reset(
    464     self, *, seed: int | None = None, options: dict[str, Any] | None = None
    465 ) -> tuple[WrapperObsType, dict[str, Any]]:
    466     """Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data."""
--> 467     return self.env.reset(seed=seed, options=options)

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/gym/core.py:292, in Wrapper.reset(self, **kwargs)
    291 def reset(self, **kwargs):
--> 292     return self.env.reset(**kwargs)

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/gym/core.py:333, in RewardWrapper.reset(self, **kwargs)
    332 def reset(self, **kwargs):
--> 333     return self.env.reset(**kwargs)

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/gym/core.py:319, in ObservationWrapper.reset(self, **kwargs)
    318 def reset(self, **kwargs):
--> 319     observation = self.env.reset(**kwargs)
    320     return self.observation(observation)

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/stable_baselines3/common/atari_wrappers.py:85, in FireResetEnv.reset(self, **kwargs)
     84 def reset(self, **kwargs) -> np.ndarray:
---> 85     self.env.reset(**kwargs)
     86     obs, _, done, _ = self.env.step(1)
     87     if done:

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/stable_baselines3/common/atari_wrappers.py:132, in EpisodicLifeEnv.reset(self, **kwargs)
    123 """
    124 Calls the Gym environment reset, only when lives are exhausted.
    125 This way all states are still reachable even though lives are episodic,
   (...)
    129 :return: the first observation of the environment
    130 """
    131 if self.was_real_done:
--> 132     obs = self.env.reset(**kwargs)
    133 else:
    134     # no-op step to advance from terminal/lost life state
    135     obs, _, done, _ = self.env.step(0)

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/gym/core.py:292, in Wrapper.reset(self, **kwargs)
    291 def reset(self, **kwargs):
--> 292     return self.env.reset(**kwargs)

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/stable_baselines3/common/atari_wrappers.py:58, in NoopResetEnv.reset(self, **kwargs)
     57 def reset(self, **kwargs) -> np.ndarray:
---> 58     self.env.reset(**kwargs)
     59     if self.override_num_noops is not None:
     60         noops = self.override_num_noops

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/gym/wrappers/time_limit.py:27, in TimeLimit.reset(self, **kwargs)
     25 def reset(self, **kwargs):
     26     self._elapsed_steps = 0
---> 27     return self.env.reset(**kwargs)

TypeError: reset() got an unexpected keyword argument 'options'

We are now ready to train the reward model.

pref_comparisons.train(
    total_timesteps=16,
    total_comparisons=15,
)

We can now wrap the environment with the learned reward model, shaped by the policy’s learned value function. Note that if we were training this for real, we would want to normalize the output of the reward net as well as the value function, to ensure their values are on the same scale. To do this, use the NormalizedRewardNet class from src/imitation/rewards/reward_nets.py on reward_net, and modify the potential to add a RunningNorm module from src/imitation/util/networks.py.

from imitation.rewards.reward_nets import ShapedRewardNet, cnn_transpose
from imitation.rewards.reward_wrapper import RewardVecEnvWrapper


def value_potential(state):
    state_ = cnn_transpose(state)
    return agent.policy.predict_values(state_)


shaped_reward_net = ShapedRewardNet(
    base=reward_net,
    potential=value_potential,
    discount_factor=0.99,
)

# GOTCHA: When using the NormalizedRewardNet wrapper, you should deactivate updating
# during evaluation by passing update_stats=False to the predict_processed method.
learned_reward_venv = RewardVecEnvWrapper(venv, shaped_reward_net.predict_processed)

Next, we train an agent that sees only the shaped, learned reward.

learner = PPO(
    policy=CnnPolicy,
    env=learned_reward_venv,
    seed=0,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
    n_steps=64,
)
learner.learn(1000)

We now evaluate the learner using the original reward.

from stable_baselines3.common.evaluation import evaluate_policy

reward, _ = evaluate_policy(learner.policy, venv, 10)
print(reward)

Generating rollouts#

When generating rollouts in image environments, be sure to use the agent’s get_env() function rather than using the original environment.

The learner re-arranges the observations space to put the channel environment in the first dimension, and get_env() will correctly provide a wrapped environment doing this.

from imitation.data import rollout

rollouts = rollout.rollout(
    learner,
    # Note that passing venv instead of agent.get_env()
    # here would fail.
    learner.get_env(),
    rollout.make_sample_until(min_timesteps=None, min_episodes=3),
    rng=rng,
)