Adversarial Inverse Reinforcement Learning (AIRL)#

AIRL, similar to GAIL, adversarially trains a policy against a discriminator that aims to distinguish the expert demonstrations from the learned policy. Unlike GAIL, AIRL recovers a reward function that is more generalizable to changes in environment dynamics.

The expert policy must be stochastic.


Detailed example notebook: Train an Agent using Adversarial Inverse Reinforcement Learning

import numpy as np
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms.adversarial.airl import AIRL
from import rollout
from import RolloutInfoWrapper
from imitation.rewards.reward_nets import BasicShapedRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env

rng = np.random.default_rng(0)

env = gym.make("seals/CartPole-v0")
expert = PPO(policy=MlpPolicy, env=env)

rollouts = rollout.rollout(
        post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
    rollout.make_sample_until(min_timesteps=None, min_episodes=60),

venv = make_vec_env("seals/CartPole-v0", rng=rng, n_envs=8)
learner = PPO(env=venv, policy=MlpPolicy)
reward_net = BasicShapedRewardNet(
airl_trainer = AIRL(
rewards, _ = evaluate_policy(learner, venv, 100, return_episode_rewards=True)
print("Rewards:", rewards)


class imitation.algorithms.adversarial.airl.AIRL(*, demonstrations, demo_batch_size, venv, gen_algo, reward_net, **kwargs)[source]

Bases: AdversarialTrainer

Adversarial Inverse Reinforcement Learning (AIRL).

__init__(*, demonstrations, demo_batch_size, venv, gen_algo, reward_net, **kwargs)[source]

Builds an AIRL trainer.

  • demonstrations (Union[Iterable[Trajectory], Iterable[Mapping[str, Union[ndarray, Tensor]]], TransitionsMinimal]) – Demonstrations from an expert (optional). Transitions expressed directly as a types.TransitionsMinimal object, a sequence of trajectories, or an iterable of transition batches (mappings from keywords to arrays containing observations, etc).

  • demo_batch_size (int) – The number of samples in each batch of expert data. The discriminator batch size is twice this number because each discriminator batch contains a generator sample for every expert sample.

  • venv (VecEnv) – The vectorized environment to train in.

  • gen_algo (BaseAlgorithm) – The generator RL algorithm that is trained to maximize discriminator confusion. Environment and logger will be set to venv and custom_logger.

  • reward_net (RewardNet) – Reward network; used as part of AIRL discriminator.

  • **kwargs – Passed through to AdversarialTrainer.__init__.


TypeError – If gen_algo.policy does not have an evaluate_actions attribute (present in ActorCriticPolicy), needed to compute log-probability of actions.

allow_variable_horizon: bool

If True, allow variable horizon trajectories; otherwise error if detected.

property logger: HierarchicalLogger
Return type


logits_expert_is_high(state, action, next_state, done, log_policy_act_prob=None)[source]

Compute the discriminator’s logits for each state-action sample.

In Fu’s AIRL paper (, the discriminator output was given as

\[D_{\theta}(s,a) = \frac{ \exp{r_{\theta}(s,a)} } { \exp{r_{\theta}(s,a)} + \pi(a|s) }\]

with a high value corresponding to the expert and a low value corresponding to the generator.

In other words, the discriminator output is the probability that the action is taken by the expert rather than the generator.

The logit of the above is given as

\[\operatorname{logit}(D_{\theta}(s,a)) = r_{\theta}(s,a) - \log{ \pi(a|s) }\]

which is what is returned by this function.

  • state (Tensor) – The state of the environment at the time of the action.

  • action (Tensor) – The action taken by the expert or generator.

  • next_state (Tensor) – The state of the environment after the action.

  • done (Tensor) – whether a terminal state (as defined under the MDP of the task) has been reached.

  • log_policy_act_prob (Optional[Tensor]) – The log probability of the action taken by the generator, \(\log{ \pi(a|s) }\).

Return type



The logits of the discriminator for each state-action sample.


TypeError – If log_policy_act_prob is None.

property policy: BasePolicy

Returns a policy imitating the demonstration data.

Return type


property reward_test: RewardNet

Returns the unshaped version of reward network used for testing.

Return type


property reward_train: RewardNet

Reward used to train generator policy.

Return type



Sets the demonstration data.

Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger.


demonstrations (Union[Iterable[Trajectory], Iterable[Mapping[str, Union[ndarray, Tensor]]], TransitionsMinimal]) – Either a Torch DataLoader, any other iterator that yields dictionaries containing “obs” and “acts” Tensors or NumPy arrays, TransitionKind instance, or a Sequence of Trajectory objects.

Return type


train(total_timesteps, callback=None)

Alternates between training the generator and discriminator.

Every “round” consists of a call to train_gen(self.gen_train_timesteps), a call to train_disc, and finally a call to callback(round).

Training ends once an additional “round” would cause the number of transitions sampled from the environment to exceed total_timesteps.

  • total_timesteps (int) – An upper bound on the number of transitions to sample from the environment during training.

  • callback (Optional[Callable[[int], None]]) – A function called at the end of every round which takes in a single argument, the round number. Round numbers are in range(total_timesteps // self.gen_train_timesteps).

Return type


train_disc(*, expert_samples=None, gen_samples=None)

Perform a single discriminator update, optionally using provided samples.

  • expert_samples (Optional[Mapping]) – Transition samples from the expert in dictionary form. If provided, must contain keys corresponding to every field of the Transitions dataclass except “infos”. All corresponding values can be either NumPy arrays or Tensors. Extra keys are ignored. Must contain self.demo_batch_size samples. If this argument is not provided, then self.demo_batch_size expert samples from self.demo_data_loader are used by default.

  • gen_samples (Optional[Mapping]) – Transition samples from the generator policy in same dictionary form as expert_samples. If provided, must contain exactly self.demo_batch_size samples. If not provided, then take len(expert_samples) samples from the generator replay buffer.

Return type

Mapping[str, float]


Statistics for discriminator (e.g. loss, accuracy).

train_gen(total_timesteps=None, learn_kwargs=None)

Trains the generator to maximize the discriminator loss.

After the end of training populates the generator replay buffer (used in discriminator training) with self.disc_batch_size transitions.

  • total_timesteps (Optional[int]) – The number of transitions to sample from self.venv_train during training. By default, self.gen_train_timesteps.

  • learn_kwargs (Optional[Mapping]) – kwargs for the Stable Baselines RLModel.learn() method.

Return type


venv: VecEnv

The original vectorized environment.

venv_train: VecEnv

Like self.venv, but wrapped with train reward unless in debug mode.

If debug_use_ground_truth=True was passed into the initializer then self.venv_train is the same as self.venv.

venv_wrapped: VecEnvWrapper
class imitation.algorithms.adversarial.common.AdversarialTrainer(*, demonstrations, demo_batch_size, venv, gen_algo, reward_net, demo_minibatch_size=None, n_disc_updates_per_round=2, log_dir='output/', disc_opt_cls=<class 'torch.optim.adam.Adam'>, disc_opt_kwargs=None, gen_train_timesteps=None, gen_replay_buffer_capacity=None, custom_logger=None, init_tensorboard=False, init_tensorboard_graph=False, debug_use_ground_truth=False, allow_variable_horizon=False)[source]

Bases: DemonstrationAlgorithm[Transitions]

Base class for adversarial imitation learning algorithms like GAIL and AIRL.

__init__(*, demonstrations, demo_batch_size, venv, gen_algo, reward_net, demo_minibatch_size=None, n_disc_updates_per_round=2, log_dir='output/', disc_opt_cls=<class 'torch.optim.adam.Adam'>, disc_opt_kwargs=None, gen_train_timesteps=None, gen_replay_buffer_capacity=None, custom_logger=None, init_tensorboard=False, init_tensorboard_graph=False, debug_use_ground_truth=False, allow_variable_horizon=False)[source]

Builds AdversarialTrainer.

  • demonstrations (Union[Iterable[Trajectory], Iterable[Mapping[str, Union[ndarray, Tensor]]], TransitionsMinimal]) – Demonstrations from an expert (optional). Transitions expressed directly as a types.TransitionsMinimal object, a sequence of trajectories, or an iterable of transition batches (mappings from keywords to arrays containing observations, etc).

  • demo_batch_size (int) – The number of samples in each batch of expert data. The discriminator batch size is twice this number because each discriminator batch contains a generator sample for every expert sample.

  • venv (VecEnv) – The vectorized environment to train in.

  • gen_algo (BaseAlgorithm) – The generator RL algorithm that is trained to maximize discriminator confusion. Environment and logger will be set to venv and custom_logger.

  • reward_net (RewardNet) – a Torch module that takes an observation, action and next observation tensors as input and computes a reward signal.

  • demo_minibatch_size (Optional[int]) – size of minibatch to calculate gradients over. The gradients are accumulated until the entire batch is processed before making an optimization step. This is useful in GPU training to reduce memory usage, since fewer examples are loaded into memory at once, facilitating training with larger batch sizes, but is generally slower. Must be a factor of demo_batch_size. Optional, defaults to demo_batch_size.

  • n_disc_updates_per_round (int) – The number of discriminator updates after each round of generator updates in AdversarialTrainer.learn().

  • log_dir (Union[str, bytes, PathLike]) – Directory to store TensorBoard logs, plots, etc. in.

  • disc_opt_cls (Type[Optimizer]) – The optimizer for discriminator training.

  • disc_opt_kwargs (Optional[Mapping]) – Parameters for discriminator training.

  • gen_train_timesteps (Optional[int]) – The number of steps to train the generator policy for each iteration. If None, then defaults to the batch size (for on-policy) or number of environments (for off-policy).

  • gen_replay_buffer_capacity (Optional[int]) – The capacity of the generator replay buffer (the number of obs-action-obs samples from the generator that can be stored). By default this is equal to gen_train_timesteps, meaning that we sample only from the most recent batch of generator samples.

  • custom_logger (Optional[HierarchicalLogger]) – Where to log to; if None (default), creates a new logger.

  • init_tensorboard (bool) – If True, makes various discriminator TensorBoard summaries.

  • init_tensorboard_graph (bool) – If both this and init_tensorboard are True, then write a Tensorboard graph summary to disk.

  • debug_use_ground_truth (bool) – If True, use the ground truth reward for self.train_env. This disables the reward wrapping that would normally replace the environment reward with the learned reward. This is useful for sanity checking that the policy training is functional.

  • allow_variable_horizon (bool) – If False (default), algorithm will raise an exception if it detects trajectories of different length during training. If True, overrides this safety check. WARNING: variable horizon episodes leak information about the reward via termination condition, and can seriously confound evaluation. Read before overriding this.


ValueError – if the batch size is not a multiple of the minibatch size.

allow_variable_horizon: bool

If True, allow variable horizon trajectories; otherwise error if detected.

abstract logits_expert_is_high(state, action, next_state, done, log_policy_act_prob=None)[source]

Compute the discriminator’s logits for each state-action sample.

A high value corresponds to predicting expert, and a low value corresponds to predicting generator.

  • state (Tensor) – state at time t, of shape (batch_size,) + state_shape.

  • action (Tensor) – action taken at time t, of shape (batch_size,) + action_shape.

  • next_state (Tensor) – state at time t+1, of shape (batch_size,) + state_shape.

  • done (Tensor) – binary episode completion flag after action at time t, of shape (batch_size,).

  • log_policy_act_prob (Optional[Tensor]) – log probability of generator policy taking action at time t.

Return type



Discriminator logits of shape (batch_size,). A high output indicates an expert-like transition.

property policy: BasePolicy

Returns a policy imitating the demonstration data.

Return type


abstract property reward_test: RewardNet

Reward used to train policy at “test” time after adversarial training.

Return type


abstract property reward_train: RewardNet

Reward used to train generator policy.

Return type



Sets the demonstration data.

Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger.


demonstrations (Union[Iterable[Trajectory], Iterable[Mapping[str, Union[ndarray, Tensor]]], TransitionsMinimal]) – Either a Torch DataLoader, any other iterator that yields dictionaries containing “obs” and “acts” Tensors or NumPy arrays, TransitionKind instance, or a Sequence of Trajectory objects.

Return type


train(total_timesteps, callback=None)[source]

Alternates between training the generator and discriminator.

Every “round” consists of a call to train_gen(self.gen_train_timesteps), a call to train_disc, and finally a call to callback(round).

Training ends once an additional “round” would cause the number of transitions sampled from the environment to exceed total_timesteps.

  • total_timesteps (int) – An upper bound on the number of transitions to sample from the environment during training.

  • callback (Optional[Callable[[int], None]]) – A function called at the end of every round which takes in a single argument, the round number. Round numbers are in range(total_timesteps // self.gen_train_timesteps).

Return type


train_disc(*, expert_samples=None, gen_samples=None)[source]

Perform a single discriminator update, optionally using provided samples.

  • expert_samples (Optional[Mapping]) – Transition samples from the expert in dictionary form. If provided, must contain keys corresponding to every field of the Transitions dataclass except “infos”. All corresponding values can be either NumPy arrays or Tensors. Extra keys are ignored. Must contain self.demo_batch_size samples. If this argument is not provided, then self.demo_batch_size expert samples from self.demo_data_loader are used by default.

  • gen_samples (Optional[Mapping]) – Transition samples from the generator policy in same dictionary form as expert_samples. If provided, must contain exactly self.demo_batch_size samples. If not provided, then take len(expert_samples) samples from the generator replay buffer.

Return type

Mapping[str, float]


Statistics for discriminator (e.g. loss, accuracy).

train_gen(total_timesteps=None, learn_kwargs=None)[source]

Trains the generator to maximize the discriminator loss.

After the end of training populates the generator replay buffer (used in discriminator training) with self.disc_batch_size transitions.

  • total_timesteps (Optional[int]) – The number of transitions to sample from self.venv_train during training. By default, self.gen_train_timesteps.

  • learn_kwargs (Optional[Mapping]) – kwargs for the Stable Baselines RLModel.learn() method.

Return type


venv: VecEnv

The original vectorized environment.

venv_train: VecEnv

Like self.venv, but wrapped with train reward unless in debug mode.

If debug_use_ground_truth=True was passed into the initializer then self.venv_train is the same as self.venv.

venv_wrapped: VecEnvWrapper