Source code for imitation.scripts.ingredients.policy
"""This ingredient provides a newly constructed stable-baselines3 policy."""
import logging
from typing import Any, Mapping, Type
import sacred
from stable_baselines3.common import policies, utils, vec_env
import imitation.util.networks
from imitation.policies import base
from imitation.scripts.ingredients import logging as logging_ingredient
policy_ingredient = sacred.Ingredient(
"policy",
ingredients=[logging_ingredient.logging_ingredient],
)
logger = logging.getLogger(__name__)
@policy_ingredient.config
def config():
# Training
policy_cls = base.FeedForward32Policy
policy_kwargs = {}
locals() # quieten flake8
@policy_ingredient.named_config
def sac():
policy_cls = base.SAC1024Policy # noqa: F841
NORMALIZE_RUNNING_POLICY_KWARGS = {
"features_extractor_class": base.NormalizeFeaturesExtractor,
"features_extractor_kwargs": {
"normalize_class": imitation.util.networks.RunningNorm,
},
}
@policy_ingredient.named_config
def normalize_running():
policy_kwargs = NORMALIZE_RUNNING_POLICY_KWARGS # noqa: F841
# Default config for CNN Policies
@policy_ingredient.named_config
def cnn_policy():
policy_cls = policies.ActorCriticCnnPolicy # noqa: F841
[docs]@policy_ingredient.capture
def make_policy(
venv: vec_env.VecEnv,
policy_cls: Type[policies.BasePolicy],
policy_kwargs: Mapping[str, Any],
) -> policies.BasePolicy:
"""Makes policy.
Args:
venv: Vectorized environment we will be imitating demos from.
policy_cls: Type of a Stable Baselines3 policy architecture.
Specify only if policy_path is not specified.
policy_kwargs: Keyword arguments for policy constructor.
Specify only if policy_path is not specified.
Returns:
A Stable Baselines3 policy.
"""
policy_kwargs = dict(policy_kwargs)
if issubclass(policy_cls, policies.ActorCriticPolicy):
policy_kwargs.update(
{
"observation_space": venv.observation_space,
"action_space": venv.action_space,
# parameter mandatory for ActorCriticPolicy, but not used by BC
"lr_schedule": utils.get_schedule_fn(1),
},
)
policy: policies.BasePolicy
policy = policy_cls(**policy_kwargs)
logger.info(f"Policy network summary:\n {policy}")
return policy