download this notebook here

Learning a Reward Function using Preference Comparisons#

The preference comparisons algorithm learns a reward function by comparing trajectory segments to each other.

To set up the preference comparisons algorithm, we first need to set up a lot of its internals beforehand:

import random
from imitation.algorithms import preference_comparisons
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env
from imitation.policies.base import FeedForward32Policy, NormalizeFeaturesExtractor
import gymnasium as gym
from stable_baselines3 import PPO
import numpy as np

rng = np.random.default_rng(0)

venv = make_vec_env("Pendulum-v1", rng=rng)

reward_net = BasicRewardNet(
    venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm
)

fragmenter = preference_comparisons.RandomFragmenter(
    warning_threshold=0,
    rng=rng,
)
gatherer = preference_comparisons.SyntheticGatherer(rng=rng)
preference_model = preference_comparisons.PreferenceModel(reward_net)
reward_trainer = preference_comparisons.BasicRewardTrainer(
    preference_model=preference_model,
    loss=preference_comparisons.CrossEntropyRewardLoss(),
    epochs=3,
    rng=rng,
)


# Several hyperparameters (reward_epochs, ppo_clip_range, ppo_ent_coef,
# ppo_gae_lambda, ppo_n_epochs, discount_factor, use_sde, sde_sample_freq,
# ppo_lr, exploration_frac, num_iterations, initial_comparison_frac,
# initial_epoch_multiplier, query_schedule) used in this example have been
# approximately fine-tuned to reach a reasonable level of performance.
agent = PPO(
    policy=FeedForward32Policy,
    policy_kwargs=dict(
        features_extractor_class=NormalizeFeaturesExtractor,
        features_extractor_kwargs=dict(normalize_class=RunningNorm),
    ),
    env=venv,
    seed=0,
    n_steps=2048 // venv.num_envs,
    batch_size=64,
    ent_coef=0.01,
    learning_rate=2e-3,
    clip_range=0.1,
    gae_lambda=0.95,
    gamma=0.97,
    n_epochs=10,
)

trajectory_generator = preference_comparisons.AgentTrainer(
    algorithm=agent,
    reward_fn=reward_net,
    venv=venv,
    exploration_frac=0.05,
    rng=rng,
)

pref_comparisons = preference_comparisons.PreferenceComparisons(
    trajectory_generator,
    reward_net,
    num_iterations=5,  # Set to 60 for better performance
    fragmenter=fragmenter,
    preference_gatherer=gatherer,
    reward_trainer=reward_trainer,
    fragment_length=100,
    transition_oversampling=1,
    initial_comparison_frac=0.1,
    allow_variable_horizon=False,
    initial_epoch_multiplier=4,
    query_schedule="hyperbolic",
)

Then we can start training the reward model. Note that we need to specify the total timesteps that the agent should be trained and how many fragment comparisons should be made.

pref_comparisons.train(
    total_timesteps=5_000,
    total_comparisons=200,
)
Query schedule: [20, 51, 41, 34, 29, 25]
Collecting 40 fragments (4000 transitions)
Requested 3800 transitions but only 0 in buffer. Sampling 3800 additional transitions.
Sampling 200 exploratory transitions.
Creating fragment pairs
Gathering preferences
Dataset now contains 20 comparisons
Training agent for 1000 timesteps
---------------------------------------------------
| raw/                                 |          |
|    agent/rollout/ep_len_mean         | 200      |
|    agent/rollout/ep_rew_mean         | -1.2e+03 |
|    agent/rollout/ep_rew_wrapped_mean | 32.6     |
|    agent/time/fps                    | 3838     |
|    agent/time/iterations             | 1        |
|    agent/time/time_elapsed           | 0        |
|    agent/time/total_timesteps        | 2048     |
---------------------------------------------------
------------------------------------------------------
| mean/                                   |          |
|    agent/rollout/ep_len_mean            | 200      |
|    agent/rollout/ep_rew_mean            | -1.2e+03 |
|    agent/rollout/ep_rew_wrapped_mean    | 32.6     |
|    agent/time/fps                       | 3.84e+03 |
|    agent/time/iterations                | 1        |
|    agent/time/time_elapsed              | 0        |
|    agent/time/total_timesteps           | 2.05e+03 |
|    agent/train/approx_kl                | 0.00269  |
|    agent/train/clip_fraction            | 0.114    |
|    agent/train/clip_range               | 0.1      |
|    agent/train/entropy_loss             | -1.44    |
|    agent/train/explained_variance       | -0.322   |
|    agent/train/learning_rate            | 0.002    |
|    agent/train/loss                     | 0.13     |
|    agent/train/n_updates                | 10       |
|    agent/train/policy_gradient_loss     | -0.00243 |
|    agent/train/std                      | 1.03     |
|    agent/train/value_loss               | 1.27     |
|    preferences/entropy                  | 0.0307   |
|    reward/epoch-0/train/accuracy        | 0.15     |
|    reward/epoch-0/train/gt_reward_loss  | 0.0639   |
|    reward/epoch-0/train/loss            | 3.79     |
|    reward/epoch-1/train/accuracy        | 0.2      |
|    reward/epoch-1/train/gt_reward_loss  | 0.0639   |
|    reward/epoch-1/train/loss            | 3.43     |
|    reward/epoch-10/train/accuracy       | 0.85     |
|    reward/epoch-10/train/gt_reward_loss | 0.0639   |
|    reward/epoch-10/train/loss           | 0.25     |
|    reward/epoch-11/train/accuracy       | 0.85     |
|    reward/epoch-11/train/gt_reward_loss | 0.0639   |
|    reward/epoch-11/train/loss           | 0.227    |
|    reward/epoch-2/train/accuracy        | 0.3      |
|    reward/epoch-2/train/gt_reward_loss  | 0.0639   |
|    reward/epoch-2/train/loss            | 2.58     |
|    reward/epoch-3/train/accuracy        | 0.35     |
|    reward/epoch-3/train/gt_reward_loss  | 0.0639   |
|    reward/epoch-3/train/loss            | 1.98     |
|    reward/epoch-4/train/accuracy        | 0.35     |
|    reward/epoch-4/train/gt_reward_loss  | 0.0639   |
|    reward/epoch-4/train/loss            | 1.39     |
|    reward/epoch-5/train/accuracy        | 0.55     |
|    reward/epoch-5/train/gt_reward_loss  | 0.0639   |
|    reward/epoch-5/train/loss            | 0.9      |
|    reward/epoch-6/train/accuracy        | 0.75     |
|    reward/epoch-6/train/gt_reward_loss  | 0.0639   |
|    reward/epoch-6/train/loss            | 0.601    |
|    reward/epoch-7/train/accuracy        | 0.75     |
|    reward/epoch-7/train/gt_reward_loss  | 0.0639   |
|    reward/epoch-7/train/loss            | 0.436    |
|    reward/epoch-8/train/accuracy        | 0.75     |
|    reward/epoch-8/train/gt_reward_loss  | 0.0639   |
|    reward/epoch-8/train/loss            | 0.343    |
|    reward/epoch-9/train/accuracy        | 0.8      |
|    reward/epoch-9/train/gt_reward_loss  | 0.0639   |
|    reward/epoch-9/train/loss            | 0.286    |
| reward/                                 |          |
|    final/train/accuracy                 | 0.85     |
|    final/train/gt_reward_loss           | 0.0639   |
|    final/train/loss                     | 0.227    |
------------------------------------------------------
Collecting 102 fragments (10200 transitions)
Requested 9690 transitions but only 1600 in buffer. Sampling 8090 additional transitions.
Sampling 510 exploratory transitions.
Creating fragment pairs
Gathering preferences
Dataset now contains 71 comparisons
Training agent for 1000 timesteps
-------------------------------------------------------
| raw/                                 |              |
|    agent/rollout/ep_len_mean         | 200          |
|    agent/rollout/ep_rew_mean         | -1.13e+03    |
|    agent/rollout/ep_rew_wrapped_mean | 47           |
|    agent/time/fps                    | 3863         |
|    agent/time/iterations             | 1            |
|    agent/time/time_elapsed           | 0            |
|    agent/time/total_timesteps        | 4096         |
|    agent/train/approx_kl             | 0.0026872663 |
|    agent/train/clip_fraction         | 0.114        |
|    agent/train/clip_range            | 0.1          |
|    agent/train/entropy_loss          | -1.44        |
|    agent/train/explained_variance    | -0.322       |
|    agent/train/learning_rate         | 0.002        |
|    agent/train/loss                  | 0.13         |
|    agent/train/n_updates             | 10           |
|    agent/train/policy_gradient_loss  | -0.00243     |
|    agent/train/std                   | 1.03         |
|    agent/train/value_loss            | 1.27         |
-------------------------------------------------------
------------------------------------------------------
| mean/                                  |           |
|    agent/rollout/ep_len_mean           | 200       |
|    agent/rollout/ep_rew_mean           | -1.13e+03 |
|    agent/rollout/ep_rew_wrapped_mean   | 47        |
|    agent/time/fps                      | 3.86e+03  |
|    agent/time/iterations               | 1         |
|    agent/time/time_elapsed             | 0         |
|    agent/time/total_timesteps          | 4.1e+03   |
|    agent/train/approx_kl               | 0.00058   |
|    agent/train/clip_fraction           | 0.0301    |
|    agent/train/clip_range              | 0.1       |
|    agent/train/entropy_loss            | -1.46     |
|    agent/train/explained_variance      | 0.436     |
|    agent/train/learning_rate           | 0.002     |
|    agent/train/loss                    | 0.112     |
|    agent/train/n_updates               | 20        |
|    agent/train/policy_gradient_loss    | -0.000273 |
|    agent/train/std                     | 1.05      |
|    agent/train/value_loss              | 0.588     |
|    preferences/entropy                 | 0.00161   |
|    reward/epoch-0/train/accuracy       | 0.838     |
|    reward/epoch-0/train/gt_reward_loss | 0.0135    |
|    reward/epoch-0/train/loss           | 0.35      |
|    reward/epoch-1/train/accuracy       | 0.906     |
|    reward/epoch-1/train/gt_reward_loss | 0.0135    |
|    reward/epoch-1/train/loss           | 0.253     |
|    reward/epoch-2/train/accuracy       | 0.879     |
|    reward/epoch-2/train/gt_reward_loss | 0.0135    |
|    reward/epoch-2/train/loss           | 0.315     |
| reward/                                |           |
|    final/train/accuracy                | 0.879     |
|    final/train/gt_reward_loss          | 0.0135    |
|    final/train/loss                    | 0.315     |
------------------------------------------------------
Collecting 82 fragments (8200 transitions)
Requested 7790 transitions but only 1600 in buffer. Sampling 6190 additional transitions.
Sampling 410 exploratory transitions.
Creating fragment pairs
Gathering preferences
Dataset now contains 112 comparisons
Training agent for 1000 timesteps
-------------------------------------------------------
| raw/                                 |              |
|    agent/rollout/ep_len_mean         | 200          |
|    agent/rollout/ep_rew_mean         | -1.16e+03    |
|    agent/rollout/ep_rew_wrapped_mean | 56           |
|    agent/time/fps                    | 3862         |
|    agent/time/iterations             | 1            |
|    agent/time/time_elapsed           | 0            |
|    agent/time/total_timesteps        | 6144         |
|    agent/train/approx_kl             | 0.0005802552 |
|    agent/train/clip_fraction         | 0.0301       |
|    agent/train/clip_range            | 0.1          |
|    agent/train/entropy_loss          | -1.46        |
|    agent/train/explained_variance    | 0.436        |
|    agent/train/learning_rate         | 0.002        |
|    agent/train/loss                  | 0.112        |
|    agent/train/n_updates             | 20           |
|    agent/train/policy_gradient_loss  | -0.000273    |
|    agent/train/std                   | 1.05         |
|    agent/train/value_loss            | 0.588        |
-------------------------------------------------------
------------------------------------------------------
| mean/                                  |           |
|    agent/rollout/ep_len_mean           | 200       |
|    agent/rollout/ep_rew_mean           | -1.16e+03 |
|    agent/rollout/ep_rew_wrapped_mean   | 56        |
|    agent/time/fps                      | 3.86e+03  |
|    agent/time/iterations               | 1         |
|    agent/time/time_elapsed             | 0         |
|    agent/time/total_timesteps          | 6.14e+03  |
|    agent/train/approx_kl               | 0.00145   |
|    agent/train/clip_fraction           | 0.0564    |
|    agent/train/clip_range              | 0.1       |
|    agent/train/entropy_loss            | -1.47     |
|    agent/train/explained_variance      | 0.694     |
|    agent/train/learning_rate           | 0.002     |
|    agent/train/loss                    | 0.0207    |
|    agent/train/n_updates               | 30        |
|    agent/train/policy_gradient_loss    | -0.00223  |
|    agent/train/std                     | 1.05      |
|    agent/train/value_loss              | 0.198     |
|    preferences/entropy                 | 0.000825  |
|    reward/epoch-0/train/accuracy       | 0.914     |
|    reward/epoch-0/train/gt_reward_loss | 0.0102    |
|    reward/epoch-0/train/loss           | 0.201     |
|    reward/epoch-1/train/accuracy       | 0.938     |
|    reward/epoch-1/train/gt_reward_loss | 0.0102    |
|    reward/epoch-1/train/loss           | 0.148     |
|    reward/epoch-2/train/accuracy       | 0.945     |
|    reward/epoch-2/train/gt_reward_loss | 0.0101    |
|    reward/epoch-2/train/loss           | 0.126     |
| reward/                                |           |
|    final/train/accuracy                | 0.945     |
|    final/train/gt_reward_loss          | 0.0101    |
|    final/train/loss                    | 0.126     |
------------------------------------------------------
Collecting 68 fragments (6800 transitions)
Requested 6460 transitions but only 1600 in buffer. Sampling 4860 additional transitions.
Sampling 340 exploratory transitions.
Creating fragment pairs
Gathering preferences
Dataset now contains 146 comparisons
Training agent for 1000 timesteps
-------------------------------------------------------
| raw/                                 |              |
|    agent/rollout/ep_len_mean         | 200          |
|    agent/rollout/ep_rew_mean         | -1.19e+03    |
|    agent/rollout/ep_rew_wrapped_mean | 57.8         |
|    agent/time/fps                    | 3837         |
|    agent/time/iterations             | 1            |
|    agent/time/time_elapsed           | 0            |
|    agent/time/total_timesteps        | 8192         |
|    agent/train/approx_kl             | 0.0014491911 |
|    agent/train/clip_fraction         | 0.0564       |
|    agent/train/clip_range            | 0.1          |
|    agent/train/entropy_loss          | -1.47        |
|    agent/train/explained_variance    | 0.694        |
|    agent/train/learning_rate         | 0.002        |
|    agent/train/loss                  | 0.0207       |
|    agent/train/n_updates             | 30           |
|    agent/train/policy_gradient_loss  | -0.00223     |
|    agent/train/std                   | 1.05         |
|    agent/train/value_loss            | 0.198        |
-------------------------------------------------------
------------------------------------------------------
| mean/                                  |           |
|    agent/rollout/ep_len_mean           | 200       |
|    agent/rollout/ep_rew_mean           | -1.19e+03 |
|    agent/rollout/ep_rew_wrapped_mean   | 57.8      |
|    agent/time/fps                      | 3.84e+03  |
|    agent/time/iterations               | 1         |
|    agent/time/time_elapsed             | 0         |
|    agent/time/total_timesteps          | 8.19e+03  |
|    agent/train/approx_kl               | 0.00167   |
|    agent/train/clip_fraction           | 0.0817    |
|    agent/train/clip_range              | 0.1       |
|    agent/train/entropy_loss            | -1.47     |
|    agent/train/explained_variance      | 0.89      |
|    agent/train/learning_rate           | 0.002     |
|    agent/train/loss                    | 0.00424   |
|    agent/train/n_updates               | 40        |
|    agent/train/policy_gradient_loss    | -0.00385  |
|    agent/train/std                     | 1.06      |
|    agent/train/value_loss              | 0.13      |
|    preferences/entropy                 | 0.0186    |
|    reward/epoch-0/train/accuracy       | 0.947     |
|    reward/epoch-0/train/gt_reward_loss | 0.0168    |
|    reward/epoch-0/train/loss           | 0.13      |
|    reward/epoch-1/train/accuracy       | 0.958     |
|    reward/epoch-1/train/gt_reward_loss | 0.0106    |
|    reward/epoch-1/train/loss           | 0.13      |
|    reward/epoch-2/train/accuracy       | 0.958     |
|    reward/epoch-2/train/gt_reward_loss | 0.0125    |
|    reward/epoch-2/train/loss           | 0.12      |
| reward/                                |           |
|    final/train/accuracy                | 0.958     |
|    final/train/gt_reward_loss          | 0.0125    |
|    final/train/loss                    | 0.12      |
------------------------------------------------------
Collecting 58 fragments (5800 transitions)
Requested 5510 transitions but only 1600 in buffer. Sampling 3910 additional transitions.
Sampling 290 exploratory transitions.
Creating fragment pairs
Gathering preferences
Dataset now contains 175 comparisons
Training agent for 1000 timesteps
-------------------------------------------------------
| raw/                                 |              |
|    agent/rollout/ep_len_mean         | 200          |
|    agent/rollout/ep_rew_mean         | -1.21e+03    |
|    agent/rollout/ep_rew_wrapped_mean | 57.7         |
|    agent/time/fps                    | 3823         |
|    agent/time/iterations             | 1            |
|    agent/time/time_elapsed           | 0            |
|    agent/time/total_timesteps        | 10240        |
|    agent/train/approx_kl             | 0.0016703831 |
|    agent/train/clip_fraction         | 0.0817       |
|    agent/train/clip_range            | 0.1          |
|    agent/train/entropy_loss          | -1.47        |
|    agent/train/explained_variance    | 0.89         |
|    agent/train/learning_rate         | 0.002        |
|    agent/train/loss                  | 0.00424      |
|    agent/train/n_updates             | 40           |
|    agent/train/policy_gradient_loss  | -0.00385     |
|    agent/train/std                   | 1.06         |
|    agent/train/value_loss            | 0.13         |
-------------------------------------------------------
------------------------------------------------------
| mean/                                  |           |
|    agent/rollout/ep_len_mean           | 200       |
|    agent/rollout/ep_rew_mean           | -1.21e+03 |
|    agent/rollout/ep_rew_wrapped_mean   | 57.7      |
|    agent/time/fps                      | 3.82e+03  |
|    agent/time/iterations               | 1         |
|    agent/time/time_elapsed             | 0         |
|    agent/time/total_timesteps          | 1.02e+04  |
|    agent/train/approx_kl               | 0.00467   |
|    agent/train/clip_fraction           | 0.202     |
|    agent/train/clip_range              | 0.1       |
|    agent/train/entropy_loss            | -1.5      |
|    agent/train/explained_variance      | 0.946     |
|    agent/train/learning_rate           | 0.002     |
|    agent/train/loss                    | 0.0124    |
|    agent/train/n_updates               | 50        |
|    agent/train/policy_gradient_loss    | -0.0108   |
|    agent/train/std                     | 1.08      |
|    agent/train/value_loss              | 0.129     |
|    preferences/entropy                 | 0.00135   |
|    reward/epoch-0/train/accuracy       | 0.947     |
|    reward/epoch-0/train/gt_reward_loss | 0.00886   |
|    reward/epoch-0/train/loss           | 0.115     |
|    reward/epoch-1/train/accuracy       | 0.958     |
|    reward/epoch-1/train/gt_reward_loss | 0.00886   |
|    reward/epoch-1/train/loss           | 0.0979    |
|    reward/epoch-2/train/accuracy       | 0.969     |
|    reward/epoch-2/train/gt_reward_loss | 0.0112    |
|    reward/epoch-2/train/loss           | 0.102     |
| reward/                                |           |
|    final/train/accuracy                | 0.969     |
|    final/train/gt_reward_loss          | 0.0112    |
|    final/train/loss                    | 0.102     |
------------------------------------------------------
Collecting 50 fragments (5000 transitions)
Requested 4750 transitions but only 1600 in buffer. Sampling 3150 additional transitions.
Sampling 250 exploratory transitions.
Creating fragment pairs
Gathering preferences
Dataset now contains 200 comparisons
Training agent for 1000 timesteps
-------------------------------------------------------
| raw/                                 |              |
|    agent/rollout/ep_len_mean         | 200          |
|    agent/rollout/ep_rew_mean         | -1.21e+03    |
|    agent/rollout/ep_rew_wrapped_mean | 56.4         |
|    agent/time/fps                    | 3839         |
|    agent/time/iterations             | 1            |
|    agent/time/time_elapsed           | 0            |
|    agent/time/total_timesteps        | 12288        |
|    agent/train/approx_kl             | 0.0046693617 |
|    agent/train/clip_fraction         | 0.202        |
|    agent/train/clip_range            | 0.1          |
|    agent/train/entropy_loss          | -1.5         |
|    agent/train/explained_variance    | 0.946        |
|    agent/train/learning_rate         | 0.002        |
|    agent/train/loss                  | 0.0124       |
|    agent/train/n_updates             | 50           |
|    agent/train/policy_gradient_loss  | -0.0108      |
|    agent/train/std                   | 1.08         |
|    agent/train/value_loss            | 0.129        |
-------------------------------------------------------
------------------------------------------------------
| mean/                                  |           |
|    agent/rollout/ep_len_mean           | 200       |
|    agent/rollout/ep_rew_mean           | -1.21e+03 |
|    agent/rollout/ep_rew_wrapped_mean   | 56.4      |
|    agent/time/fps                      | 3.84e+03  |
|    agent/time/iterations               | 1         |
|    agent/time/time_elapsed             | 0         |
|    agent/time/total_timesteps          | 1.23e+04  |
|    agent/train/approx_kl               | 0.00137   |
|    agent/train/clip_fraction           | 0.0687    |
|    agent/train/clip_range              | 0.1       |
|    agent/train/entropy_loss            | -1.5      |
|    agent/train/explained_variance      | 0.971     |
|    agent/train/learning_rate           | 0.002     |
|    agent/train/loss                    | 0.21      |
|    agent/train/n_updates               | 60        |
|    agent/train/policy_gradient_loss    | -0.00218  |
|    agent/train/std                     | 1.09      |
|    agent/train/value_loss              | 0.144     |
|    preferences/entropy                 | 0.000231  |
|    reward/epoch-0/train/accuracy       | 0.969     |
|    reward/epoch-0/train/gt_reward_loss | 0.00759   |
|    reward/epoch-0/train/loss           | 0.116     |
|    reward/epoch-1/train/accuracy       | 0.964     |
|    reward/epoch-1/train/gt_reward_loss | 0.00759   |
|    reward/epoch-1/train/loss           | 0.0973    |
|    reward/epoch-2/train/accuracy       | 0.955     |
|    reward/epoch-2/train/gt_reward_loss | 0.00765   |
|    reward/epoch-2/train/loss           | 0.114     |
| reward/                                |           |
|    final/train/accuracy                | 0.955     |
|    final/train/gt_reward_loss          | 0.00765   |
|    final/train/loss                    | 0.114     |
------------------------------------------------------
{'reward_loss': 0.11424736359289714, 'reward_accuracy': 0.9553571428571429}

After we trained the reward network using the preference comparisons algorithm, we can wrap our environment with that learned reward.

from imitation.rewards.reward_wrapper import RewardVecEnvWrapper

learned_reward_venv = RewardVecEnvWrapper(venv, reward_net.predict_processed)

Next, we train an agent that sees only the shaped, learned reward.

learner = PPO(
    seed=0,
    policy=FeedForward32Policy,
    policy_kwargs=dict(
        features_extractor_class=NormalizeFeaturesExtractor,
        features_extractor_kwargs=dict(normalize_class=RunningNorm),
    ),
    env=learned_reward_venv,
    batch_size=64,
    ent_coef=0.01,
    n_epochs=10,
    n_steps=2048 // learned_reward_venv.num_envs,
    clip_range=0.1,
    gae_lambda=0.95,
    gamma=0.97,
    learning_rate=2e-3,
)
learner.learn(1_000)  # Note: set to 100_000 to train a proficient expert
<stable_baselines3.ppo.ppo.PPO at 0x7f90e10510d0>

Then we can evaluate it using the original reward.

from stable_baselines3.common.evaluation import evaluate_policy

n_eval_episodes = 10
reward_mean, reward_std = evaluate_policy(learner.policy, venv, n_eval_episodes)
reward_stderr = reward_std / np.sqrt(n_eval_episodes)
print(f"Reward: {reward_mean:.0f} +/- {reward_stderr:.0f}")
Reward: -1348 +/- 114