imitation.scripts.ingredients.rl#

This ingredient provides a reinforcement learning algorithm from stable-baselines3.

The algorithm instance is either freshly constructed or loaded from a file.

Functions

config_hook(config, command_name, logger)

Sets defaults equivalent to sb3.PPO default hyperparameters.

load_rl_algo_from_path(_seed, agent_path, ...)

rtype

BaseAlgorithm

make_rl_algo(venv, rl_cls, batch_size, ...)

Instantiates a Stable Baselines3 RL algorithm.

imitation.scripts.ingredients.rl.config_hook(config, command_name, logger)[source]#

Sets defaults equivalent to sb3.PPO default hyperparameters.

This hook is a no-op if command_name is “sqil” (used only in train_imitation), which has its own config hook.

Parameters
  • config – Sacred config dict.

  • command_name – Sacred command name.

  • logger – Sacred logger.

Returns

Updated Sacred config dict.

Return type

config

imitation.scripts.ingredients.rl.load_rl_algo_from_path(_seed, agent_path, venv, rl_cls, rl_kwargs, relabel_reward_fn=None)[source]#
Return type

BaseAlgorithm

imitation.scripts.ingredients.rl.make_rl_algo(venv, rl_cls, batch_size, rl_kwargs, policy, _seed, relabel_reward_fn=None)[source]#

Instantiates a Stable Baselines3 RL algorithm.

Parameters
  • venv (VecEnv) – The vectorized environment to train on.

  • rl_cls (Type[BaseAlgorithm]) – Type of a Stable Baselines3 RL algorithm.

  • batch_size (int) – The batch size of the RL algorithm.

  • rl_kwargs (Mapping[str, Any]) – Keyword arguments for RL algorithm constructor.

  • policy (Mapping[str, Any]) – Configuration for the policy ingredient. We need the policy_cls and policy_kwargs component.

  • relabel_reward_fn (Optional[RewardFn]) – Reward function used for reward relabeling in replay or rollout buffers of RL algorithms.

Return type

BaseAlgorithm

Returns

The RL algorithm.

Raises
  • ValueErrorgen_batch_size not divisible by venv.num_envs.

  • TypeErrorrl_cls is neither OnPolicyAlgorithm nor OffPolicyAlgorithm.