"""This ingredient provides a reinforcement learning algorithm from stable-baselines3.
The algorithm instance is either freshly constructed or loaded from a file.
"""
import logging
import warnings
from typing import Any, Dict, Mapping, Optional, Type
import sacred
import stable_baselines3 as sb3
from stable_baselines3.common import (
base_class,
buffers,
off_policy_algorithm,
on_policy_algorithm,
vec_env,
)
from imitation.policies import serialize
from imitation.policies.replay_buffer_wrapper import ReplayBufferRewardWrapper
from imitation.rewards.reward_function import RewardFn
from imitation.scripts.ingredients import logging as logging_ingredient
from imitation.scripts.ingredients.policy import policy_ingredient
rl_ingredient = sacred.Ingredient(
"rl",
ingredients=[policy_ingredient, logging_ingredient.logging_ingredient],
)
logger = logging.getLogger(__name__)
@rl_ingredient.config
def config():
rl_cls = None
batch_size = None
rl_kwargs = dict()
locals() # quieten flake8
[docs]@rl_ingredient.config_hook
def config_hook(config, command_name, logger):
"""Sets defaults equivalent to sb3.PPO default hyperparameters.
This hook is a no-op if command_name is "sqil" (used only in train_imitation),
which has its own config hook.
Args:
config: Sacred config dict.
command_name: Sacred command name.
logger: Sacred logger.
Returns:
config: Updated Sacred config dict.
"""
del logger
res = {}
if config["rl"]["rl_cls"] is None and command_name != "sqil":
res["rl_cls"] = sb3.PPO
res["batch_size"] = 2048 # rl_kwargs["n_steps"] = batch_size // venv.num_envs
res["rl_kwargs"] = dict(
learning_rate=3e-4,
batch_size=64,
n_epochs=10,
ent_coef=0.0,
)
return res
@rl_ingredient.named_config
def fast():
batch_size = 2
rl_kwargs = dict(
# SB3 RL seems to need batch size of 2, otherwise it runs into numeric
# issues when computing multinomial distribution during predict()
batch_size=2,
# Setting n_epochs=1 speeds up thing a lot
n_epochs=1,
)
locals() # quieten flake8
@rl_ingredient.named_config
def sac():
# For recommended SAC hyperparams in each environment, see:
# https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/sac.yml
rl_cls = sb3.SAC
warnings.warn(
"SAC currently only supports continuous action spaces. "
"Consider adding a discrete version as mentioned here: "
"https://github.com/DLR-RM/stable-baselines3/issues/505",
category=RuntimeWarning,
)
# Default HPs are as follows:
batch_size = 256 # batch size for RL algorithm
rl_kwargs = dict(batch_size=None) # make sure to set batch size to None
locals() # quieten flake8
def _maybe_add_relabel_buffer(
rl_kwargs: Dict[str, Any],
relabel_reward_fn: Optional[RewardFn] = None,
) -> Dict[str, Any]:
"""Use ReplayBufferRewardWrapper in rl_kwargs if relabel_reward_fn is not None."""
rl_kwargs = dict(rl_kwargs)
if relabel_reward_fn:
_buffer_kwargs = dict(reward_fn=relabel_reward_fn)
_buffer_kwargs["replay_buffer_class"] = rl_kwargs.get(
"replay_buffer_class",
buffers.ReplayBuffer,
)
rl_kwargs["replay_buffer_class"] = ReplayBufferRewardWrapper
if "replay_buffer_kwargs" in rl_kwargs:
_buffer_kwargs.update(rl_kwargs["replay_buffer_kwargs"])
rl_kwargs["replay_buffer_kwargs"] = _buffer_kwargs
return rl_kwargs
[docs]@rl_ingredient.capture
def make_rl_algo(
venv: vec_env.VecEnv,
rl_cls: Type[base_class.BaseAlgorithm],
batch_size: int,
rl_kwargs: Mapping[str, Any],
policy: Mapping[str, Any],
_seed: int,
relabel_reward_fn: Optional[RewardFn] = None,
) -> base_class.BaseAlgorithm:
"""Instantiates a Stable Baselines3 RL algorithm.
Args:
venv: The vectorized environment to train on.
rl_cls: Type of a Stable Baselines3 RL algorithm.
batch_size: The batch size of the RL algorithm.
rl_kwargs: Keyword arguments for RL algorithm constructor.
policy: Configuration for the policy ingredient. We need the
policy_cls and policy_kwargs component.
relabel_reward_fn: Reward function used for reward relabeling
in replay or rollout buffers of RL algorithms.
Returns:
The RL algorithm.
Raises:
ValueError: `gen_batch_size` not divisible by `venv.num_envs`.
TypeError: `rl_cls` is neither `OnPolicyAlgorithm` nor `OffPolicyAlgorithm`.
"""
if batch_size % venv.num_envs != 0:
raise ValueError(
f"num_envs={venv.num_envs} must evenly divide batch_size={batch_size}.",
)
rl_kwargs = dict(rl_kwargs)
# TODO: this is a hack and an indicator that the rl ingredient should be refactored
if rl_cls == sb3.SAC:
del rl_kwargs["n_epochs"]
# If on-policy, collect `batch_size` many timesteps each update.
# If off-policy, train on `batch_size` many timesteps each update.
# These are different notion of batches, but this seems the closest
# possible translation, and I would expect the appropriate hyperparameter
# to be similar between them.
if issubclass(rl_cls, on_policy_algorithm.OnPolicyAlgorithm):
assert "n_steps" not in rl_kwargs, (
"set 'n_steps' at top-level using 'batch_size'. "
"n_steps = batch_size // num_vec"
)
rl_kwargs["n_steps"] = batch_size // venv.num_envs
elif issubclass(rl_cls, off_policy_algorithm.OffPolicyAlgorithm):
if rl_kwargs.get("batch_size") is not None:
raise ValueError("set 'batch_size' at top-level")
rl_kwargs["batch_size"] = batch_size
rl_kwargs = _maybe_add_relabel_buffer(
rl_kwargs=rl_kwargs,
relabel_reward_fn=relabel_reward_fn,
)
else:
raise TypeError(f"Unsupported RL algorithm '{rl_cls}'")
rl_algo = rl_cls(
policy=policy["policy_cls"],
# Note(yawen): Copy `policy_kwargs` as SB3 may mutate the config we pass.
# In particular, policy_kwargs["use_sde"] may be changed in rl_cls.__init__()
# for certain algorithms, such as Soft Actor Critic. See:
# https://github.com/DLR-RM/stable-baselines3/blob/30772aa9f53a4cf61571ee90046cdc454c1b11d7/sb3/common/off_policy_algorithm.py#L145
policy_kwargs=dict(policy["policy_kwargs"]),
env=venv,
seed=_seed,
**rl_kwargs,
)
logger.info(f"RL algorithm: {type(rl_algo)}")
logger.info(f"Policy network summary:\n {rl_algo.policy}")
return rl_algo
[docs]@rl_ingredient.capture
def load_rl_algo_from_path(
_seed: int,
agent_path: str,
venv: vec_env.VecEnv,
rl_cls: Type[base_class.BaseAlgorithm],
rl_kwargs: Mapping[str, Any],
relabel_reward_fn: Optional[RewardFn] = None,
) -> base_class.BaseAlgorithm:
rl_kwargs = dict(rl_kwargs)
# TODO: this is a hack and an indicator that the rl ingredient should be refactored
if rl_cls == sb3.SAC:
del rl_kwargs["n_epochs"]
if issubclass(rl_cls, off_policy_algorithm.OffPolicyAlgorithm):
rl_kwargs = _maybe_add_relabel_buffer(
rl_kwargs=rl_kwargs,
relabel_reward_fn=relabel_reward_fn,
)
agent = serialize.load_stable_baselines_model(
cls=rl_cls,
path=agent_path,
venv=venv,
seed=_seed,
**rl_kwargs,
)
logger.info(f"Warm starting agent from '{agent_path}'")
logger.info(f"RL algorithm: {type(agent)}")
logger.info(f"Policy network summary:\n {agent.policy}")
return agent