Source code for imitation.algorithms.adversarial.common

"""Core code for adversarial imitation learning, shared between GAIL and AIRL."""
import abc
import dataclasses
import logging
from typing import Callable, Iterable, Iterator, Mapping, Optional, Type, overload

import numpy as np
import torch as th
import torch.utils.tensorboard as thboard
import tqdm
from stable_baselines3.common import (
    base_class,
    distributions,
    on_policy_algorithm,
    policies,
    vec_env,
)
from stable_baselines3.sac import policies as sac_policies
from torch.nn import functional as F

from imitation.algorithms import base
from imitation.data import buffer, rollout, types, wrappers
from imitation.rewards import reward_nets, reward_wrapper
from imitation.util import logger, networks, util


[docs]def compute_train_stats( disc_logits_expert_is_high: th.Tensor, labels_expert_is_one: th.Tensor, disc_loss: th.Tensor, ) -> Mapping[str, float]: """Train statistics for GAIL/AIRL discriminator. Args: disc_logits_expert_is_high: discriminator logits produced by `AdversarialTrainer.logits_expert_is_high`. labels_expert_is_one: integer labels describing whether logit was for an expert (0) or generator (1) sample. disc_loss: final discriminator loss. Returns: A mapping from statistic names to float values. """ with th.no_grad(): # Logits of the discriminator output; >0 for expert samples, <0 for generator. bin_is_generated_pred = disc_logits_expert_is_high < 0 # Binary label, so 1 is for expert, 0 is for generator. bin_is_generated_true = labels_expert_is_one == 0 bin_is_expert_true = th.logical_not(bin_is_generated_true) int_is_generated_pred = bin_is_generated_pred.long() int_is_generated_true = bin_is_generated_true.long() n_generated = float(th.sum(int_is_generated_true)) n_labels = float(len(labels_expert_is_one)) n_expert = n_labels - n_generated pct_expert = n_expert / float(n_labels) if n_labels > 0 else float("NaN") n_expert_pred = int(n_labels - th.sum(int_is_generated_pred)) if n_labels > 0: pct_expert_pred = n_expert_pred / float(n_labels) else: pct_expert_pred = float("NaN") correct_vec = th.eq(bin_is_generated_pred, bin_is_generated_true) acc = th.mean(correct_vec.float()) _n_pred_expert = th.sum(th.logical_and(bin_is_expert_true, correct_vec)) if n_expert < 1: expert_acc = float("NaN") else: # float() is defensive, since we cannot divide Torch tensors by # Python ints expert_acc = _n_pred_expert.item() / float(n_expert) _n_pred_gen = th.sum(th.logical_and(bin_is_generated_true, correct_vec)) _n_gen_or_1 = max(1, n_generated) generated_acc = _n_pred_gen / float(_n_gen_or_1) label_dist = th.distributions.Bernoulli(logits=disc_logits_expert_is_high) entropy = th.mean(label_dist.entropy()) return { "disc_loss": float(th.mean(disc_loss)), "disc_acc": float(acc), "disc_acc_expert": float(expert_acc), # accuracy on just expert examples "disc_acc_gen": float(generated_acc), # accuracy on just generated examples # entropy of the predicted label distribution, averaged equally across # both classes (if this drops then disc is very good or has given up) "disc_entropy": float(entropy), # true number of expert demos and predicted number of expert demos "disc_proportion_expert_true": float(pct_expert), "disc_proportion_expert_pred": float(pct_expert_pred), "n_expert": float(n_expert), "n_generated": float(n_generated), }
[docs]class AdversarialTrainer(base.DemonstrationAlgorithm[types.Transitions]): """Base class for adversarial imitation learning algorithms like GAIL and AIRL.""" venv: vec_env.VecEnv """The original vectorized environment.""" venv_train: vec_env.VecEnv """Like `self.venv`, but wrapped with train reward unless in debug mode. If `debug_use_ground_truth=True` was passed into the initializer then `self.venv_train` is the same as `self.venv`.""" _demo_data_loader: Optional[Iterable[types.TransitionMapping]] _endless_expert_iterator: Optional[Iterator[types.TransitionMapping]] venv_wrapped: vec_env.VecEnvWrapper
[docs] def __init__( self, *, demonstrations: base.AnyTransitions, demo_batch_size: int, venv: vec_env.VecEnv, gen_algo: base_class.BaseAlgorithm, reward_net: reward_nets.RewardNet, demo_minibatch_size: Optional[int] = None, n_disc_updates_per_round: int = 2, log_dir: types.AnyPath = "output/", disc_opt_cls: Type[th.optim.Optimizer] = th.optim.Adam, disc_opt_kwargs: Optional[Mapping] = None, gen_train_timesteps: Optional[int] = None, gen_replay_buffer_capacity: Optional[int] = None, custom_logger: Optional[logger.HierarchicalLogger] = None, init_tensorboard: bool = False, init_tensorboard_graph: bool = False, debug_use_ground_truth: bool = False, allow_variable_horizon: bool = False, ): """Builds AdversarialTrainer. Args: demonstrations: Demonstrations from an expert (optional). Transitions expressed directly as a `types.TransitionsMinimal` object, a sequence of trajectories, or an iterable of transition batches (mappings from keywords to arrays containing observations, etc). demo_batch_size: The number of samples in each batch of expert data. The discriminator batch size is twice this number because each discriminator batch contains a generator sample for every expert sample. venv: The vectorized environment to train in. gen_algo: The generator RL algorithm that is trained to maximize discriminator confusion. Environment and logger will be set to `venv` and `custom_logger`. reward_net: a Torch module that takes an observation, action and next observation tensors as input and computes a reward signal. demo_minibatch_size: size of minibatch to calculate gradients over. The gradients are accumulated until the entire batch is processed before making an optimization step. This is useful in GPU training to reduce memory usage, since fewer examples are loaded into memory at once, facilitating training with larger batch sizes, but is generally slower. Must be a factor of `demo_batch_size`. Optional, defaults to `demo_batch_size`. n_disc_updates_per_round: The number of discriminator updates after each round of generator updates in AdversarialTrainer.learn(). log_dir: Directory to store TensorBoard logs, plots, etc. in. disc_opt_cls: The optimizer for discriminator training. disc_opt_kwargs: Parameters for discriminator training. gen_train_timesteps: The number of steps to train the generator policy for each iteration. If None, then defaults to the batch size (for on-policy) or number of environments (for off-policy). gen_replay_buffer_capacity: The capacity of the generator replay buffer (the number of obs-action-obs samples from the generator that can be stored). By default this is equal to `gen_train_timesteps`, meaning that we sample only from the most recent batch of generator samples. custom_logger: Where to log to; if None (default), creates a new logger. init_tensorboard: If True, makes various discriminator TensorBoard summaries. init_tensorboard_graph: If both this and `init_tensorboard` are True, then write a Tensorboard graph summary to disk. debug_use_ground_truth: If True, use the ground truth reward for `self.train_env`. This disables the reward wrapping that would normally replace the environment reward with the learned reward. This is useful for sanity checking that the policy training is functional. allow_variable_horizon: If False (default), algorithm will raise an exception if it detects trajectories of different length during training. If True, overrides this safety check. WARNING: variable horizon episodes leak information about the reward via termination condition, and can seriously confound evaluation. Read https://imitation.readthedocs.io/en/latest/guide/variable_horizon.html before overriding this. Raises: ValueError: if the batch size is not a multiple of the minibatch size. """ self.demo_batch_size = demo_batch_size self.demo_minibatch_size = demo_minibatch_size or demo_batch_size if self.demo_batch_size % self.demo_minibatch_size != 0: raise ValueError("Batch size must be a multiple of minibatch size.") self._demo_data_loader = None self._endless_expert_iterator = None super().__init__( demonstrations=demonstrations, custom_logger=custom_logger, allow_variable_horizon=allow_variable_horizon, ) self._global_step = 0 self._disc_step = 0 self.n_disc_updates_per_round = n_disc_updates_per_round self.debug_use_ground_truth = debug_use_ground_truth self.venv = venv self.gen_algo = gen_algo self._reward_net = reward_net.to(gen_algo.device) self._log_dir = util.parse_path(log_dir) # Create graph for optimising/recording stats on discriminator self._disc_opt_cls = disc_opt_cls self._disc_opt_kwargs = disc_opt_kwargs or {} self._init_tensorboard = init_tensorboard self._init_tensorboard_graph = init_tensorboard_graph self._disc_opt = self._disc_opt_cls( self._reward_net.parameters(), **self._disc_opt_kwargs, ) if self._init_tensorboard: logging.info(f"building summary directory at {self._log_dir}") summary_dir = self._log_dir / "summary" summary_dir.mkdir(parents=True, exist_ok=True) self._summary_writer = thboard.SummaryWriter(str(summary_dir)) self.venv_buffering = wrappers.BufferingWrapper(self.venv) if debug_use_ground_truth: # Would use an identity reward fn here, but RewardFns can't see rewards. self.venv_wrapped = self.venv_buffering self.gen_callback = None else: self.venv_wrapped = reward_wrapper.RewardVecEnvWrapper( self.venv_buffering, reward_fn=self.reward_train.predict_processed, ) self.gen_callback = self.venv_wrapped.make_log_callback() self.venv_train = self.venv_wrapped self.gen_algo.set_env(self.venv_train) self.gen_algo.set_logger(self.logger) if gen_train_timesteps is None: gen_algo_env = self.gen_algo.get_env() assert gen_algo_env is not None self.gen_train_timesteps = gen_algo_env.num_envs if isinstance(self.gen_algo, on_policy_algorithm.OnPolicyAlgorithm): self.gen_train_timesteps *= self.gen_algo.n_steps else: self.gen_train_timesteps = gen_train_timesteps if gen_replay_buffer_capacity is None: gen_replay_buffer_capacity = self.gen_train_timesteps self._gen_replay_buffer = buffer.ReplayBuffer( gen_replay_buffer_capacity, self.venv, )
@property def policy(self) -> policies.BasePolicy: policy = self.gen_algo.policy assert policy is not None return policy
[docs] @abc.abstractmethod def logits_expert_is_high( self, state: th.Tensor, action: th.Tensor, next_state: th.Tensor, done: th.Tensor, log_policy_act_prob: Optional[th.Tensor] = None, ) -> th.Tensor: """Compute the discriminator's logits for each state-action sample. A high value corresponds to predicting expert, and a low value corresponds to predicting generator. Args: state: state at time t, of shape `(batch_size,) + state_shape`. action: action taken at time t, of shape `(batch_size,) + action_shape`. next_state: state at time t+1, of shape `(batch_size,) + state_shape`. done: binary episode completion flag after action at time t, of shape `(batch_size,)`. log_policy_act_prob: log probability of generator policy taking `action` at time t. Returns: Discriminator logits of shape `(batch_size,)`. A high output indicates an expert-like transition. """ # noqa: DAR202
@property @abc.abstractmethod def reward_train(self) -> reward_nets.RewardNet: """Reward used to train generator policy.""" @property @abc.abstractmethod def reward_test(self) -> reward_nets.RewardNet: """Reward used to train policy at "test" time after adversarial training."""
[docs] def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None: self._demo_data_loader = base.make_data_loader( demonstrations, self.demo_batch_size, ) self._endless_expert_iterator = util.endless_iter(self._demo_data_loader)
def _next_expert_batch(self) -> Mapping: assert self._endless_expert_iterator is not None return next(self._endless_expert_iterator)
[docs] def train_disc( self, *, expert_samples: Optional[Mapping] = None, gen_samples: Optional[Mapping] = None, ) -> Mapping[str, float]: """Perform a single discriminator update, optionally using provided samples. Args: expert_samples: Transition samples from the expert in dictionary form. If provided, must contain keys corresponding to every field of the `Transitions` dataclass except "infos". All corresponding values can be either NumPy arrays or Tensors. Extra keys are ignored. Must contain `self.demo_batch_size` samples. If this argument is not provided, then `self.demo_batch_size` expert samples from `self.demo_data_loader` are used by default. gen_samples: Transition samples from the generator policy in same dictionary form as `expert_samples`. If provided, must contain exactly `self.demo_batch_size` samples. If not provided, then take `len(expert_samples)` samples from the generator replay buffer. Returns: Statistics for discriminator (e.g. loss, accuracy). """ with self.logger.accumulate_means("disc"): # optionally write TB summaries for collected ops write_summaries = self._init_tensorboard and self._global_step % 20 == 0 # compute loss self._disc_opt.zero_grad() batch_iter = self._make_disc_train_batches( gen_samples=gen_samples, expert_samples=expert_samples, ) for batch in batch_iter: disc_logits = self.logits_expert_is_high( batch["state"], batch["action"], batch["next_state"], batch["done"], batch["log_policy_act_prob"], ) loss = F.binary_cross_entropy_with_logits( disc_logits, batch["labels_expert_is_one"].float(), ) # Renormalise the loss to be averaged over the whole # batch size instead of the minibatch size. assert len(batch["state"]) == 2 * self.demo_minibatch_size loss *= self.demo_minibatch_size / self.demo_batch_size loss.backward() # do gradient step self._disc_opt.step() self._disc_step += 1 # compute/write stats and TensorBoard data with th.no_grad(): train_stats = compute_train_stats( disc_logits, batch["labels_expert_is_one"], loss, ) self.logger.record("global_step", self._global_step) for k, v in train_stats.items(): self.logger.record(k, v) self.logger.dump(self._disc_step) if write_summaries: self._summary_writer.add_histogram("disc_logits", disc_logits.detach()) return train_stats
[docs] def train_gen( self, total_timesteps: Optional[int] = None, learn_kwargs: Optional[Mapping] = None, ) -> None: """Trains the generator to maximize the discriminator loss. After the end of training populates the generator replay buffer (used in discriminator training) with `self.disc_batch_size` transitions. Args: total_timesteps: The number of transitions to sample from `self.venv_train` during training. By default, `self.gen_train_timesteps`. learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()` method. """ if total_timesteps is None: total_timesteps = self.gen_train_timesteps if learn_kwargs is None: learn_kwargs = {} with self.logger.accumulate_means("gen"): self.gen_algo.learn( total_timesteps=total_timesteps, reset_num_timesteps=False, callback=self.gen_callback, **learn_kwargs, ) self._global_step += 1 gen_trajs, ep_lens = self.venv_buffering.pop_trajectories() self._check_fixed_horizon(ep_lens) gen_samples = rollout.flatten_trajectories_with_rew(gen_trajs) self._gen_replay_buffer.store(gen_samples)
[docs] def train( self, total_timesteps: int, callback: Optional[Callable[[int], None]] = None, ) -> None: """Alternates between training the generator and discriminator. Every "round" consists of a call to `train_gen(self.gen_train_timesteps)`, a call to `train_disc`, and finally a call to `callback(round)`. Training ends once an additional "round" would cause the number of transitions sampled from the environment to exceed `total_timesteps`. Args: total_timesteps: An upper bound on the number of transitions to sample from the environment during training. callback: A function called at the end of every round which takes in a single argument, the round number. Round numbers are in `range(total_timesteps // self.gen_train_timesteps)`. """ n_rounds = total_timesteps // self.gen_train_timesteps assert n_rounds >= 1, ( "No updates (need at least " f"{self.gen_train_timesteps} timesteps, have only " f"total_timesteps={total_timesteps})!" ) for r in tqdm.tqdm(range(0, n_rounds), desc="round"): self.train_gen(self.gen_train_timesteps) for _ in range(self.n_disc_updates_per_round): with networks.training(self.reward_train): # switch to training mode (affects dropout, normalization) self.train_disc() if callback: callback(r) self.logger.dump(self._global_step)
@overload def _torchify_array(self, ndarray: np.ndarray) -> th.Tensor: ... @overload def _torchify_array(self, ndarray: None) -> None: ... def _torchify_array(self, ndarray: Optional[np.ndarray]) -> Optional[th.Tensor]: if ndarray is not None: return th.as_tensor(ndarray, device=self.reward_train.device) return None def _get_log_policy_act_prob( self, obs_th: th.Tensor, acts_th: th.Tensor, ) -> Optional[th.Tensor]: """Evaluates the given actions on the given observations. Args: obs_th: A batch of observations. acts_th: A batch of actions. Returns: A batch of log policy action probabilities. """ if isinstance(self.policy, policies.ActorCriticPolicy): # policies.ActorCriticPolicy has a concrete implementation of # evaluate_actions to generate log_policy_act_prob given obs and actions. _, log_policy_act_prob_th, _ = self.policy.evaluate_actions( obs_th, acts_th, ) elif isinstance(self.policy, sac_policies.SACPolicy): gen_algo_actor = self.policy.actor assert gen_algo_actor is not None # generate log_policy_act_prob from SAC actor. mean_actions, log_std, _ = gen_algo_actor.get_action_dist_params(obs_th) assert isinstance( gen_algo_actor.action_dist, distributions.SquashedDiagGaussianDistribution, ) # Note: this is just a hint to mypy distribution = gen_algo_actor.action_dist.proba_distribution( mean_actions, log_std, ) # SAC applies a squashing function to bound the actions to a finite range # `acts_th` need to be scaled accordingly before computing log prob. # Scale actions only if the policy squashes outputs. assert self.policy.squash_output scaled_acts = self.policy.scale_action(acts_th.numpy(force=True)) scaled_acts_th = th.as_tensor(scaled_acts, device=mean_actions.device) log_policy_act_prob_th = distribution.log_prob(scaled_acts_th) else: return None return log_policy_act_prob_th def _make_disc_train_batches( self, *, gen_samples: Optional[Mapping] = None, expert_samples: Optional[Mapping] = None, ) -> Iterator[Mapping[str, th.Tensor]]: """Build and return training minibatches for the next discriminator update. Args: gen_samples: Same as in `train_disc`. expert_samples: Same as in `train_disc`. Yields: The training minibatch: state, action, next state, dones, labels and policy log-probabilities. Raises: RuntimeError: Empty generator replay buffer. ValueError: `gen_samples` or `expert_samples` batch size is different from `self.demo_batch_size`. """ batch_size = self.demo_batch_size if expert_samples is None: expert_samples = self._next_expert_batch() if gen_samples is None: if self._gen_replay_buffer.size() == 0: raise RuntimeError( "No generator samples for training. " "Call `train_gen()` first.", ) gen_samples_dataclass = self._gen_replay_buffer.sample(batch_size) gen_samples = types.dataclass_quick_asdict(gen_samples_dataclass) if not (len(gen_samples["obs"]) == len(expert_samples["obs"]) == batch_size): raise ValueError( "Need to have exactly `demo_batch_size` number of expert and " "generator samples, each. " f"(n_gen={len(gen_samples['obs'])} " f"n_expert={len(expert_samples['obs'])} " f"demo_batch_size={batch_size})", ) # Guarantee that Mapping arguments are in mutable form. expert_samples = dict(expert_samples) gen_samples = dict(gen_samples) # Convert applicable Tensor values to NumPy. for field in dataclasses.fields(types.Transitions): k = field.name if k == "infos": continue for d in [gen_samples, expert_samples]: if isinstance(d[k], th.Tensor): d[k] = d[k].detach().numpy() assert isinstance(gen_samples["obs"], np.ndarray) assert isinstance(expert_samples["obs"], np.ndarray) # Check dimensions. assert batch_size == len(expert_samples["acts"]) assert batch_size == len(expert_samples["next_obs"]) assert batch_size == len(gen_samples["acts"]) assert batch_size == len(gen_samples["next_obs"]) for start in range(0, batch_size, self.demo_minibatch_size): end = start + self.demo_minibatch_size # take minibatch slice (this creates views so no memory issues) expert_batch = {k: v[start:end] for k, v in expert_samples.items()} gen_batch = {k: v[start:end] for k, v in gen_samples.items()} # Concatenate rollouts, and label each row as expert or generator. obs = np.concatenate([expert_batch["obs"], gen_batch["obs"]]) acts = np.concatenate([expert_batch["acts"], gen_batch["acts"]]) next_obs = np.concatenate([expert_batch["next_obs"], gen_batch["next_obs"]]) dones = np.concatenate([expert_batch["dones"], gen_batch["dones"]]) # notice that the labels use the convention that expert samples are # labelled with 1 and generator samples with 0. labels_expert_is_one = np.concatenate( [ np.ones(self.demo_minibatch_size, dtype=int), np.zeros(self.demo_minibatch_size, dtype=int), ], ) # Calculate generator-policy log probabilities. with th.no_grad(): obs_th = th.as_tensor(obs, device=self.gen_algo.device) acts_th = th.as_tensor(acts, device=self.gen_algo.device) log_policy_act_prob = self._get_log_policy_act_prob(obs_th, acts_th) if log_policy_act_prob is not None: assert len(log_policy_act_prob) == 2 * self.demo_minibatch_size log_policy_act_prob = log_policy_act_prob.reshape( (2 * self.demo_minibatch_size,), ) del obs_th, acts_th # unneeded obs_th, acts_th, next_obs_th, dones_th = self.reward_train.preprocess( obs, acts, next_obs, dones, ) batch_dict = { "state": obs_th, "action": acts_th, "next_state": next_obs_th, "done": dones_th, "labels_expert_is_one": self._torchify_array(labels_expert_is_one), "log_policy_act_prob": log_policy_act_prob, } yield batch_dict