download this notebook here

Train an Agent using the DAgger Algorithm#

The DAgger algorithm is an extension of behavior cloning. In behavior cloning, the training trajectories are recorded directly from an expert. In DAgger, the learner generates the trajectories but an expert corrects the actions with the optimal actions in each of the visited states. This ensures that the state distribution of the training data matches that of the learner’s current policy.

First we need an expert to learn from:

import gym
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

env = gym.make("CartPole-v1")
expert = PPO(
    policy=MlpPolicy,
    env=env,
    seed=0,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
    n_steps=64,
)
expert.learn(1000)  # Note: set to 100000 to train a proficient expert
<stable_baselines3.ppo.ppo.PPO at 0x7fa214199e50>

Then we can construct a DAgger trainer und use it to train the policy on the cartpole environment.

import tempfile
import gym
import numpy as np
from stable_baselines3.common.vec_env import DummyVecEnv

from imitation.algorithms import bc
from imitation.algorithms.dagger import SimpleDAggerTrainer

venv = DummyVecEnv([lambda: gym.make("CartPole-v1")])


bc_trainer = bc.BC(
    observation_space=env.observation_space,
    action_space=env.action_space,
    rng=np.random.default_rng(),
)

with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir:
    print(tmpdir)
    dagger_trainer = SimpleDAggerTrainer(
        venv=venv,
        scratch_dir=tmpdir,
        expert_policy=expert,
        bc_trainer=bc_trainer,
        rng=np.random.default_rng(),
    )

    dagger_trainer.train(2000)
/tmp/dagger_example_twcgrzsh
---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 0         |
|    ent_loss       | -0.000693 |
|    entropy        | 0.693     |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 72.5      |
|    loss           | 0.693     |
|    neglogp        | 0.693     |
|    prob_true_act  | 0.5       |
|    samples_so_far | 32        |
| rollout/          |           |
|    return_max     | 47        |
|    return_mean    | 28.8      |
|    return_min     | 15        |
|    return_std     | 11.3      |
---------------------------------
---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 0         |
|    ent_loss       | -0.000557 |
|    entropy        | 0.557     |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 78.8      |
|    loss           | 0.344     |
|    neglogp        | 0.345     |
|    prob_true_act  | 0.72      |
|    samples_so_far | 32        |
| rollout/          |           |
|    return_max     | 59        |
|    return_mean    | 49.2      |
|    return_min     | 38        |
|    return_std     | 7.7       |
---------------------------------
---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 0         |
|    ent_loss       | -0.000157 |
|    entropy        | 0.157     |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 95.6      |
|    loss           | 0.0713    |
|    neglogp        | 0.0715    |
|    prob_true_act  | 0.939     |
|    samples_so_far | 32        |
| rollout/          |           |
|    return_max     | 57        |
|    return_mean    | 45.4      |
|    return_min     | 36        |
|    return_std     | 6.83      |
---------------------------------
---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 0         |
|    ent_loss       | -9.16e-05 |
|    entropy        | 0.0916    |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 109       |
|    loss           | 0.0362    |
|    neglogp        | 0.0363    |
|    prob_true_act  | 0.968     |
|    samples_so_far | 32        |
| rollout/          |           |
|    return_max     | 67        |
|    return_mean    | 50.4      |
|    return_min     | 40        |
|    return_std     | 10.4      |
---------------------------------

Finally, the evaluation shows, that we actually trained a policy that solves the environment (500 is the max reward).

from stable_baselines3.common.evaluation import evaluate_policy

reward, _ = evaluate_policy(dagger_trainer.policy, env, 10)
print(reward)
52.3