Source code for imitation.scripts.ingredients.bc

"""This ingredient provides BC algorithm instance.

It is either loaded from disk or constructed from scratch.
"""
import warnings
from typing import Optional, Sequence

import sacred
import torch as th
from stable_baselines3.common import vec_env

from imitation.algorithms import bc
from imitation.data import types
from imitation.scripts.ingredients import policy

bc_ingredient = sacred.Ingredient("bc", ingredients=[policy.policy_ingredient])


@bc_ingredient.config
def config():
    batch_size = 32
    l2_weight = 3e-5  # L2 regularization weight
    optimizer_cls = th.optim.Adam
    optimizer_kwargs = dict(
        lr=4e-4,
    )
    train_kwargs = dict(
        n_epochs=None,  # Number of BC epochs per DAgger training round
        n_batches=None,  # Number of BC batches per DAgger training round
        log_interval=500,  # Number of updates between Tensorboard/stdout logs
    )
    agent_path = None  # Path to serialized policy. If None, a new policy is created.

    locals()  # quieten flake8 unused variable warning


[docs]@bc_ingredient.capture def make_bc( venv: vec_env.VecEnv, expert_trajs: Sequence[types.Trajectory], custom_logger, batch_size: int, l2_weight: float, optimizer_cls, optimizer_kwargs, _rnd, ) -> bc.BC: return bc.BC( observation_space=venv.observation_space, action_space=venv.action_space, policy=make_or_load_policy(venv), demonstrations=expert_trajs, custom_logger=custom_logger, rng=_rnd, batch_size=batch_size, l2_weight=l2_weight, optimizer_cls=optimizer_cls, optimizer_kwargs=optimizer_kwargs, )
[docs]@bc_ingredient.capture def make_or_load_policy(venv: vec_env.VecEnv, agent_path: Optional[str]): """Makes a policy or loads a policy from a path if provided. Args: venv: Vectorized environment we will be imitating demos from. agent_path: Path to serialized policy. If provided, then load the policy from this path. Otherwise, make a new policy. Specify only if policy_cls and policy_kwargs are not specified. Returns: A Stable Baselines3 policy. """ if agent_path is None: policy.make_policy(venv) else: warnings.warn( "When agent_path is specified, policy.policy_cls and policy.policy_kwargs " "are ignored.", RuntimeWarning, ) return bc.reconstruct_policy(agent_path)