Generative Adversarial Imitation Learning (GAIL)#
GAIL learns a policy by simultaneously training it with a discriminator that aims to distinguish expert trajectories against trajectories from the learned policy.
Note
GAIL paper: Generative Adversarial Imitation Learning
Example#
Detailed example notebook: Train an Agent using Generative Adversarial Imitation Learning
import numpy as np
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy
from imitation.algorithms.adversarial.gail import GAIL
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.serialize import load_policy
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env
SEED = 42
env = make_vec_env(
"seals:seals/CartPole-v0",
rng=np.random.default_rng(SEED),
n_envs=8,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], # to compute rollouts
)
expert = load_policy(
"ppo-huggingface",
organization="HumanCompatibleAI",
env_name="seals-CartPole-v0",
venv=env,
)
rollouts = rollout.rollout(
expert,
env,
rollout.make_sample_until(min_timesteps=None, min_episodes=60),
rng=np.random.default_rng(SEED),
)
learner = PPO(
env=env,
policy=MlpPolicy,
batch_size=64,
ent_coef=0.0,
learning_rate=0.0004,
gamma=0.95,
n_epochs=5,
seed=SEED,
)
reward_net = BasicRewardNet(
observation_space=env.observation_space,
action_space=env.action_space,
normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
demonstrations=rollouts,
demo_batch_size=1024,
gen_replay_buffer_capacity=512,
n_disc_updates_per_round=8,
venv=env,
gen_algo=learner,
reward_net=reward_net,
)
# evaluate the learner before training
env.seed(SEED)
learner_rewards_before_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
)
# train the learner and evaluate again
gail_trainer.train(20000) # Train for 800_000 steps to match expert.
env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
)
print("mean reward after training:", np.mean(learner_rewards_after_training))
print("mean reward before training:", np.mean(learner_rewards_before_training))
API#
- class imitation.algorithms.adversarial.gail.GAIL(*, demonstrations, demo_batch_size, venv, gen_algo, reward_net, **kwargs)[source]
Bases:
AdversarialTrainer
Generative Adversarial Imitation Learning (GAIL).
- __init__(*, demonstrations, demo_batch_size, venv, gen_algo, reward_net, **kwargs)[source]
Generative Adversarial Imitation Learning.
- Parameters
demonstrations (
Union
[Iterable
[Trajectory
],Iterable
[TransitionMapping
],TransitionsMinimal
]) – Demonstrations from an expert (optional). Transitions expressed directly as a types.TransitionsMinimal object, a sequence of trajectories, or an iterable of transition batches (mappings from keywords to arrays containing observations, etc).demo_batch_size (
int
) – The number of samples in each batch of expert data. The discriminator batch size is twice this number because each discriminator batch contains a generator sample for every expert sample.venv (
VecEnv
) – The vectorized environment to train in.gen_algo (
BaseAlgorithm
) – The generator RL algorithm that is trained to maximize discriminator confusion. Environment and logger will be set to venv and custom_logger.reward_net (
RewardNet
) – a Torch module that takes an observation, action and next observation tensor as input, then computes the logits. Used as the GAIL discriminator.**kwargs – Passed through to AdversarialTrainer.__init__.
- allow_variable_horizon: bool
If True, allow variable horizon trajectories; otherwise error if detected.
- property logger: HierarchicalLogger
- Return type
- logits_expert_is_high(state, action, next_state, done, log_policy_act_prob=None)[source]
Compute the discriminator’s logits for each state-action sample.
- Parameters
state (
Tensor
) – The state of the environment at the time of the action.action (
Tensor
) – The action taken by the expert or generator.next_state (
Tensor
) – The state of the environment after the action.done (
Tensor
) – whether a terminal state (as defined under the MDP of the task) has been reached.log_policy_act_prob (
Optional
[Tensor
]) – The log probability of the action taken by the generator, \(\log{P(a|s)}\).
- Return type
Tensor
- Returns
The logits of the discriminator for each state-action sample.
- property policy: BasePolicy
Returns a policy imitating the demonstration data.
- Return type
BasePolicy
- property reward_test: RewardNet
Reward used to train policy at “test” time after adversarial training.
- Return type
- set_demonstrations(demonstrations)
Sets the demonstration data.
Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger.
- Parameters
demonstrations (
Union
[Iterable
[Trajectory
],Iterable
[TransitionMapping
],TransitionsMinimal
]) – Either a Torch DataLoader, any other iterator that yields dictionaries containing “obs” and “acts” Tensors or NumPy arrays, TransitionKind instance, or a Sequence of Trajectory objects.- Return type
None
- train(total_timesteps, callback=None)
Alternates between training the generator and discriminator.
Every “round” consists of a call to train_gen(self.gen_train_timesteps), a call to train_disc, and finally a call to callback(round).
Training ends once an additional “round” would cause the number of transitions sampled from the environment to exceed total_timesteps.
- Parameters
total_timesteps (
int
) – An upper bound on the number of transitions to sample from the environment during training.callback (
Optional
[Callable
[[int
],None
]]) – A function called at the end of every round which takes in a single argument, the round number. Round numbers are in range(total_timesteps // self.gen_train_timesteps).
- Return type
None
- train_disc(*, expert_samples=None, gen_samples=None)
Perform a single discriminator update, optionally using provided samples.
- Parameters
expert_samples (
Optional
[Mapping
]) – Transition samples from the expert in dictionary form. If provided, must contain keys corresponding to every field of the Transitions dataclass except “infos”. All corresponding values can be either NumPy arrays or Tensors. Extra keys are ignored. Must contain self.demo_batch_size samples. If this argument is not provided, then self.demo_batch_size expert samples from self.demo_data_loader are used by default.gen_samples (
Optional
[Mapping
]) – Transition samples from the generator policy in same dictionary form as expert_samples. If provided, must contain exactly self.demo_batch_size samples. If not provided, then take len(expert_samples) samples from the generator replay buffer.
- Return type
Mapping
[str
,float
]- Returns
Statistics for discriminator (e.g. loss, accuracy).
- train_gen(total_timesteps=None, learn_kwargs=None)
Trains the generator to maximize the discriminator loss.
After the end of training populates the generator replay buffer (used in discriminator training) with self.disc_batch_size transitions.
- Parameters
total_timesteps (
Optional
[int
]) – The number of transitions to sample from self.venv_train during training. By default, self.gen_train_timesteps.learn_kwargs (
Optional
[Mapping
]) – kwargs for the Stable Baselines RLModel.learn() method.
- Return type
None
- venv: VecEnv
The original vectorized environment.
- venv_train: VecEnv
Like self.venv, but wrapped with train reward unless in debug mode.
If debug_use_ground_truth=True was passed into the initializer then self.venv_train is the same as self.venv.
- venv_wrapped: VecEnvWrapper
- class imitation.algorithms.adversarial.common.AdversarialTrainer(*, demonstrations, demo_batch_size, venv, gen_algo, reward_net, demo_minibatch_size=None, n_disc_updates_per_round=2, log_dir='output/', disc_opt_cls=<class 'torch.optim.adam.Adam'>, disc_opt_kwargs=None, gen_train_timesteps=None, gen_replay_buffer_capacity=None, custom_logger=None, init_tensorboard=False, init_tensorboard_graph=False, debug_use_ground_truth=False, allow_variable_horizon=False)[source]
Bases:
DemonstrationAlgorithm
[Transitions
]Base class for adversarial imitation learning algorithms like GAIL and AIRL.
- __init__(*, demonstrations, demo_batch_size, venv, gen_algo, reward_net, demo_minibatch_size=None, n_disc_updates_per_round=2, log_dir='output/', disc_opt_cls=<class 'torch.optim.adam.Adam'>, disc_opt_kwargs=None, gen_train_timesteps=None, gen_replay_buffer_capacity=None, custom_logger=None, init_tensorboard=False, init_tensorboard_graph=False, debug_use_ground_truth=False, allow_variable_horizon=False)[source]
Builds AdversarialTrainer.
- Parameters
demonstrations (
Union
[Iterable
[Trajectory
],Iterable
[TransitionMapping
],TransitionsMinimal
]) – Demonstrations from an expert (optional). Transitions expressed directly as a types.TransitionsMinimal object, a sequence of trajectories, or an iterable of transition batches (mappings from keywords to arrays containing observations, etc).demo_batch_size (
int
) – The number of samples in each batch of expert data. The discriminator batch size is twice this number because each discriminator batch contains a generator sample for every expert sample.venv (
VecEnv
) – The vectorized environment to train in.gen_algo (
BaseAlgorithm
) – The generator RL algorithm that is trained to maximize discriminator confusion. Environment and logger will be set to venv and custom_logger.reward_net (
RewardNet
) – a Torch module that takes an observation, action and next observation tensors as input and computes a reward signal.demo_minibatch_size (
Optional
[int
]) – size of minibatch to calculate gradients over. The gradients are accumulated until the entire batch is processed before making an optimization step. This is useful in GPU training to reduce memory usage, since fewer examples are loaded into memory at once, facilitating training with larger batch sizes, but is generally slower. Must be a factor of demo_batch_size. Optional, defaults to demo_batch_size.n_disc_updates_per_round (
int
) – The number of discriminator updates after each round of generator updates in AdversarialTrainer.learn().log_dir (
Union
[str
,bytes
,PathLike
]) – Directory to store TensorBoard logs, plots, etc. in.disc_opt_cls (
Type
[Optimizer
]) – The optimizer for discriminator training.disc_opt_kwargs (
Optional
[Mapping
]) – Parameters for discriminator training.gen_train_timesteps (
Optional
[int
]) – The number of steps to train the generator policy for each iteration. If None, then defaults to the batch size (for on-policy) or number of environments (for off-policy).gen_replay_buffer_capacity (
Optional
[int
]) – The capacity of the generator replay buffer (the number of obs-action-obs samples from the generator that can be stored). By default this is equal to gen_train_timesteps, meaning that we sample only from the most recent batch of generator samples.custom_logger (
Optional
[HierarchicalLogger
]) – Where to log to; if None (default), creates a new logger.init_tensorboard (
bool
) – If True, makes various discriminator TensorBoard summaries.init_tensorboard_graph (
bool
) – If both this and init_tensorboard are True, then write a Tensorboard graph summary to disk.debug_use_ground_truth (
bool
) – If True, use the ground truth reward for self.train_env. This disables the reward wrapping that would normally replace the environment reward with the learned reward. This is useful for sanity checking that the policy training is functional.allow_variable_horizon (
bool
) – If False (default), algorithm will raise an exception if it detects trajectories of different length during training. If True, overrides this safety check. WARNING: variable horizon episodes leak information about the reward via termination condition, and can seriously confound evaluation. Read https://imitation.readthedocs.io/en/latest/guide/variable_horizon.html before overriding this.
- Raises
ValueError – if the batch size is not a multiple of the minibatch size.
- allow_variable_horizon: bool
If True, allow variable horizon trajectories; otherwise error if detected.
- property logger: HierarchicalLogger
- Return type
- abstract logits_expert_is_high(state, action, next_state, done, log_policy_act_prob=None)[source]
Compute the discriminator’s logits for each state-action sample.
A high value corresponds to predicting expert, and a low value corresponds to predicting generator.
- Parameters
state (
Tensor
) – state at time t, of shape (batch_size,) + state_shape.action (
Tensor
) – action taken at time t, of shape (batch_size,) + action_shape.next_state (
Tensor
) – state at time t+1, of shape (batch_size,) + state_shape.done (
Tensor
) – binary episode completion flag after action at time t, of shape (batch_size,).log_policy_act_prob (
Optional
[Tensor
]) – log probability of generator policy taking action at time t.
- Return type
Tensor
- Returns
Discriminator logits of shape (batch_size,). A high output indicates an expert-like transition.
- property policy: BasePolicy
Returns a policy imitating the demonstration data.
- Return type
BasePolicy
- abstract property reward_test: RewardNet
Reward used to train policy at “test” time after adversarial training.
- Return type
- abstract property reward_train: RewardNet
Reward used to train generator policy.
- Return type
- set_demonstrations(demonstrations)[source]
Sets the demonstration data.
Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger.
- Parameters
demonstrations (
Union
[Iterable
[Trajectory
],Iterable
[TransitionMapping
],TransitionsMinimal
]) – Either a Torch DataLoader, any other iterator that yields dictionaries containing “obs” and “acts” Tensors or NumPy arrays, TransitionKind instance, or a Sequence of Trajectory objects.- Return type
None
- train(total_timesteps, callback=None)[source]
Alternates between training the generator and discriminator.
Every “round” consists of a call to train_gen(self.gen_train_timesteps), a call to train_disc, and finally a call to callback(round).
Training ends once an additional “round” would cause the number of transitions sampled from the environment to exceed total_timesteps.
- Parameters
total_timesteps (
int
) – An upper bound on the number of transitions to sample from the environment during training.callback (
Optional
[Callable
[[int
],None
]]) – A function called at the end of every round which takes in a single argument, the round number. Round numbers are in range(total_timesteps // self.gen_train_timesteps).
- Return type
None
- train_disc(*, expert_samples=None, gen_samples=None)[source]
Perform a single discriminator update, optionally using provided samples.
- Parameters
expert_samples (
Optional
[Mapping
]) – Transition samples from the expert in dictionary form. If provided, must contain keys corresponding to every field of the Transitions dataclass except “infos”. All corresponding values can be either NumPy arrays or Tensors. Extra keys are ignored. Must contain self.demo_batch_size samples. If this argument is not provided, then self.demo_batch_size expert samples from self.demo_data_loader are used by default.gen_samples (
Optional
[Mapping
]) – Transition samples from the generator policy in same dictionary form as expert_samples. If provided, must contain exactly self.demo_batch_size samples. If not provided, then take len(expert_samples) samples from the generator replay buffer.
- Return type
Mapping
[str
,float
]- Returns
Statistics for discriminator (e.g. loss, accuracy).
- train_gen(total_timesteps=None, learn_kwargs=None)[source]
Trains the generator to maximize the discriminator loss.
After the end of training populates the generator replay buffer (used in discriminator training) with self.disc_batch_size transitions.
- Parameters
total_timesteps (
Optional
[int
]) – The number of transitions to sample from self.venv_train during training. By default, self.gen_train_timesteps.learn_kwargs (
Optional
[Mapping
]) – kwargs for the Stable Baselines RLModel.learn() method.
- Return type
None
- venv: VecEnv
The original vectorized environment.
- venv_train: VecEnv
Like self.venv, but wrapped with train reward unless in debug mode.
If debug_use_ground_truth=True was passed into the initializer then self.venv_train is the same as self.venv.
- venv_wrapped: VecEnvWrapper