imitation.scripts.train_adversarial#
Train GAIL or AIRL.
Functions
|
|
|
|
|
Save discriminator and generator. |
|
Train an adversarial-network-based imitation learning algorithm. |
- imitation.scripts.train_adversarial.save(trainer, save_path)[source]#
Save discriminator and generator.
- imitation.scripts.train_adversarial.train_adversarial(_run, show_config, algo_cls, algorithm_kwargs, total_timesteps, checkpoint_interval, agent_path)[source]#
Train an adversarial-network-based imitation learning algorithm.
- Checkpoints:
- AdversarialTrainer train and test RewardNets are saved to
- f”{log_dir}/checkpoints/{step}/reward_{train,test}.pt”
where step is either the training round or “final”.
Generator policies are saved to f”{log_dir}/checkpoints/{step}/gen_policy/”.
- Parameters
show_config (
bool
) – Print the merged config before starting training. This is analogous to the print_config command, but will show config after rather than before merging algorithm_specific arguments.algo_cls (
Type
[AdversarialTrainer
]) – The adversarial imitation learning algorithm to use.algorithm_kwargs (
Mapping
[str
,Any
]) – Keyword arguments for the GAIL or AIRL constructor.total_timesteps (
int
) – The number of transitions to sample from the environment during training.checkpoint_interval (
int
) – Save the discriminator and generator models every checkpoint_interval rounds and after training is complete. If 0, then only save weights after training is complete. If <0, then don’t save weights at all.agent_path (
Optional
[str
]) – Path to a directory containing a pre-trained agent. If provided, then the agent will be initialized using this stored policy (warm start). If not provided, then the agent will be initialized using a random policy.
- Return type
Mapping
[str
,Mapping
[str
,float
]]- Returns
A dictionary with two keys. “imit_stats” gives the return value of rollout_stats() on rollouts test-reward-wrapped environment, using the final policy (remember that the ground-truth reward can be recovered from the “monitor_return” key). “expert_stats” gives the return value of rollout_stats() on the expert demonstrations.