imitation.rewards.serialize#

Load serialized reward functions of different types.

Functions

load_reward(reward_type, reward_path, venv, ...)

Load serialized reward.

load_zero(path, venv)

rtype

RewardFn

Classes

ValidateRewardFn(reward_fn)

Wrap reward function to add sanity check.

class imitation.rewards.serialize.ValidateRewardFn(reward_fn)[source]#

Bases: RewardFn

Wrap reward function to add sanity check.

Checks that the length of the reward vector is equal to the batch size of the input.

__init__(reward_fn)[source]#

Builds the reward validator.

Parameters

reward_fn (RewardFn) – base reward function

imitation.rewards.serialize.load_reward(reward_type, reward_path, venv, **kwargs)[source]#

Load serialized reward.

Parameters
  • reward_type (str) – A key in reward_registry. Valid types include zero, RewardNet_normalized, RewardNet_shaped, RewardNet_unshaped, RewardNet_unnormalized, RewardNet_std_added.

  • reward_path (str) – A path specifying the reward.

  • venv (VecEnv) – An environment that the policy is to be used with.

  • **kwargs – kwargs to pass to reward fn

Return type

RewardFn

Returns

The deserialized reward.

imitation.rewards.serialize.load_zero(path, venv)[source]#
Return type

RewardFn