imitation.policies.serialize#
Load serialized policies of different types.
Module Attributes
A policy loader function that takes a VecEnv before any other custom arguments and returns a stable_baselines3 base policy policy. |
|
Registry of policy loading functions. |
Functions
|
Load serialized policy. |
|
Helper method to load RL models from Stable Baselines. |
|
Serialize Stable Baselines model. |
Classes
|
Saves the policy using save_stable_model each time it is called. |
- imitation.policies.serialize.PolicyLoaderFn#
A policy loader function that takes a VecEnv before any other custom arguments and returns a stable_baselines3 base policy policy.
alias of
Callable
[[…],BasePolicy
]
- class imitation.policies.serialize.SavePolicyCallback(policy_dir, *args, **kwargs)[source]#
Bases:
EventCallback
Saves the policy using save_stable_model each time it is called.
Should be used in conjunction with callbacks.EveryNTimesteps or another event-based trigger.
- __init__(policy_dir, *args, **kwargs)[source]#
Builds SavePolicyCallback.
- Parameters
policy_dir (
Path
) – Directory to save checkpoints.*args – Passed through to callbacks.EventCallback.
**kwargs – Passed through to callbacks.EventCallback.
- model: base_class.BaseAlgorithm#
- imitation.policies.serialize.load_policy(policy_type, venv, **kwargs)[source]#
Load serialized policy.
Note on the kwargs:
zero and random policy take no kwargs
ppo and sac policies take a path argument with a path to a zip file or to a folder containing a model.zip file.
ppo-huggingface and sac-huggingface policies take an env_name and optional organization argument.
- Parameters
policy_type (
str
) – A key in policy_registry, e.g. ppo.venv (
VecEnv
) – An environment that the policy is to be used with.**kwargs – Additional arguments to pass to the policy loader.
- Return type
BasePolicy
- Returns
The deserialized policy.
- imitation.policies.serialize.load_stable_baselines_model(cls, path, venv, **kwargs)[source]#
Helper method to load RL models from Stable Baselines.
- Parameters
cls (
Type
[TypeVar
(Algorithm
, bound=BaseAlgorithm
)]) – Stable Baselines RL algorithm.path (
str
) – Path to zip file containing saved model data or to a folder containing a model.zip file.venv (
VecEnv
) – Environment to train on.kwargs – Passed through to cls.load.
- Raises
FileNotFoundError – If path is not a directory containing a model.zip file.
FileExistsError – If path contains a vec_normalize.pkl file (unsupported).
- Return type
TypeVar
(Algorithm
, bound=BaseAlgorithm
)- Returns
The deserialized RL algorithm.
- imitation.policies.serialize.policy_registry: Registry[Callable[[...], BasePolicy]] = <imitation.util.registry.Registry object>#
Registry of policy loading functions. Add your own here if desired.
- imitation.policies.serialize.save_stable_model(output_dir, model, filename='model.zip')[source]#
Serialize Stable Baselines model.
Load later with load_policy(…, policy_path=output_dir).
- Parameters
output_dir (
Path
) – Path to the save directory.model (
BaseAlgorithm
) – The stable baselines model.filename (
str
) – The filename of the model.
- Return type
None