download this notebook here

Train an Agent using Generative Adversarial Imitation Learning#

The idea of generative adversarial imitation learning is to train a discriminator network to distinguish between expert trajectories and learner trajectories. The learner is trained using a traditional reinforcement learning algorithm such as PPO and is rewarded for trajectories that make the discriminator think that it was an expert trajectory.

As usual, we first need an expert. Note that we now use a variant of the CartPole environment from the seals package, which has fixed episode durations. Read more about why we do this here.

import gym
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
import seals  # needed to load environments

env = gym.make("seals/CartPole-v0")
expert = PPO(
    policy=MlpPolicy,
    env=env,
    seed=0,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
    n_steps=64,
)
expert.learn(1000)  # Note: set to 100000 to train a proficient expert
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/gym/envs/registration.py:158, in EnvRegistry.spec(self, path)
    157 try:
--> 158     return self.env_specs[id]
    159 except KeyError:
    160     # Parse the env name and check to see if it matches the non-version
    161     # part of a valid env (could also check the exact number here)

KeyError: 'seals/CartPole-v0'

During handling of the above exception, another exception occurred:

DeprecatedEnv                             Traceback (most recent call last)
Cell In[1], line 6
      3 from stable_baselines3.ppo import MlpPolicy
      4 import seals  # needed to load environments
----> 6 env = gym.make("seals/CartPole-v0")
      7 expert = PPO(
      8     policy=MlpPolicy,
      9     env=env,
   (...)
     15     n_steps=64,
     16 )
     17 expert.learn(1000)  # Note: set to 100000 to train a proficient expert

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/gym/envs/registration.py:235, in make(id, **kwargs)
    234 def make(id, **kwargs):
--> 235     return registry.make(id, **kwargs)

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/gym/envs/registration.py:128, in EnvRegistry.make(self, path, **kwargs)
    126 else:
    127     logger.info("Making new env: %s", path)
--> 128 spec = self.spec(path)
    129 env = spec.make(**kwargs)
    130 return env

File ~/checkouts/readthedocs.org/user_builds/imitation/envs/stable/lib/python3.8/site-packages/gym/envs/registration.py:185, in EnvRegistry.spec(self, path)
    176 toytext_envs = [
    177     "KellyCoinflip",
    178     "KellyCoinflipGeneralized",
   (...)
    182     "HotterColder",
    183 ]
    184 if matching_envs:
--> 185     raise error.DeprecatedEnv(
    186         "Env {} not found (valid versions include {})".format(
    187             id, matching_envs
    188         )
    189     )
    190 elif env_name in algorithmic_envs:
    191     raise error.UnregisteredEnv(
    192         "Algorithmic environment {} has been moved out of Gym. Install it via `pip install gym-algorithmic` and add `import gym_algorithmic` before using it.".format(
    193             id
    194         )
    195     )

DeprecatedEnv: Env seals/CartPole-v0 not found (valid versions include ['CartPole-v0', 'CartPole-v1'])

We generate some expert trajectories, that the discriminator needs to distinguish from the learner’s trajectories.

from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.util.util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv
import numpy as np

rng = np.random.default_rng()
rollouts = rollout.rollout(
    expert,
    make_vec_env(
        "seals/CartPole-v0",
        n_envs=5,
        post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
        rng=rng,
    ),
    rollout.make_sample_until(min_timesteps=None, min_episodes=60),
    rng=rng,
)

Now we are ready to set up our GAIL trainer. Note, that the reward_net is actually the network of the discriminator. We evaluate the learner before and after training so we can see if it made any progress.

from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv

import gym


venv = make_vec_env("seals/CartPole-v0", n_envs=8, rng=rng)
learner = PPO(
    env=venv,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
)
reward_net = BasicRewardNet(
    venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm
)
gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=2048,
    n_disc_updates_per_round=4,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net,
)

learner_rewards_before_training, _ = evaluate_policy(
    learner, venv, 100, return_episode_rewards=True
)
gail_trainer.train(20000)  # Note: set to 300000 for better results
learner_rewards_after_training, _ = evaluate_policy(
    learner, venv, 100, return_episode_rewards=True
)

When we look at the histograms of rewards before and after learning, we can see that the learner is not perfect yet, but it made some progress at least. If not, just re-run the above cell.

import matplotlib.pyplot as plt
import numpy as np

print(np.mean(learner_rewards_after_training))
print(np.mean(learner_rewards_before_training))

plt.hist(
    [learner_rewards_before_training, learner_rewards_after_training],
    label=["untrained", "trained"],
)
plt.legend()
plt.show()