imitation.regularization.regularizers#

Implements the regularizer base class and some standard regularizers.

Classes

LossRegularizer(optimizer, initial_lambda, ...)

Abstract base class for regularizers that add a loss term to the loss function.

LpRegularizer(optimizer, initial_lambda, ...)

Applies Lp regularization to a loss function.

Regularizer(optimizer, initial_lambda, ...)

Abstract class for creating regularizers with a common interface.

RegularizerFactory(*args, **kwargs)

Protocol for functions that create regularizers.

WeightDecayRegularizer(optimizer, ...[, ...])

Applies weight decay to a loss function.

WeightRegularizer(optimizer, initial_lambda, ...)

Abstract base class for regularizers that regularize the weights of a network.

class imitation.regularization.regularizers.LossRegularizer(optimizer, initial_lambda, lambda_updater, logger, val_split=None)[source]#

Bases: Regularizer[Union[Tensor, float]]

Abstract base class for regularizers that add a loss term to the loss function.

Requires the user to implement the _loss_penalty method.

lambda_: float#
lambda_updater: Optional[LambdaUpdater]#
logger: HierarchicalLogger#
optimizer: Optimizer#
regularize_and_backward(loss)[source]#

Add the regularization term to the loss and compute gradients.

Parameters

loss (Tensor) – The loss to regularize.

Return type

Union[Tensor, float]

Returns

The regularized loss.

val_split: Optional[float]#
class imitation.regularization.regularizers.LpRegularizer(optimizer, initial_lambda, lambda_updater, logger, p, val_split=None)[source]#

Bases: LossRegularizer

Applies Lp regularization to a loss function.

__init__(optimizer, initial_lambda, lambda_updater, logger, p, val_split=None)[source]#

Initialize the regularizer.

p: int#
class imitation.regularization.regularizers.Regularizer(optimizer, initial_lambda, lambda_updater, logger, val_split=None)[source]#

Bases: ABC, Generic[R]

Abstract class for creating regularizers with a common interface.

__init__(optimizer, initial_lambda, lambda_updater, logger, val_split=None)[source]#

Initialize the regularizer.

Parameters
  • optimizer (Optimizer) – The optimizer to which the regularizer is attached.

  • initial_lambda (float) – The initial value of the regularization parameter.

  • lambda_updater (Optional[LambdaUpdater]) – A callable object that takes in the current lambda and the train and val loss, and returns the new lambda.

  • logger (HierarchicalLogger) – The logger to which the regularizer will log its parameters.

  • val_split (Optional[float]) – The fraction of the training data to use as validation data for the lambda updater. Can be none if no lambda updater is provided.

Raises
  • ValueError – if no lambda updater (lambda_updater) is provided and the initial regularization strength (initial_lambda) is zero.

  • ValueError – if a validation split (val_split) is provided but it’s not a float in the (0, 1) interval.

  • ValueError – if a lambda updater is provided but no validation split is provided.

  • ValueError – if a validation split is set, but no lambda updater is provided.

classmethod create(initial_lambda, lambda_updater=None, val_split=0.0, **kwargs)[source]#

Create a regularizer.

Return type

RegularizerFactory[TypeVar(Self, bound= Regularizer)]

lambda_: float#
lambda_updater: Optional[LambdaUpdater]#
logger: HierarchicalLogger#
optimizer: Optimizer#
abstract regularize_and_backward(loss)[source]#

Abstract method for performing the regularization step.

The return type is a generic and the specific implementation must describe the meaning of the return type.

This step will also call loss.backward() for the user. This is because the regularizer may require the loss to be called before or after the regularization step. Leaving this to the user would force them to make their implementation dependent on the regularizer algorithm used, which is prone to errors.

Parameters

loss (Tensor) – The loss to regularize.

Return type

TypeVar(R)

update_params(train_loss, val_loss)[source]#

Update the regularization parameter.

This method calls the lambda_updater to update the regularization parameter, and assigns the new value to self.lambda_. Then logs the new value using the provided logger.

Parameters
  • train_loss (Union[Tensor, float]) – The loss on the training set.

  • val_loss (Union[Tensor, float]) – The loss on the validation set.

Return type

None

val_split: Optional[float]#
class imitation.regularization.regularizers.RegularizerFactory(*args, **kwargs)[source]#

Bases: Protocol[T_Regularizer_co]

Protocol for functions that create regularizers.

The regularizer factory is meant to be used as a way to create a regularizer in two steps. First, the end-user creates a regularizer factory by calling the .create() method of a regularizer class. This allows specifying all the relevant configuration to the regularization algorithm. Then, the network algorithm finishes setting up the optimizer and logger, and calls the regularizer factory to create the regularizer.

This two-step process separates the configuration of the regularization algorithm from additional “operational” parameters. This is useful because it solves two problems:

  1. The end-user does not have access to the optimizer and logger when configuring the regularization algorithm.

  2. Validation of the configuration is done outside the network constructor.

It also allows re-using the same regularizer factory for multiple networks.

__init__(*args, **kwargs)#
class imitation.regularization.regularizers.WeightDecayRegularizer(optimizer, initial_lambda, lambda_updater, logger, val_split=None)[source]#

Bases: WeightRegularizer

Applies weight decay to a loss function.

lambda_: float#
lambda_updater: Optional[LambdaUpdater]#
logger: HierarchicalLogger#
optimizer: Optimizer#
val_split: Optional[float]#
class imitation.regularization.regularizers.WeightRegularizer(optimizer, initial_lambda, lambda_updater, logger, val_split=None)[source]#

Bases: Regularizer

Abstract base class for regularizers that regularize the weights of a network.

Requires the user to implement the _weight_penalty method.

lambda_: float#
lambda_updater: Optional[LambdaUpdater]#
logger: HierarchicalLogger#
optimizer: Optimizer#
regularize_and_backward(loss)[source]#

Regularize the weights of the network, and call loss.backward().

Return type

None

val_split: Optional[float]#