imitation.scripts.train_adversarial#

Train GAIL or AIRL.

Functions

airl()

gail()

main_console()

save(trainer, save_path)

Save discriminator and generator.

train_adversarial(_run, show_config, ...)

Train an adversarial-network-based imitation learning algorithm.

imitation.scripts.train_adversarial.airl()[source]#
imitation.scripts.train_adversarial.gail()[source]#
imitation.scripts.train_adversarial.main_console()[source]#
imitation.scripts.train_adversarial.save(trainer, save_path)[source]#

Save discriminator and generator.

imitation.scripts.train_adversarial.train_adversarial(_run, show_config, algo_cls, algorithm_kwargs, total_timesteps, checkpoint_interval, agent_path)[source]#

Train an adversarial-network-based imitation learning algorithm.

Checkpoints:
  • AdversarialTrainer train and test RewardNets are saved to
    f”{log_dir}/checkpoints/{step}/reward_{train,test}.pt”

    where step is either the training round or “final”.

  • Generator policies are saved to f”{log_dir}/checkpoints/{step}/gen_policy/”.

Parameters
  • show_config (bool) – Print the merged config before starting training. This is analogous to the print_config command, but will show config after rather than before merging algorithm_specific arguments.

  • algo_cls (Type[AdversarialTrainer]) – The adversarial imitation learning algorithm to use.

  • algorithm_kwargs (Mapping[str, Any]) – Keyword arguments for the GAIL or AIRL constructor.

  • total_timesteps (int) – The number of transitions to sample from the environment during training.

  • checkpoint_interval (int) – Save the discriminator and generator models every checkpoint_interval rounds and after training is complete. If 0, then only save weights after training is complete. If <0, then don’t save weights at all.

  • agent_path (Optional[str]) – Path to a directory containing a pre-trained agent. If provided, then the agent will be initialized using this stored policy (warm start). If not provided, then the agent will be initialized using a random policy.

Return type

Mapping[str, Mapping[str, float]]

Returns

A dictionary with two keys. “imit_stats” gives the return value of rollout_stats() on rollouts test-reward-wrapped environment, using the final policy (remember that the ground-truth reward can be recovered from the “monitor_return” key). “expert_stats” gives the return value of rollout_stats() on the expert demonstrations.