Source code for imitation.scripts.ingredients.sqil
"""This ingredient provides a SQIL algorithm instance."""
import sacred
from stable_baselines3 import dqn as dqn_algorithm
from imitation.policies import base
from imitation.scripts.ingredients import policy, rl
sqil_ingredient = sacred.Ingredient(
"sqil",
ingredients=[rl.rl_ingredient, policy.policy_ingredient],
)
@sqil_ingredient.config
def config():
total_timesteps = 3e5
train_kwargs = dict(
log_interval=4, # Number of updates between Tensorboard/stdout logs
progress_bar=True,
)
locals() # quieten flake8 unused variable warning
[docs]@rl.rl_ingredient.config_hook
def override_rl_cls(config, command_name, logger):
# want to remove arguments added by the rl ingredient but keep
# the ones that are added by others
del logger
res = {}
if command_name == "sqil" and config["rl"]["rl_cls"] is None:
res["rl_cls"] = dqn_algorithm.DQN
return res
[docs]@policy.policy_ingredient.config_hook
def override_policy_cls(config, command_name, logger): # noqa
del logger
res = {}
if (
command_name == "sqil"
and config["policy"]["policy_cls"] == base.FeedForward32Policy
):
res["policy_cls"] = "MlpPolicy"
return res