download this notebook here

Learning a Reward Function using Kernel Density#

This demo shows how to train a Pendulum agent (exciting!) with our simple density-based imitation learning baselines. DensityTrainer has a few interesting parameters, but the key ones are:

  1. density_type: this governs whether density is measured on \((s,s')\) pairs (db.STATE_STATE_DENSITY), \((s,a)\) pairs (db.STATE_ACTION_DENSITY), or single states (db.STATE_DENSITY).

  2. is_stationary: determines whether a separate density model is used for each time step \(t\) (False), or the same model is used for transitions at all times (True).

  3. standardise_inputs: if True, each dimension of the agent state vectors will be normalised to have zero mean and unit variance over the training dataset. This can be useful when not all elements of the demonstration vector are on the same scale, or when some elements have too wide a variation to be captured by the fixed kernel width (1 for Gaussian kernel).

  4. kernel: changes the kernel used for non-parametric density estimation. gaussian and exponential are the best bets; see the sklearn docs for the rest.

import pprint

from imitation.algorithms import density as db
from imitation.data import types
from imitation.util import util
# Set FAST = False for longer training. Use True for testing and CI.
FAST = True

if FAST:
    N_VEC = 1
    N_TRAJECTORIES = 1
    N_ITERATIONS = 1
    N_RL_TRAIN_STEPS = 100

else:
    N_VEC = 8
    N_TRAJECTORIES = 10
    N_ITERATIONS = 100
    N_RL_TRAIN_STEPS = int(1e5)
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3 import PPO
from huggingface_sb3 import load_from_hub
from imitation.data import rollout
from stable_baselines3.common.vec_env import DummyVecEnv
from imitation.data.wrappers import RolloutInfoWrapper
import gym
import numpy as np


rng = np.random.default_rng()
env_name = "Pendulum-v1"
expert = PPO.load(
    load_from_hub("HumanCompatibleAI/ppo-Pendulum-v1", "ppo-Pendulum-v1.zip")
).policy
rollout_env = DummyVecEnv(
    [lambda: RolloutInfoWrapper(gym.make(env_name)) for _ in range(N_VEC)]
)
rollouts = rollout.rollout(
    expert,
    rollout_env,
    rollout.make_sample_until(min_timesteps=2000, min_episodes=57),
    rng=rng,
)

env = util.make_vec_env(env_name, n_envs=N_VEC, rng=rng)


imitation_trainer = PPO(ActorCriticPolicy, env, learning_rate=3e-4, n_steps=2048)
density_trainer = db.DensityAlgorithm(
    venv=env,
    rng=rng,
    demonstrations=rollouts,
    rl_algo=imitation_trainer,
    density_type=db.DensityType.STATE_ACTION_DENSITY,
    is_stationary=True,
    kernel="gaussian",
    kernel_bandwidth=0.2,  # found using divination & some palm reading
    standardise_inputs=True,
)
density_trainer.train()
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[3], line 13
     11 rng = np.random.default_rng()
     12 env_name = "Pendulum-v1"
---> 13 expert = PPO.load(
     14     load_from_hub("HumanCompatibleAI/ppo-Pendulum-v1", "ppo-Pendulum-v1.zip")
     15 ).policy
     16 rollout_env = DummyVecEnv(
     17     [lambda: RolloutInfoWrapper(gym.make(env_name)) for _ in range(N_VEC)]
     18 )
     19 rollouts = rollout.rollout(
     20     expert,
     21     rollout_env,
     22     rollout.make_sample_until(min_timesteps=2000, min_episodes=57),
     23     rng=rng,
     24 )

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/stable_baselines3/common/base_class.py:717, in BaseAlgorithm.load(cls, path, env, device, custom_objects, print_system_info, force_reset, **kwargs)
    715 model.__dict__.update(data)
    716 model.__dict__.update(kwargs)
--> 717 model._setup_model()
    719 try:
    720     # put state_dicts back in place
    721     model.set_parameters(params, exact_match=True, device=device)

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/stable_baselines3/ppo/ppo.py:167, in PPO._setup_model(self)
    166 def _setup_model(self) -> None:
--> 167     super()._setup_model()
    169     # Initialize schedules for policy/value clipping
    170     self.clip_range = get_schedule_fn(self.clip_range)

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/stable_baselines3/common/on_policy_algorithm.py:111, in OnPolicyAlgorithm._setup_model(self)
    107 self.set_random_seed(self.seed)
    109 buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer
--> 111 self.rollout_buffer = buffer_cls(
    112     self.n_steps,
    113     self.observation_space,
    114     self.action_space,
    115     device=self.device,
    116     gamma=self.gamma,
    117     gae_lambda=self.gae_lambda,
    118     n_envs=self.n_envs,
    119 )
    120 self.policy = self.policy_class(  # pytype:disable=not-instantiable
    121     self.observation_space,
    122     self.action_space,
   (...)
    125     **self.policy_kwargs  # pytype:disable=not-instantiable
    126 )
    127 self.policy = self.policy.to(self.device)

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/stable_baselines3/common/buffers.py:348, in RolloutBuffer.__init__(self, buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
    338 def __init__(
    339     self,
    340     buffer_size: int,
   (...)
    346     n_envs: int = 1,
    347 ):
--> 348     super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
    349     self.gae_lambda = gae_lambda
    350     self.gamma = gamma

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/stable_baselines3/common/buffers.py:50, in BaseBuffer.__init__(self, buffer_size, observation_space, action_space, device, n_envs)
     48 self.observation_space = observation_space
     49 self.action_space = action_space
---> 50 self.obs_shape = get_obs_shape(observation_space)
     52 self.action_dim = get_action_dim(action_space)
     53 self.pos = 0

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/stable_baselines3/common/preprocessing.py:169, in get_obs_shape(observation_space)
    166     return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()}  # type: ignore[misc]
    168 else:
--> 169     raise NotImplementedError(f"{observation_space} observation space is not supported")

NotImplementedError: Box([-1. -1. -8.], [1. 1. 8.], (3,), float32) observation space is not supported
def print_stats(density_trainer, n_trajectories, epoch=""):
    stats = density_trainer.test_policy(n_trajectories=n_trajectories)
    print("True reward function stats:")
    pprint.pprint(stats)
    stats_im = density_trainer.test_policy(
        true_reward=False,
        n_trajectories=n_trajectories,
    )
    print(f"Imitation reward function stats, epoch {epoch}:")
    pprint.pprint(stats_im)


novice_stats = density_trainer.test_policy(n_trajectories=N_TRAJECTORIES)
print("Stats before training:")
print_stats(density_trainer, 1)

print("Stats after training:")
for i in range(N_ITERATIONS):
    density_trainer.train_policy(N_RL_TRAIN_STEPS)
    print_stats(density_trainer, 1, epoch=str(i))