imitation.policies.serialize#

Load serialized policies of different types.

Module Attributes

PolicyLoaderFn

A policy loader function that takes a VecEnv before any other custom arguments and returns a stable_baselines3 base policy policy.

policy_registry

Registry of policy loading functions.

Functions

load_policy(policy_type, venv, **kwargs)

Load serialized policy.

load_stable_baselines_model(cls, path, venv, ...)

Helper method to load RL models from Stable Baselines.

save_stable_model(output_dir, model[, filename])

Serialize Stable Baselines model.

Classes

SavePolicyCallback(policy_dir, *args, **kwargs)

Saves the policy using save_stable_model each time it is called.

imitation.policies.serialize.PolicyLoaderFn#

A policy loader function that takes a VecEnv before any other custom arguments and returns a stable_baselines3 base policy policy.

alias of Callable[[…], BasePolicy]

class imitation.policies.serialize.SavePolicyCallback(policy_dir, *args, **kwargs)[source]#

Bases: EventCallback

Saves the policy using save_stable_model each time it is called.

Should be used in conjunction with callbacks.EveryNTimesteps or another event-based trigger.

__init__(policy_dir, *args, **kwargs)[source]#

Builds SavePolicyCallback.

Parameters
  • policy_dir (Path) – Directory to save checkpoints.

  • *args – Passed through to callbacks.EventCallback.

  • **kwargs – Passed through to callbacks.EventCallback.

model: base_class.BaseAlgorithm#
imitation.policies.serialize.load_policy(policy_type, venv, **kwargs)[source]#

Load serialized policy.

Note on the kwargs:

  • zero and random policy take no kwargs

  • ppo and sac policies take a path argument with a path to a zip file or to a folder containing a model.zip file.

  • ppo-huggingface and sac-huggingface policies take an env_name and optional organization argument.

Parameters
  • policy_type (str) – A key in policy_registry, e.g. ppo.

  • venv (VecEnv) – An environment that the policy is to be used with.

  • **kwargs – Additional arguments to pass to the policy loader.

Return type

BasePolicy

Returns

The deserialized policy.

imitation.policies.serialize.load_stable_baselines_model(cls, path, venv, **kwargs)[source]#

Helper method to load RL models from Stable Baselines.

Parameters
  • cls (Type[TypeVar(Algorithm, bound= BaseAlgorithm)]) – Stable Baselines RL algorithm.

  • path (str) – Path to zip file containing saved model data or to a folder containing a model.zip file.

  • venv (VecEnv) – Environment to train on.

  • kwargs – Passed through to cls.load.

Raises
  • FileNotFoundError – If path is not a directory containing a model.zip file.

  • FileExistsError – If path contains a vec_normalize.pkl file (unsupported).

Return type

TypeVar(Algorithm, bound= BaseAlgorithm)

Returns

The deserialized RL algorithm.

imitation.policies.serialize.policy_registry: Registry[Callable[[...], BasePolicy]] = <imitation.util.registry.Registry object>#

Registry of policy loading functions. Add your own here if desired.

imitation.policies.serialize.save_stable_model(output_dir, model, filename='model.zip')[source]#

Serialize Stable Baselines model.

Load later with load_policy(…, policy_path=output_dir).

Parameters
  • output_dir (Path) – Path to the save directory.

  • model (BaseAlgorithm) – The stable baselines model.

  • filename (str) – The filename of the model.

Return type

None