imitation.policies.base#

Custom policy classes and convenience methods.

Classes

FeedForward32Policy(*args, **kwargs)

A feed forward policy network with two hidden layers of 32 units.

NonTrainablePolicy(observation_space, ...)

Abstract class for non-trainable (e.g.

NormalizeFeaturesExtractor(observation_space)

Feature extractor that flattens then normalizes input.

RandomPolicy(observation_space, action_space)

Returns random actions.

SAC1024Policy(*args, **kwargs)

Actor and value networks with two hidden layers of 1024 units respectively.

ZeroPolicy(observation_space, action_space)

Returns constant zero action.

class imitation.policies.base.FeedForward32Policy(*args, **kwargs)[source]#

Bases: ActorCriticPolicy

A feed forward policy network with two hidden layers of 32 units.

This matches the IRL policies in the original AIRL paper.

Note: This differs from stable_baselines3 ActorCriticPolicy in two ways: by having 32 rather than 64 units, and by having policy and value networks share weights except at the final layer, where there are different linear heads.

__init__(*args, **kwargs)[source]#

Builds FeedForward32Policy; arguments passed to ActorCriticPolicy.

features_extractor: BaseFeaturesExtractor#
class imitation.policies.base.NonTrainablePolicy(observation_space, action_space)[source]#

Bases: BasePolicy, ABC

Abstract class for non-trainable (e.g. hard-coded or interactive) policies.

__init__(observation_space, action_space)[source]#

Builds NonTrainablePolicy with specified observation and action space.

features_extractor: BaseFeaturesExtractor#
forward(*args)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class imitation.policies.base.NormalizeFeaturesExtractor(observation_space, normalize_class=<class 'imitation.util.networks.RunningNorm'>)[source]#

Bases: FlattenExtractor

Feature extractor that flattens then normalizes input.

__init__(observation_space, normalize_class=<class 'imitation.util.networks.RunningNorm'>)[source]#

Builds NormalizeFeaturesExtractor.

Parameters
  • observation_space (Space) – The space observations lie in.

  • normalize_class (Type[Module]) – The class to use to normalize observations (after being flattened). This can be any Module that preserves the shape; e.g. nn.BatchNorm* or nn.LayerNorm.

forward(observations)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type

Tensor

training: bool#
class imitation.policies.base.RandomPolicy(observation_space, action_space)[source]#

Bases: NonTrainablePolicy

Returns random actions.

features_extractor: BaseFeaturesExtractor#
optimizer: th.optim.Optimizer#
training: bool#
class imitation.policies.base.SAC1024Policy(*args, **kwargs)[source]#

Bases: SACPolicy

Actor and value networks with two hidden layers of 1024 units respectively.

This matches the implementation of SAC policies in the PEBBLE paper. See: https://arxiv.org/pdf/2106.05091.pdf https://github.com/denisyarats/pytorch_sac/blob/master/config/agent/sac.yaml

Note: This differs from stable_baselines3 SACPolicy by having 1024 hidden units in each layer instead of the default value of 256.

__init__(*args, **kwargs)[source]#

Builds SAC1024Policy; arguments passed to SACPolicy.

actor: Actor#
critic: ContinuousCritic#
critic_target: ContinuousCritic#
class imitation.policies.base.ZeroPolicy(observation_space, action_space)[source]#

Bases: NonTrainablePolicy

Returns constant zero action.

__init__(observation_space, action_space)[source]#

Builds ZeroPolicy with specified observation and action space.

features_extractor: BaseFeaturesExtractor#
optimizer: th.optim.Optimizer#
training: bool#