imitation.util.networks#

Helper methods to build and run neural networks.

Functions

build_cnn(in_channels, hid_channels[, ...])

Constructs a Torch CNN.

build_mlp(in_size, hid_sizes[, out_size, ...])

Constructs a Torch MLP.

training_mode(m[, mode])

Temporarily switch module m to specified training mode.

Classes

BaseNorm(num_features[, eps])

Base class for layers that try to normalize the input to mean 0 and variance 1.

EMANorm(num_features[, decay, eps])

Similar to RunningNorm but uses an exponential weighting.

RunningNorm(num_features[, eps])

Normalizes input to mean 0 and standard deviation 1 using a running average.

SqueezeLayer(*args, **kwargs)

Torch module that squeezes a B*1 tensor down into a size-B vector.

class imitation.util.networks.BaseNorm(num_features, eps=1e-05)[source]#

Bases: Module, ABC

Base class for layers that try to normalize the input to mean 0 and variance 1.

Similar to BatchNorm, LayerNorm, etc. but whereas they only use statistics from the current batch at train time, we use statistics from all batches.

__init__(num_features, eps=1e-05)[source]#

Builds RunningNorm.

Parameters
  • num_features (int) – Number of features; the length of the non-batch dimension.

  • eps (float) – Small constant for numerical stability. Inputs are rescaled by 1 / sqrt(estimated_variance + eps).

count: Tensor#
forward(x)[source]#

Updates statistics if in training mode. Returns normalized x.

Return type

Tensor

reset_running_stats()[source]#

Resets running stats to defaults, yielding the identity transformation.

Return type

None

running_mean: Tensor#
running_var: Tensor#
abstract update_stats(batch)[source]#

Update self.running_mean, self.running_var and self.count.

Return type

None

class imitation.util.networks.EMANorm(num_features, decay=0.99, eps=1e-05)[source]#

Bases: BaseNorm

Similar to RunningNorm but uses an exponential weighting.

__init__(num_features, decay=0.99, eps=1e-05)[source]#

Builds EMARunningNorm.

Parameters
  • num_features (int) – Number of features; the length of the non-batch dim.

  • decay (float) – how quickly the weight on past samples decays over time.

  • eps (float) – small constant for numerical stability.

Raises

ValueError – if decay is out of range.

inv_learning_rate: Tensor#
num_batches: IntTensor#
reset_running_stats()[source]#

Reset the running stats of the normalization layer.

update_stats(batch)[source]#

Update self.running_mean and self.running_var in batch mode.

Reference Algorithm 3 from: https://github.com/HumanCompatibleAI/imitation/files/9456540/Incremental_batch_EMA_and_EMV.pdf

Parameters

batch (Tensor) – A batch of data to use to update the running mean and variance.

Return type

None

class imitation.util.networks.RunningNorm(num_features, eps=1e-05)[source]#

Bases: BaseNorm

Normalizes input to mean 0 and standard deviation 1 using a running average.

Similar to BatchNorm, LayerNorm, etc. but whereas they only use statistics from the current batch at train time, we use statistics from all batches.

This should replicate the common practice in RL of normalizing environment observations, such as using VecNormalize in Stable Baselines. Note that the behavior of this class is slightly different from VecNormalize, e.g., it works with the current reward instead of return estimate, and subtracts the mean reward whereas VecNormalize only rescales it.

count: Tensor#
running_mean: Tensor#
running_var: Tensor#
training: bool#
update_stats(batch)[source]#

Update self.running_mean, self.running_var and self.count.

Uses Chan et al (1979), “Updating Formulae and a Pairwise Algorithm for Computing Sample Variances.” to update the running moments in a numerically stable fashion.

Parameters

batch (Tensor) – A batch of data to use to update the running mean and variance.

Return type

None

class imitation.util.networks.SqueezeLayer(*args, **kwargs)[source]#

Bases: Module

Torch module that squeezes a B*1 tensor down into a size-B vector.

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
imitation.util.networks.build_cnn(in_channels, hid_channels, out_size=1, name=None, activation=<class 'torch.nn.modules.activation.ReLU'>, kernel_size=3, stride=1, padding='same', dropout_prob=0.0, squeeze_output=False)[source]#

Constructs a Torch CNN.

Parameters
  • in_channels (int) – number of channels of individual inputs; input to the CNN will have shape (batch_size, in_size, in_height, in_width).

  • hid_channels (Iterable[int]) – number of channels of hidden layers. If this is an empty iterable, then we build a linear function approximator.

  • out_size (int) – size of output vector.

  • name (Optional[str]) – Name to use as a prefix for the layers ID.

  • activation (Type[Module]) – activation to apply after hidden layers.

  • kernel_size (int) – size of convolutional kernels.

  • stride (int) – stride of convolutional kernels.

  • padding (Union[int, str]) – padding of convolutional kernels.

  • dropout_prob (float) – Dropout probability to use after each hidden layer. If 0, no dropout layers are added to the network.

  • squeeze_output (bool) – if out_size=1, then squeeze_input=True ensures that CNN output is of size (B,) instead of (B,1).

Returns

a CNN mapping from inputs of size (batch_size, in_size, in_height,

in_width) to (batch_size, out_size), unless out_size=1 and squeeze_output=True, in which case the output is of size (batch_size, ).

Return type

nn.Module

Raises

ValueError – if squeeze_output was supplied with out_size!=1.

imitation.util.networks.build_mlp(in_size, hid_sizes, out_size=1, name=None, activation=<class 'torch.nn.modules.activation.ReLU'>, dropout_prob=0.0, squeeze_output=False, flatten_input=False, normalize_input_layer=None)[source]#

Constructs a Torch MLP.

Parameters
  • in_size (int) – size of individual input vectors; input to the MLP will be of shape (batch_size, in_size).

  • hid_sizes (Iterable[int]) – sizes of hidden layers. If this is an empty iterable, then we build a linear function approximator.

  • out_size (int) – size of output vector.

  • name (Optional[str]) – Name to use as a prefix for the layers ID.

  • activation (Type[Module]) – activation to apply after hidden layers.

  • dropout_prob (float) – Dropout probability to use after each hidden layer. If 0, no dropout layers are added to the network.

  • squeeze_output (bool) – if out_size=1, then squeeze_input=True ensures that MLP output is of size (B,) instead of (B,1).

  • flatten_input (bool) – should input be flattened along axes 1, 2, 3, …? Useful if you want to, e.g., process small images inputs with an MLP.

  • normalize_input_layer (Optional[Type[Module]]) – if specified, module to use to normalize inputs; e.g. nn.BatchNorm or RunningNorm.

Returns

an MLP mapping from inputs of size (batch_size, in_size) to

(batch_size, out_size), unless out_size=1 and squeeze_output=True, in which case the output is of size (batch_size, ).

Return type

nn.Module

Raises

ValueError – if squeeze_output was supplied with out_size!=1.

imitation.util.networks.evaluating(m: Module, *, mode: bool = False)#

Temporarily switch module m to specified training mode.

Parameters
  • m – The module to switch the mode of.

  • mode – whether to set training mode (True) or evaluation (False).

Yields

The module m.

imitation.util.networks.training(m: Module, *, mode: bool = True)#

Temporarily switch module m to specified training mode.

Parameters
  • m – The module to switch the mode of.

  • mode – whether to set training mode (True) or evaluation (False).

Yields

The module m.

imitation.util.networks.training_mode(m, mode=False)[source]#

Temporarily switch module m to specified training mode.

Parameters
  • m (Module) – The module to switch the mode of.

  • mode (bool) – whether to set training mode (True) or evaluation (False).

Yields

The module m.