download this notebook here

Learning a Reward Function using Preference Comparisons on Atari#

In this case, we will use a convolutional neural network for our policy and reward model. We will also shape the learned reward model with the policy’s learned value function, since these shaped rewards will be more informative for training - incentivizing agents to move to high-value states. In the interests of execution time, we will only do a little bit of training - much less than in the previous preference comparison notebook. To run this notebook, be sure to install the atari extras, for example by running pip install imitation[atari].

First, we will set up the environment, reward network, et cetera.

import torch as th
import gymnasium as gym
from gymnasium.wrappers import TimeLimit
import numpy as np

from seals.util import AutoResetWrapper

from stable_baselines3 import PPO
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.ppo import CnnPolicy

from imitation.algorithms import preference_comparisons
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.base import NormalizeFeaturesExtractor
from imitation.rewards.reward_nets import CnnRewardNet


device = th.device("cuda" if th.cuda.is_available() else "cpu")

rng = np.random.default_rng()


# Here we ensure that our environment has constant-length episodes by resetting
# it when done, and running until 100 timesteps have elapsed.
# For real training, you will want a much longer time limit.
def constant_length_asteroids(num_steps):
    atari_env = gym.make("AsteroidsNoFrameskip-v4")
    preprocessed_env = AtariWrapper(atari_env)
    endless_env = AutoResetWrapper(preprocessed_env)
    limited_env = TimeLimit(endless_env, max_episode_steps=num_steps)
    return RolloutInfoWrapper(limited_env)


# For real training, you will want a vectorized environment with 8 environments in parallel.
# This can be done by passing in n_envs=8 as an argument to make_vec_env.
# The seed needs to be set to 1 for reproducibility and also to avoid win32
# np.random.randint high bound error.
venv = make_vec_env(constant_length_asteroids, env_kwargs={"num_steps": 100}, seed=1)
venv = VecFrameStack(venv, n_stack=4)

reward_net = CnnRewardNet(
    venv.observation_space,
    venv.action_space,
).to(device)

fragmenter = preference_comparisons.RandomFragmenter(warning_threshold=0, rng=rng)
gatherer = preference_comparisons.SyntheticGatherer(rng=rng)
preference_model = preference_comparisons.PreferenceModel(reward_net)
reward_trainer = preference_comparisons.BasicRewardTrainer(
    preference_model=preference_model,
    loss=preference_comparisons.CrossEntropyRewardLoss(),
    epochs=3,
    rng=rng,
)

agent = PPO(
    policy=CnnPolicy,
    env=venv,
    seed=0,
    n_steps=16,  # To train on atari well, set this to 128
    batch_size=16,  # To train on atari well, set this to 256
    ent_coef=0.01,
    learning_rate=0.00025,
    n_epochs=4,
)

trajectory_generator = preference_comparisons.AgentTrainer(
    algorithm=agent,
    reward_fn=reward_net,
    venv=venv,
    exploration_frac=0.0,
    rng=rng,
)

pref_comparisons = preference_comparisons.PreferenceComparisons(
    trajectory_generator,
    reward_net,
    num_iterations=2,
    fragmenter=fragmenter,
    preference_gatherer=gatherer,
    reward_trainer=reward_trainer,
    fragment_length=10,
    transition_oversampling=1,
    initial_comparison_frac=0.1,
    allow_variable_horizon=False,
    initial_epoch_multiplier=1,
)

We are now ready to train the reward model.

pref_comparisons.train(
    total_timesteps=16,
    total_comparisons=15,
)
Query schedule: [1, 9, 5]
Collecting 2 fragments (20 transitions)
Requested 20 transitions but only 0 in buffer. Sampling 20 additional transitions.
Creating fragment pairs
Gathering preferences
Dataset now contains 1 comparisons
Training agent for 8 timesteps
---------------------------------------------------
| raw/                                 |          |
|    agent/rollout/ep_rew_wrapped_mean | 2.91     |
|    agent/time/fps                    | 154      |
|    agent/time/iterations             | 1        |
|    agent/time/time_elapsed           | 0        |
|    agent/time/total_timesteps        | 16       |
---------------------------------------------------
-----------------------------------------------------
| mean/                                  |          |
|    agent/rollout/ep_rew_wrapped_mean   | 2.91     |
|    agent/time/fps                      | 154      |
|    agent/time/iterations               | 1        |
|    agent/time/time_elapsed             | 0        |
|    agent/time/total_timesteps          | 16       |
|    agent/train/approx_kl               | 9.21e-05 |
|    agent/train/clip_fraction           | 0        |
|    agent/train/clip_range              | 0.2      |
|    agent/train/entropy_loss            | -2.64    |
|    agent/train/explained_variance      | -0.0826  |
|    agent/train/learning_rate           | 0.00025  |
|    agent/train/loss                    | -0.0353  |
|    agent/train/n_updates               | 4        |
|    agent/train/policy_gradient_loss    | -0.00515 |
|    agent/train/value_loss              | 0.00965  |
|    preferences/entropy                 | 0.693    |
|    reward/epoch-0/train/accuracy       | 1        |
|    reward/epoch-0/train/gt_reward_loss | 0.693    |
|    reward/epoch-0/train/loss           | 0.435    |
|    reward/epoch-1/train/accuracy       | 1        |
|    reward/epoch-1/train/gt_reward_loss | 0.693    |
|    reward/epoch-1/train/loss           | 0.393    |
|    reward/epoch-2/train/accuracy       | 1        |
|    reward/epoch-2/train/gt_reward_loss | 0.693    |
|    reward/epoch-2/train/loss           | 0.355    |
| reward/                                |          |
|    final/train/accuracy                | 1        |
|    final/train/gt_reward_loss          | 0.693    |
|    final/train/loss                    | 0.355    |
-----------------------------------------------------
Collecting 18 fragments (180 transitions)
Requested 180 transitions but only 0 in buffer. Sampling 180 additional transitions.
Creating fragment pairs
Gathering preferences
Dataset now contains 10 comparisons
Training agent for 8 timesteps
------------------------------------------------------
| raw/                                 |             |
|    agent/rollout/ep_rew_wrapped_mean | 3.08        |
|    agent/time/fps                    | 157         |
|    agent/time/iterations             | 1           |
|    agent/time/time_elapsed           | 0           |
|    agent/time/total_timesteps        | 32          |
|    agent/train/approx_kl             | 9.20929e-05 |
|    agent/train/clip_fraction         | 0           |
|    agent/train/clip_range            | 0.2         |
|    agent/train/entropy_loss          | -2.64       |
|    agent/train/explained_variance    | -0.0826     |
|    agent/train/learning_rate         | 0.00025     |
|    agent/train/loss                  | -0.0353     |
|    agent/train/n_updates             | 4           |
|    agent/train/policy_gradient_loss  | -0.00515    |
|    agent/train/value_loss            | 0.00965     |
------------------------------------------------------
-----------------------------------------------------
| mean/                                  |          |
|    agent/rollout/ep_rew_wrapped_mean   | 3.08     |
|    agent/time/fps                      | 157      |
|    agent/time/iterations               | 1        |
|    agent/time/time_elapsed             | 0        |
|    agent/time/total_timesteps          | 32       |
|    agent/train/approx_kl               | 0.00014  |
|    agent/train/clip_fraction           | 0        |
|    agent/train/clip_range              | 0.2      |
|    agent/train/entropy_loss            | -2.64    |
|    agent/train/explained_variance      | 0.0552   |
|    agent/train/learning_rate           | 0.00025  |
|    agent/train/loss                    | -0.0351  |
|    agent/train/n_updates               | 8        |
|    agent/train/policy_gradient_loss    | -0.00595 |
|    agent/train/value_loss              | 0.0197   |
|    preferences/entropy                 | 0.681    |
|    reward/epoch-0/train/accuracy       | 0.4      |
|    reward/epoch-0/train/gt_reward_loss | 0.655    |
|    reward/epoch-0/train/loss           | 0.776    |
|    reward/epoch-1/train/accuracy       | 0.4      |
|    reward/epoch-1/train/gt_reward_loss | 0.655    |
|    reward/epoch-1/train/loss           | 0.764    |
|    reward/epoch-2/train/accuracy       | 0.4      |
|    reward/epoch-2/train/gt_reward_loss | 0.655    |
|    reward/epoch-2/train/loss           | 0.749    |
| reward/                                |          |
|    final/train/accuracy                | 0.4      |
|    final/train/gt_reward_loss          | 0.655    |
|    final/train/loss                    | 0.749    |
-----------------------------------------------------
Collecting 10 fragments (100 transitions)
Requested 100 transitions but only 0 in buffer. Sampling 100 additional transitions.
Creating fragment pairs
Gathering preferences
Dataset now contains 15 comparisons
Training agent for 8 timesteps
-------------------------------------------------------
| raw/                                 |              |
|    agent/rollout/ep_rew_wrapped_mean | 2.94         |
|    agent/time/fps                    | 163          |
|    agent/time/iterations             | 1            |
|    agent/time/time_elapsed           | 0            |
|    agent/time/total_timesteps        | 48           |
|    agent/train/approx_kl             | 0.0001396425 |
|    agent/train/clip_fraction         | 0            |
|    agent/train/clip_range            | 0.2          |
|    agent/train/entropy_loss          | -2.64        |
|    agent/train/explained_variance    | 0.0552       |
|    agent/train/learning_rate         | 0.00025      |
|    agent/train/loss                  | -0.0351      |
|    agent/train/n_updates             | 8            |
|    agent/train/policy_gradient_loss  | -0.00595     |
|    agent/train/value_loss            | 0.0197       |
-------------------------------------------------------
-----------------------------------------------------
| mean/                                  |          |
|    agent/rollout/ep_rew_wrapped_mean   | 2.94     |
|    agent/time/fps                      | 163      |
|    agent/time/iterations               | 1        |
|    agent/time/time_elapsed             | 0        |
|    agent/time/total_timesteps          | 48       |
|    agent/train/approx_kl               | 0.000128 |
|    agent/train/clip_fraction           | 0        |
|    agent/train/clip_range              | 0.2      |
|    agent/train/entropy_loss            | -2.64    |
|    agent/train/explained_variance      | -1.45    |
|    agent/train/learning_rate           | 0.00025  |
|    agent/train/loss                    | -0.0216  |
|    agent/train/n_updates               | 12       |
|    agent/train/policy_gradient_loss    | -0.00565 |
|    agent/train/value_loss              | 0.0346   |
|    preferences/entropy                 | 0.693    |
|    reward/epoch-0/train/accuracy       | 0.4      |
|    reward/epoch-0/train/gt_reward_loss | 0.668    |
|    reward/epoch-0/train/loss           | 0.719    |
|    reward/epoch-1/train/accuracy       | 0.467    |
|    reward/epoch-1/train/gt_reward_loss | 0.668    |
|    reward/epoch-1/train/loss           | 0.709    |
|    reward/epoch-2/train/accuracy       | 0.467    |
|    reward/epoch-2/train/gt_reward_loss | 0.668    |
|    reward/epoch-2/train/loss           | 0.698    |
| reward/                                |          |
|    final/train/accuracy                | 0.467    |
|    final/train/gt_reward_loss          | 0.668    |
|    final/train/loss                    | 0.698    |
-----------------------------------------------------
{'reward_loss': 0.6978113651275635, 'reward_accuracy': 0.46666666865348816}

We can now wrap the environment with the learned reward model, shaped by the policy’s learned value function. Note that if we were training this for real, we would want to normalize the output of the reward net as well as the value function, to ensure their values are on the same scale. To do this, use the NormalizedRewardNet class from src/imitation/rewards/reward_nets.py on reward_net, and modify the potential to add a RunningNorm module from src/imitation/util/networks.py.

from imitation.rewards.reward_nets import ShapedRewardNet, cnn_transpose
from imitation.rewards.reward_wrapper import RewardVecEnvWrapper


def value_potential(state):
    state_ = cnn_transpose(state)
    return agent.policy.predict_values(state_)


shaped_reward_net = ShapedRewardNet(
    base=reward_net,
    potential=value_potential,
    discount_factor=0.99,
)

# GOTCHA: When using the NormalizedRewardNet wrapper, you should deactivate updating
# during evaluation by passing update_stats=False to the predict_processed method.
learned_reward_venv = RewardVecEnvWrapper(venv, shaped_reward_net.predict_processed)

Next, we train an agent that sees only the shaped, learned reward.

learner = PPO(
    policy=CnnPolicy,
    env=learned_reward_venv,
    seed=0,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
    n_steps=64,
)
learner.learn(1000)
<stable_baselines3.ppo.ppo.PPO at 0x7f38d5f04d00>

We now evaluate the learner using the original reward.

from stable_baselines3.common.evaluation import evaluate_policy

reward, _ = evaluate_policy(learner.policy, venv, 10)
print(reward)
0.4

Generating rollouts#

When generating rollouts in image environments, be sure to use the agent’s get_env() function rather than using the original environment.

The learner re-arranges the observations space to put the channel environment in the first dimension, and get_env() will correctly provide a wrapped environment doing this.

from imitation.data import rollout

rollouts = rollout.rollout(
    learner,
    # Note that passing venv instead of agent.get_env()
    # here would fail.
    learner.get_env(),
    rollout.make_sample_until(min_timesteps=None, min_episodes=3),
    rng=rng,
)