imitation.scripts.train_preference_comparisons#

Train a reward model using preference comparisons.

Can be used as a CLI script, or the train_preference_comparisons function can be called directly.

Functions

main_console()

save_checkpoint(trainer, save_path, ...)

Save reward model and optionally policy.

save_model(agent_trainer, save_path)

Save the model as model.zip.

train_preference_comparisons(...)

Train a reward model using preference comparisons.

imitation.scripts.train_preference_comparisons.main_console()[source]#
imitation.scripts.train_preference_comparisons.save_checkpoint(trainer, save_path, allow_save_policy)[source]#

Save reward model and optionally policy.

imitation.scripts.train_preference_comparisons.save_model(agent_trainer, save_path)[source]#

Save the model as model.zip.

imitation.scripts.train_preference_comparisons.train_preference_comparisons(total_timesteps, total_comparisons, num_iterations, comparison_queue_size, fragment_length, transition_oversampling, initial_comparison_frac, exploration_frac, trajectory_path, trajectory_generator_kwargs, save_preferences, agent_path, preference_model_kwargs, reward_trainer_kwargs, gatherer_cls, gatherer_kwargs, active_selection, active_selection_oversampling, uncertainty_on, fragmenter_kwargs, allow_variable_horizon, checkpoint_interval, query_schedule, _rnd)[source]#

Train a reward model using preference comparisons.

Parameters
  • total_timesteps (int) – number of environment interaction steps

  • total_comparisons (int) – number of preferences to gather in total

  • num_iterations (int) – number of times to train the agent against the reward model and then train the reward model against newly gathered preferences.

  • comparison_queue_size (Optional[int]) – the maximum number of comparisons to keep in the queue for training the reward model. If None, the queue will grow without bound as new comparisons are added.

  • fragment_length (int) – number of timesteps per fragment that is used to elicit preferences

  • transition_oversampling (float) – factor by which to oversample transitions before creating fragments. Since fragments are sampled with replacement, this is usually chosen > 1 to avoid having the same transition in too many fragments.

  • initial_comparison_frac (float) – fraction of total_comparisons that will be sampled before the rest of training begins (using the randomly initialized agent). This can be used to pretrain the reward model before the agent is trained on the learned reward.

  • exploration_frac (float) – fraction of trajectory samples that will be created using partially random actions, rather than the current policy. Might be helpful if the learned policy explores too little and gets stuck with a wrong reward.

  • trajectory_path (Optional[str]) – either None, in which case an agent will be trained and used to sample trajectories on the fly, or a path to a pickled sequence of TrajectoryWithRew to be trained on.

  • trajectory_generator_kwargs (Mapping[str, Any]) – kwargs to pass to the trajectory generator.

  • save_preferences (bool) – if True, store the final dataset of preferences to disk.

  • agent_path (Optional[str]) – if given, initialize the agent using this stored policy rather than randomly.

  • preference_model_kwargs (Mapping[str, Any]) – passed to PreferenceModel

  • reward_trainer_kwargs (Mapping[str, Any]) – passed to BasicRewardTrainer or EnsembleRewardTrainer

  • gatherer_cls (Type[PreferenceGatherer]) – type of PreferenceGatherer to use (defaults to SyntheticGatherer)

  • gatherer_kwargs (Mapping[str, Any]) – passed to the PreferenceGatherer specified by gatherer_cls

  • active_selection (bool) – use active selection fragmenter instead of random fragmenter

  • active_selection_oversampling (int) – factor by which to oversample random fragments from the base fragmenter of active selection. this is usually chosen > 1 to allow the active selection algorithm to pick fragment pairs with highest uncertainty. = 1 implies no active selection.

  • uncertainty_on (str) – passed to ActiveSelectionFragmenter

  • fragmenter_kwargs (Mapping[str, Any]) – passed to RandomFragmenter

  • allow_variable_horizon (bool) – If False (default), algorithm will raise an exception if it detects trajectories of different length during training. If True, overrides this safety check. WARNING: variable horizon episodes leak information about the reward via termination condition, and can seriously confound evaluation. Read https://imitation.readthedocs.io/en/latest/guide/variable_horizon.html before overriding this.

  • checkpoint_interval (int) – Save the reward model and policy models (if trajectory_generator contains a policy) every checkpoint_interval iterations and after training is complete. If 0, then only save weights after training is complete. If <0, then don’t save weights at all.

  • query_schedule (Union[str, Callable[[float], float]]) – one of (“constant”, “hyperbolic”, “inverse_quadratic”). A function indicating how the total number of preference queries should be allocated to each iteration. “hyperbolic” and “inverse_quadratic” apportion fewer queries to later iterations when the policy is assumed to be better and more stable.

  • _rnd (Generator) – Random number generator provided by Sacred.

Return type

Mapping[str, Any]

Returns

Rollout statistics from trained policy.

Raises

ValueError – Inconsistency between config and deserialized policy normalization.