imitation.scripts.ingredients.rl#
This ingredient provides a reinforcement learning algorithm from stable-baselines3.
The algorithm instance is either freshly constructed or loaded from a file.
Functions
|
Sets defaults equivalent to sb3.PPO default hyperparameters. |
|
|
|
Instantiates a Stable Baselines3 RL algorithm. |
- imitation.scripts.ingredients.rl.config_hook(config, command_name, logger)[source]#
Sets defaults equivalent to sb3.PPO default hyperparameters.
This hook is a no-op if command_name is “sqil” (used only in train_imitation), which has its own config hook.
- Parameters
config – Sacred config dict.
command_name – Sacred command name.
logger – Sacred logger.
- Returns
Updated Sacred config dict.
- Return type
config
- imitation.scripts.ingredients.rl.load_rl_algo_from_path(_seed, agent_path, venv, rl_cls, rl_kwargs, relabel_reward_fn=None)[source]#
- Return type
BaseAlgorithm
- imitation.scripts.ingredients.rl.make_rl_algo(venv, rl_cls, batch_size, rl_kwargs, policy, _seed, relabel_reward_fn=None)[source]#
Instantiates a Stable Baselines3 RL algorithm.
- Parameters
venv (
VecEnv
) – The vectorized environment to train on.rl_cls (
Type
[BaseAlgorithm
]) – Type of a Stable Baselines3 RL algorithm.batch_size (
int
) – The batch size of the RL algorithm.rl_kwargs (
Mapping
[str
,Any
]) – Keyword arguments for RL algorithm constructor.policy (
Mapping
[str
,Any
]) – Configuration for the policy ingredient. We need the policy_cls and policy_kwargs component.relabel_reward_fn (
Optional
[RewardFn
]) – Reward function used for reward relabeling in replay or rollout buffers of RL algorithms.
- Return type
BaseAlgorithm
- Returns
The RL algorithm.
- Raises
ValueError – gen_batch_size not divisible by venv.num_envs.
TypeError – rl_cls is neither OnPolicyAlgorithm nor OffPolicyAlgorithm.