Source code for imitation.scripts.ingredients.expert

"""This ingredient provides an expert policy.

The expert policy is either loaded from disk or from the HuggingFace Model Hub or is
a test policy (e.g., random or zero).
The supported policy types are:

- :code:`ppo` and :code:`sac`: A policy trained with SB3.
    Needs a `path` in the `loader_kwargs`.
- :code:`<algo>-huggingface` (algo can be `ppo` or `sac`):
    A policy trained with SB3 and uploaded to the HuggingFace Model Hub.
    Will load the model from the repo :code:`<organization>/<algo>-<env_name>`.
    You can set the organization with the `organization` key in :code:`loader_kwargs`.
    The default is `HumanCompatibleAI`.
- :code:`random`: A policy that takes random actions.
- :code:`zero`: A policy that takes zero actions.
"""
import sacred

from imitation.policies import serialize
from imitation.scripts.ingredients import environment

expert_ingredient = sacred.Ingredient(
    "expert",
    ingredients=[environment.environment_ingredient],
)


@expert_ingredient.config
def config():
    # [ppo, sac, random, zero, ppo-huggingface, sac-huggingface] or your own.
    policy_type = "ppo-huggingface"
    # See imitation.policies.serialize.load_policy for options.
    loader_kwargs = dict()
    locals()  # quieten flake8


[docs]@expert_ingredient.config_hook def config_hook(config, command_name, logger): e_config = config["expert"] if "huggingface" in e_config["policy_type"]: # Set the default loader_kwargs for huggingface policies. if "organization" not in e_config["loader_kwargs"]: e_config["loader_kwargs"]["organization"] = "HumanCompatibleAI" # Note: unfortunately we need to pass the venv **and** its name to the # huggingface policy loader since there is no way to get the name from the venv. # The name is needed to deduce the repo id and load the correct huggingface # model. e_config["loader_kwargs"]["env_name"] = config["environment"]["gym_id"] # Note: this only serves the purpose to indicated that you need to specify the # path for local policies. It makes the config more explicit. if ( e_config["policy_type"] in ("ppo", "sac") and "path" not in e_config["loader_kwargs"] ): # pragma: no cover e_config["loader_kwargs"]["path"] = None return e_config
[docs]@expert_ingredient.capture def get_expert_policy(venv, policy_type, loader_kwargs): return serialize.load_policy(policy_type, venv, **loader_kwargs)