Source code for imitation.policies.serialize

"""Load serialized policies of different types."""

# FIXME(sam): it seems like this module could mostly be replaced with a few
# torch.load() and calls

import logging
import pathlib
from typing import Callable, Type, TypeVar

import huggingface_sb3 as hfsb3
from stable_baselines3.common import base_class, callbacks, policies, vec_env

from imitation.policies import base
from imitation.util import registry, util

Algorithm = TypeVar("Algorithm", bound=base_class.BaseAlgorithm)

# Note: a VecEnv will always be passed first and then any kwargs. There is just no
# proper way to specify this in python yet. For details see
# TODO(juan) this can be fixed using ParamSpec
#  (
PolicyLoaderFn = Callable[..., policies.BasePolicy]
"""A policy loader function that takes a VecEnv before any other custom arguments and
returns a stable_baselines3 base policy policy."""

policy_registry: registry.Registry[PolicyLoaderFn] = registry.Registry()
"""Registry of policy loading functions. Add your own here if desired."""

[docs]def load_stable_baselines_model( cls: Type[Algorithm], path: str, venv: vec_env.VecEnv, **kwargs, ) -> Algorithm: """Helper method to load RL models from Stable Baselines. Args: cls: Stable Baselines RL algorithm. path: Path to zip file containing saved model data or to a folder containing a `` file. venv: Environment to train on. kwargs: Passed through to `cls.load`. Raises: FileNotFoundError: If `path` is not a directory containing a `` file. FileExistsError: If `path` contains a `vec_normalize.pkl` file (unsupported). Returns: The deserialized RL algorithm. """"Loading Stable Baselines policy for '{cls}' from '{path}'") path_obj = util.parse_path(path) if path_obj.is_dir(): path_obj = path_obj / "" if not path_obj.exists(): raise FileNotFoundError( f"Expected '{path}' to be a directory containing a '' file.", ) # SOMEDAY(adam): added 2022-01, can probably remove this check in 2023 vec_normalize_path = path_obj.parent / "vec_normalize.pkl" if vec_normalize_path.exists(): raise FileExistsError( "Outdated policy format: we do not support restoring normalization " f"statistics from '{vec_normalize_path}'", ) return cls.load(path_obj, env=venv, **kwargs)
def _load_stable_baselines_from_file( cls: Type[base_class.BaseAlgorithm], ) -> PolicyLoaderFn: """Creates a policy loading function to read a policy from a file. Args: cls: The RL algorithm, e.g. `stable_baselines3.PPO`. Returns: A function loading policies trained via cls. """ def f(venv: vec_env.VecEnv, path: str) -> policies.BasePolicy: """Loads a policy saved to path, for environment env.""" model = load_stable_baselines_model(cls, path, venv) return getattr(model, "policy") return f def _load_stable_baselines_from_huggingface( algo_name: str, cls: Type[base_class.BaseAlgorithm], ) -> PolicyLoaderFn: """Creates a policy loading function to load from Hugging Face. Args: algo_name: The name of the algorithm, e.g. `ppo`. cls: The RL algorithm, e.g. `stable_baselines3.PPO`. Returns: A function loading policies trained via cls. """ def f( venv: vec_env.VecEnv, env_name: str, organization: str = "HumanCompatibleAI", ) -> policies.BasePolicy: """Loads a policy saved to path, for environment env.""" model_name = hfsb3.ModelName(algo_name, hfsb3.EnvironmentName(env_name)) repo_id = hfsb3.ModelRepoId(organization, model_name) filename = hfsb3.load_from_hub(repo_id, model_name.filename) model = load_stable_baselines_model(cls, filename, venv) return getattr(model, "policy") return f policy_registry.register( "random", value=registry.build_loader_fn_require_space(base.RandomPolicy), ) policy_registry.register( "zero", value=registry.build_loader_fn_require_space(base.ZeroPolicy), ) def _add_stable_baselines_policies_from_file(classes): for k, cls_name in classes.items(): cls = registry.load_attr(cls_name) fn = _load_stable_baselines_from_file(cls) policy_registry.register(k, value=fn) def _add_stable_baselines_policies_from_huggingface(classes): for k, cls_name in classes.items(): cls = registry.load_attr(cls_name) fn = _load_stable_baselines_from_huggingface(k, cls) policy_registry.register(f"{k}-huggingface", value=fn) STABLE_BASELINES_CLASSES = { "ppo": "stable_baselines3:PPO", "sac": "stable_baselines3:SAC", } _add_stable_baselines_policies_from_file(STABLE_BASELINES_CLASSES) _add_stable_baselines_policies_from_huggingface(STABLE_BASELINES_CLASSES)
[docs]def load_policy( policy_type: str, venv: vec_env.VecEnv, **kwargs, ) -> policies.BasePolicy: """Load serialized policy. Note on the kwargs: - `zero` and `random` policy take no kwargs - `ppo` and `sac` policies take a `path` argument with a path to a zip file or to a folder containing a `` file. - `ppo-huggingface` and `sac-huggingface` policies take an `env_name` and optional `organization` argument. Args: policy_type: A key in `policy_registry`, e.g. `ppo`. venv: An environment that the policy is to be used with. **kwargs: Additional arguments to pass to the policy loader. Returns: The deserialized policy. """ agent_loader = policy_registry.get(policy_type) return agent_loader(venv, **kwargs)
[docs]def save_stable_model( output_dir: pathlib.Path, model: base_class.BaseAlgorithm, filename: str = "", ) -> None: """Serialize Stable Baselines model. Load later with `load_policy(..., policy_path=output_dir)`. Args: output_dir: Path to the save directory. model: The stable baselines model. filename: The filename of the model. """ # Save each model in new directory in case we want to add metadata or other # information in future. (E.g. we used to save `VecNormalize` statistics here, # although that is no longer necessary.) output_dir.mkdir(parents=True, exist_ok=True) / filename)"Saved policy to {output_dir}")
[docs]class SavePolicyCallback(callbacks.EventCallback): """Saves the policy using `save_stable_model` each time it is called. Should be used in conjunction with `callbacks.EveryNTimesteps` or another event-based trigger. """
[docs] def __init__( self, policy_dir: pathlib.Path, *args, **kwargs, ): """Builds SavePolicyCallback. Args: policy_dir: Directory to save checkpoints. *args: Passed through to `callbacks.EventCallback`. **kwargs: Passed through to `callbacks.EventCallback`. """ super().__init__(*args, **kwargs) self.policy_dir = policy_dir
def _on_step(self) -> bool: assert self.model is not None output_dir = self.policy_dir / f"{self.num_timesteps:012d}" save_stable_model(output_dir, self.model) return True