download this notebook here

Learning a Reward Function using Preference Comparisons#

The preference comparisons algorithm learns a reward function by comparing trajectory segments to each other.

To set up the preference comparisons algorithm, we first need to set up a lot of its internals beforehand:

import random
from imitation.algorithms import preference_comparisons
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env
from imitation.policies.base import FeedForward32Policy, NormalizeFeaturesExtractor
import gym
from stable_baselines3 import PPO
import numpy as np

rng = np.random.default_rng(0)

venv = make_vec_env("Pendulum-v1", rng=rng)

reward_net = BasicRewardNet(
    venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm
)

fragmenter = preference_comparisons.RandomFragmenter(
    warning_threshold=0,
    rng=rng,
)
gatherer = preference_comparisons.SyntheticGatherer(rng=rng)
preference_model = preference_comparisons.PreferenceModel(reward_net)
reward_trainer = preference_comparisons.BasicRewardTrainer(
    preference_model=preference_model,
    loss=preference_comparisons.CrossEntropyRewardLoss(),
    epochs=3,
    rng=rng,
)

agent = PPO(
    policy=FeedForward32Policy,
    policy_kwargs=dict(
        features_extractor_class=NormalizeFeaturesExtractor,
        features_extractor_kwargs=dict(normalize_class=RunningNorm),
    ),
    env=venv,
    seed=0,
    n_steps=2048 // venv.num_envs,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
)

trajectory_generator = preference_comparisons.AgentTrainer(
    algorithm=agent,
    reward_fn=reward_net,
    venv=venv,
    exploration_frac=0.0,
    rng=rng,
)

pref_comparisons = preference_comparisons.PreferenceComparisons(
    trajectory_generator,
    reward_net,
    num_iterations=5,
    fragmenter=fragmenter,
    preference_gatherer=gatherer,
    reward_trainer=reward_trainer,
    fragment_length=100,
    transition_oversampling=1,
    initial_comparison_frac=0.1,
    allow_variable_horizon=False,
    initial_epoch_multiplier=1,
)

Then we can start training the reward model. Note that we need to specify the total timesteps that the agent should be trained and how many fragment comparisons should be made.

pref_comparisons.train(
    total_timesteps=5_000,  # For good performance this should be 1_000_000
    total_comparisons=200,  # For good performance this should be 5_000
)
Query schedule: [20, 51, 41, 34, 29, 25]
Collecting 40 fragments (4000 transitions)
Requested 4000 transitions but only 0 in buffer. Sampling 4000 additional transitions.
Creating fragment pairs
Gathering preferences
Dataset now contains 20 comparisons
Training agent for 1000 timesteps
----------------------------------------------------
| raw/                                 |           |
|    agent/rollout/ep_len_mean         | 200       |
|    agent/rollout/ep_rew_mean         | -1.32e+03 |
|    agent/rollout/ep_rew_wrapped_mean | 70.9      |
|    agent/time/fps                    | 4965      |
|    agent/time/iterations             | 1         |
|    agent/time/time_elapsed           | 0         |
|    agent/time/total_timesteps        | 2048      |
----------------------------------------------------
------------------------------------------------------
| mean/                                  |           |
|    agent/rollout/ep_len_mean           | 200       |
|    agent/rollout/ep_rew_mean           | -1.32e+03 |
|    agent/rollout/ep_rew_wrapped_mean   | 70.9      |
|    agent/time/fps                      | 4.96e+03  |
|    agent/time/iterations               | 1         |
|    agent/time/time_elapsed             | 0         |
|    agent/time/total_timesteps          | 2.05e+03  |
|    agent/train/approx_kl               | 0.00522   |
|    agent/train/clip_fraction           | 0.033     |
|    agent/train/clip_range              | 0.2       |
|    agent/train/entropy_loss            | -1.42     |
|    agent/train/explained_variance      | -0.0565   |
|    agent/train/learning_rate           | 0.0003    |
|    agent/train/loss                    | 0.538     |
|    agent/train/n_updates               | 10        |
|    agent/train/policy_gradient_loss    | -0.00434  |
|    agent/train/std                     | 1         |
|    agent/train/value_loss              | 6.93      |
|    preferences/entropy                 | 0.00589   |
|    reward/epoch-0/train/accuracy       | 0.7       |
|    reward/epoch-0/train/gt_reward_loss | 0.00125   |
|    reward/epoch-0/train/loss           | 0.663     |
|    reward/epoch-1/train/accuracy       | 0.75      |
|    reward/epoch-1/train/gt_reward_loss | 0.00125   |
|    reward/epoch-1/train/loss           | 0.551     |
|    reward/epoch-2/train/accuracy       | 0.9       |
|    reward/epoch-2/train/gt_reward_loss | 0.00125   |
|    reward/epoch-2/train/loss           | 0.416     |
| reward/                                |           |
|    final/train/accuracy                | 0.9       |
|    final/train/gt_reward_loss          | 0.00125   |
|    final/train/loss                    | 0.416     |
------------------------------------------------------
Collecting 102 fragments (10200 transitions)
Requested 10200 transitions but only 1600 in buffer. Sampling 8600 additional transitions.
Creating fragment pairs
Gathering preferences
Dataset now contains 71 comparisons
Training agent for 1000 timesteps
-----------------------------------------------------
| raw/                                 |            |
|    agent/rollout/ep_len_mean         | 200        |
|    agent/rollout/ep_rew_mean         | -1.32e+03  |
|    agent/rollout/ep_rew_wrapped_mean | 54.5       |
|    agent/time/fps                    | 4935       |
|    agent/time/iterations             | 1          |
|    agent/time/time_elapsed           | 0          |
|    agent/time/total_timesteps        | 4096       |
|    agent/train/approx_kl             | 0.00522125 |
|    agent/train/clip_fraction         | 0.033      |
|    agent/train/clip_range            | 0.2        |
|    agent/train/entropy_loss          | -1.42      |
|    agent/train/explained_variance    | -0.0565    |
|    agent/train/learning_rate         | 0.0003     |
|    agent/train/loss                  | 0.538      |
|    agent/train/n_updates             | 10         |
|    agent/train/policy_gradient_loss  | -0.00434   |
|    agent/train/std                   | 1          |
|    agent/train/value_loss            | 6.93       |
-----------------------------------------------------
------------------------------------------------------
| mean/                                  |           |
|    agent/rollout/ep_len_mean           | 200       |
|    agent/rollout/ep_rew_mean           | -1.32e+03 |
|    agent/rollout/ep_rew_wrapped_mean   | 54.5      |
|    agent/time/fps                      | 4.94e+03  |
|    agent/time/iterations               | 1         |
|    agent/time/time_elapsed             | 0         |
|    agent/time/total_timesteps          | 4.1e+03   |
|    agent/train/approx_kl               | 0.00491   |
|    agent/train/clip_fraction           | 0.0267    |
|    agent/train/clip_range              | 0.2       |
|    agent/train/entropy_loss            | -1.42     |
|    agent/train/explained_variance      | -0.0464   |
|    agent/train/learning_rate           | 0.0003    |
|    agent/train/loss                    | 0.743     |
|    agent/train/n_updates               | 20        |
|    agent/train/policy_gradient_loss    | -0.00231  |
|    agent/train/std                     | 1         |
|    agent/train/value_loss              | 2.25      |
|    preferences/entropy                 | 0.022     |
|    reward/epoch-0/train/accuracy       | 0.811     |
|    reward/epoch-0/train/gt_reward_loss | 0.00811   |
|    reward/epoch-0/train/loss           | 0.339     |
|    reward/epoch-1/train/accuracy       | 0.9       |
|    reward/epoch-1/train/gt_reward_loss | 0.0081    |
|    reward/epoch-1/train/loss           | 0.194     |
|    reward/epoch-2/train/accuracy       | 0.948     |
|    reward/epoch-2/train/gt_reward_loss | 0.0081    |
|    reward/epoch-2/train/loss           | 0.102     |
| reward/                                |           |
|    final/train/accuracy                | 0.948     |
|    final/train/gt_reward_loss          | 0.0081    |
|    final/train/loss                    | 0.102     |
------------------------------------------------------
Collecting 82 fragments (8200 transitions)
Requested 8200 transitions but only 1600 in buffer. Sampling 6600 additional transitions.
Creating fragment pairs
Gathering preferences
Dataset now contains 112 comparisons
Training agent for 1000 timesteps
------------------------------------------------------
| raw/                                 |             |
|    agent/rollout/ep_len_mean         | 200         |
|    agent/rollout/ep_rew_mean         | -1.28e+03   |
|    agent/rollout/ep_rew_wrapped_mean | 46.4        |
|    agent/time/fps                    | 4954        |
|    agent/time/iterations             | 1           |
|    agent/time/time_elapsed           | 0           |
|    agent/time/total_timesteps        | 6144        |
|    agent/train/approx_kl             | 0.004912718 |
|    agent/train/clip_fraction         | 0.0267      |
|    agent/train/clip_range            | 0.2         |
|    agent/train/entropy_loss          | -1.42       |
|    agent/train/explained_variance    | -0.0464     |
|    agent/train/learning_rate         | 0.0003      |
|    agent/train/loss                  | 0.743       |
|    agent/train/n_updates             | 20          |
|    agent/train/policy_gradient_loss  | -0.00231    |
|    agent/train/std                   | 1           |
|    agent/train/value_loss            | 2.25        |
------------------------------------------------------
------------------------------------------------------
| mean/                                  |           |
|    agent/rollout/ep_len_mean           | 200       |
|    agent/rollout/ep_rew_mean           | -1.28e+03 |
|    agent/rollout/ep_rew_wrapped_mean   | 46.4      |
|    agent/time/fps                      | 4.95e+03  |
|    agent/time/iterations               | 1         |
|    agent/time/time_elapsed             | 0         |
|    agent/time/total_timesteps          | 6.14e+03  |
|    agent/train/approx_kl               | 0.0017    |
|    agent/train/clip_fraction           | 0.00176   |
|    agent/train/clip_range              | 0.2       |
|    agent/train/entropy_loss            | -1.42     |
|    agent/train/explained_variance      | 0.214     |
|    agent/train/learning_rate           | 0.0003    |
|    agent/train/loss                    | 1.94      |
|    agent/train/n_updates               | 30        |
|    agent/train/policy_gradient_loss    | -0.000454 |
|    agent/train/std                     | 0.991     |
|    agent/train/value_loss              | 3.72      |
|    preferences/entropy                 | 1.11e-06  |
|    reward/epoch-0/train/accuracy       | 0.969     |
|    reward/epoch-0/train/gt_reward_loss | 0.00607   |
|    reward/epoch-0/train/loss           | 0.105     |
|    reward/epoch-1/train/accuracy       | 0.969     |
|    reward/epoch-1/train/gt_reward_loss | 0.0106    |
|    reward/epoch-1/train/loss           | 0.0978    |
|    reward/epoch-2/train/accuracy       | 0.977     |
|    reward/epoch-2/train/gt_reward_loss | 0.00607   |
|    reward/epoch-2/train/loss           | 0.0845    |
| reward/                                |           |
|    final/train/accuracy                | 0.977     |
|    final/train/gt_reward_loss          | 0.00607   |
|    final/train/loss                    | 0.0845    |
------------------------------------------------------
Collecting 68 fragments (6800 transitions)
Requested 6800 transitions but only 1600 in buffer. Sampling 5200 additional transitions.
Creating fragment pairs
Gathering preferences
Dataset now contains 146 comparisons
Training agent for 1000 timesteps
-------------------------------------------------------
| raw/                                 |              |
|    agent/rollout/ep_len_mean         | 200          |
|    agent/rollout/ep_rew_mean         | -1.29e+03    |
|    agent/rollout/ep_rew_wrapped_mean | 45           |
|    agent/time/fps                    | 4998         |
|    agent/time/iterations             | 1            |
|    agent/time/time_elapsed           | 0            |
|    agent/time/total_timesteps        | 8192         |
|    agent/train/approx_kl             | 0.0016984465 |
|    agent/train/clip_fraction         | 0.00176      |
|    agent/train/clip_range            | 0.2          |
|    agent/train/entropy_loss          | -1.42        |
|    agent/train/explained_variance    | 0.214        |
|    agent/train/learning_rate         | 0.0003       |
|    agent/train/loss                  | 1.94         |
|    agent/train/n_updates             | 30           |
|    agent/train/policy_gradient_loss  | -0.000454    |
|    agent/train/std                   | 0.991        |
|    agent/train/value_loss            | 3.72         |
-------------------------------------------------------
------------------------------------------------------
| mean/                                  |           |
|    agent/rollout/ep_len_mean           | 200       |
|    agent/rollout/ep_rew_mean           | -1.29e+03 |
|    agent/rollout/ep_rew_wrapped_mean   | 45        |
|    agent/time/fps                      | 5e+03     |
|    agent/time/iterations               | 1         |
|    agent/time/time_elapsed             | 0         |
|    agent/time/total_timesteps          | 8.19e+03  |
|    agent/train/approx_kl               | 0.00147   |
|    agent/train/clip_fraction           | 0.00308   |
|    agent/train/clip_range              | 0.2       |
|    agent/train/entropy_loss            | -1.4      |
|    agent/train/explained_variance      | 0.267     |
|    agent/train/learning_rate           | 0.0003    |
|    agent/train/loss                    | 2.95      |
|    agent/train/n_updates               | 40        |
|    agent/train/policy_gradient_loss    | -0.000462 |
|    agent/train/std                     | 0.973     |
|    agent/train/value_loss              | 4.44      |
|    preferences/entropy                 | 0.000701  |
|    reward/epoch-0/train/accuracy       | 0.975     |
|    reward/epoch-0/train/gt_reward_loss | 0.00488   |
|    reward/epoch-0/train/loss           | 0.0804    |
|    reward/epoch-1/train/accuracy       | 0.97      |
|    reward/epoch-1/train/gt_reward_loss | 0.0077    |
|    reward/epoch-1/train/loss           | 0.0931    |
|    reward/epoch-2/train/accuracy       | 0.975     |
|    reward/epoch-2/train/gt_reward_loss | 0.00488   |
|    reward/epoch-2/train/loss           | 0.0702    |
| reward/                                |           |
|    final/train/accuracy                | 0.975     |
|    final/train/gt_reward_loss          | 0.00488   |
|    final/train/loss                    | 0.0702    |
------------------------------------------------------
Collecting 58 fragments (5800 transitions)
Requested 5800 transitions but only 1600 in buffer. Sampling 4200 additional transitions.
Creating fragment pairs
Gathering preferences
Dataset now contains 175 comparisons
Training agent for 1000 timesteps
-------------------------------------------------------
| raw/                                 |              |
|    agent/rollout/ep_len_mean         | 200          |
|    agent/rollout/ep_rew_mean         | -1.26e+03    |
|    agent/rollout/ep_rew_wrapped_mean | 46.1         |
|    agent/time/fps                    | 4920         |
|    agent/time/iterations             | 1            |
|    agent/time/time_elapsed           | 0            |
|    agent/time/total_timesteps        | 10240        |
|    agent/train/approx_kl             | 0.0014707824 |
|    agent/train/clip_fraction         | 0.00308      |
|    agent/train/clip_range            | 0.2          |
|    agent/train/entropy_loss          | -1.4         |
|    agent/train/explained_variance    | 0.267        |
|    agent/train/learning_rate         | 0.0003       |
|    agent/train/loss                  | 2.95         |
|    agent/train/n_updates             | 40           |
|    agent/train/policy_gradient_loss  | -0.000462    |
|    agent/train/std                   | 0.973        |
|    agent/train/value_loss            | 4.44         |
-------------------------------------------------------
------------------------------------------------------
| mean/                                  |           |
|    agent/rollout/ep_len_mean           | 200       |
|    agent/rollout/ep_rew_mean           | -1.26e+03 |
|    agent/rollout/ep_rew_wrapped_mean   | 46.1      |
|    agent/time/fps                      | 4.92e+03  |
|    agent/time/iterations               | 1         |
|    agent/time/time_elapsed             | 0         |
|    agent/time/total_timesteps          | 1.02e+04  |
|    agent/train/approx_kl               | 0.00458   |
|    agent/train/clip_fraction           | 0.0319    |
|    agent/train/clip_range              | 0.2       |
|    agent/train/entropy_loss            | -1.4      |
|    agent/train/explained_variance      | 0.265     |
|    agent/train/learning_rate           | 0.0003    |
|    agent/train/loss                    | 1.72      |
|    agent/train/n_updates               | 50        |
|    agent/train/policy_gradient_loss    | -0.00411  |
|    agent/train/std                     | 0.982     |
|    agent/train/value_loss              | 7.02      |
|    preferences/entropy                 | 0.00086   |
|    reward/epoch-0/train/accuracy       | 0.974     |
|    reward/epoch-0/train/gt_reward_loss | 0.00409   |
|    reward/epoch-0/train/loss           | 0.103     |
|    reward/epoch-1/train/accuracy       | 0.969     |
|    reward/epoch-1/train/gt_reward_loss | 0.00409   |
|    reward/epoch-1/train/loss           | 0.0948    |
|    reward/epoch-2/train/accuracy       | 0.963     |
|    reward/epoch-2/train/gt_reward_loss | 0.0051    |
|    reward/epoch-2/train/loss           | 0.106     |
| reward/                                |           |
|    final/train/accuracy                | 0.963     |
|    final/train/gt_reward_loss          | 0.0051    |
|    final/train/loss                    | 0.106     |
------------------------------------------------------
Collecting 50 fragments (5000 transitions)
Requested 5000 transitions but only 1600 in buffer. Sampling 3400 additional transitions.
Creating fragment pairs
Gathering preferences
Dataset now contains 200 comparisons
Training agent for 1000 timesteps
-------------------------------------------------------
| raw/                                 |              |
|    agent/rollout/ep_len_mean         | 200          |
|    agent/rollout/ep_rew_mean         | -1.26e+03    |
|    agent/rollout/ep_rew_wrapped_mean | 50.8         |
|    agent/time/fps                    | 4976         |
|    agent/time/iterations             | 1            |
|    agent/time/time_elapsed           | 0            |
|    agent/time/total_timesteps        | 12288        |
|    agent/train/approx_kl             | 0.0045790263 |
|    agent/train/clip_fraction         | 0.0319       |
|    agent/train/clip_range            | 0.2          |
|    agent/train/entropy_loss          | -1.4         |
|    agent/train/explained_variance    | 0.265        |
|    agent/train/learning_rate         | 0.0003       |
|    agent/train/loss                  | 1.72         |
|    agent/train/n_updates             | 50           |
|    agent/train/policy_gradient_loss  | -0.00411     |
|    agent/train/std                   | 0.982        |
|    agent/train/value_loss            | 7.02         |
-------------------------------------------------------
------------------------------------------------------
| mean/                                  |           |
|    agent/rollout/ep_len_mean           | 200       |
|    agent/rollout/ep_rew_mean           | -1.26e+03 |
|    agent/rollout/ep_rew_wrapped_mean   | 50.8      |
|    agent/time/fps                      | 4.98e+03  |
|    agent/time/iterations               | 1         |
|    agent/time/time_elapsed             | 0         |
|    agent/time/total_timesteps          | 1.23e+04  |
|    agent/train/approx_kl               | 0.0026    |
|    agent/train/clip_fraction           | 0.014     |
|    agent/train/clip_range              | 0.2       |
|    agent/train/entropy_loss            | -1.4      |
|    agent/train/explained_variance      | 0.37      |
|    agent/train/learning_rate           | 0.0003    |
|    agent/train/loss                    | 3.02      |
|    agent/train/n_updates               | 60        |
|    agent/train/policy_gradient_loss    | -0.0024   |
|    agent/train/std                     | 0.974     |
|    agent/train/value_loss              | 7.18      |
|    preferences/entropy                 | 0.00229   |
|    reward/epoch-0/train/accuracy       | 0.969     |
|    reward/epoch-0/train/gt_reward_loss | 0.00355   |
|    reward/epoch-0/train/loss           | 0.0883    |
|    reward/epoch-1/train/accuracy       | 0.969     |
|    reward/epoch-1/train/gt_reward_loss | 0.00355   |
|    reward/epoch-1/train/loss           | 0.084     |
|    reward/epoch-2/train/accuracy       | 0.955     |
|    reward/epoch-2/train/gt_reward_loss | 0.0059    |
|    reward/epoch-2/train/loss           | 0.0948    |
| reward/                                |           |
|    final/train/accuracy                | 0.955     |
|    final/train/gt_reward_loss          | 0.0059    |
|    final/train/loss                    | 0.0948    |
------------------------------------------------------
{'reward_loss': 0.0947954399245126, 'reward_accuracy': 0.9553571428571429}

After we trained the reward network using the preference comparisons algorithm, we can wrap our environment with that learned reward.

from imitation.rewards.reward_wrapper import RewardVecEnvWrapper


learned_reward_venv = RewardVecEnvWrapper(venv, reward_net.predict_processed)

Now we can train an agent, that only sees those learned reward.

from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

learner = PPO(
    policy=MlpPolicy,
    env=learned_reward_venv,
    seed=0,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
    n_steps=64,
)
learner.learn(1000)  # Note: set to 100000 to train a proficient expert
<stable_baselines3.ppo.ppo.PPO at 0x7f3e449969d0>

Then we can evaluate it using the original reward.

from stable_baselines3.common.evaluation import evaluate_policy

reward, _ = evaluate_policy(learner.policy, venv, 10)
print(reward)
-1015.7399708