imitation.regularization.regularizers#
Implements the regularizer base class and some standard regularizers.
Classes
|
Abstract base class for regularizers that add a loss term to the loss function. |
|
Applies Lp regularization to a loss function. |
|
Abstract class for creating regularizers with a common interface. |
|
Protocol for functions that create regularizers. |
|
Applies weight decay to a loss function. |
|
Abstract base class for regularizers that regularize the weights of a network. |
- class imitation.regularization.regularizers.LossRegularizer(optimizer, initial_lambda, lambda_updater, logger, val_split=None)[source]#
Bases:
Regularizer
[Union
[Tensor
,float
]]Abstract base class for regularizers that add a loss term to the loss function.
Requires the user to implement the _loss_penalty method.
- lambda_: float#
- lambda_updater: Optional[LambdaUpdater]#
- logger: HierarchicalLogger#
- optimizer: Optimizer#
- regularize_and_backward(loss)[source]#
Add the regularization term to the loss and compute gradients.
- Parameters
loss (
Tensor
) – The loss to regularize.- Return type
Union
[Tensor
,float
]- Returns
The regularized loss.
- val_split: Optional[float]#
- class imitation.regularization.regularizers.LpRegularizer(optimizer, initial_lambda, lambda_updater, logger, p, val_split=None)[source]#
Bases:
LossRegularizer
Applies Lp regularization to a loss function.
- __init__(optimizer, initial_lambda, lambda_updater, logger, p, val_split=None)[source]#
Initialize the regularizer.
- p: int#
- class imitation.regularization.regularizers.Regularizer(optimizer, initial_lambda, lambda_updater, logger, val_split=None)[source]#
Bases:
ABC
,Generic
[R
]Abstract class for creating regularizers with a common interface.
- __init__(optimizer, initial_lambda, lambda_updater, logger, val_split=None)[source]#
Initialize the regularizer.
- Parameters
optimizer (
Optimizer
) – The optimizer to which the regularizer is attached.initial_lambda (
float
) – The initial value of the regularization parameter.lambda_updater (
Optional
[LambdaUpdater
]) – A callable object that takes in the current lambda and the train and val loss, and returns the new lambda.logger (
HierarchicalLogger
) – The logger to which the regularizer will log its parameters.val_split (
Optional
[float
]) – The fraction of the training data to use as validation data for the lambda updater. Can be none if no lambda updater is provided.
- Raises
ValueError – if no lambda updater (
lambda_updater
) is provided and the initial regularization strength (initial_lambda
) is zero.ValueError – if a validation split (
val_split
) is provided but it’s not a float in the (0, 1) interval.ValueError – if a lambda updater is provided but no validation split is provided.
ValueError – if a validation split is set, but no lambda updater is provided.
- classmethod create(initial_lambda, lambda_updater=None, val_split=0.0, **kwargs)[source]#
Create a regularizer.
- Return type
RegularizerFactory
[TypeVar
(Self
, bound= Regularizer)]
- lambda_: float#
- lambda_updater: Optional[LambdaUpdater]#
- logger: HierarchicalLogger#
- optimizer: Optimizer#
- abstract regularize_and_backward(loss)[source]#
Abstract method for performing the regularization step.
The return type is a generic and the specific implementation must describe the meaning of the return type.
This step will also call loss.backward() for the user. This is because the regularizer may require the loss to be called before or after the regularization step. Leaving this to the user would force them to make their implementation dependent on the regularizer algorithm used, which is prone to errors.
- Parameters
loss (
Tensor
) – The loss to regularize.- Return type
TypeVar
(R
)
- update_params(train_loss, val_loss)[source]#
Update the regularization parameter.
This method calls the lambda_updater to update the regularization parameter, and assigns the new value to self.lambda_. Then logs the new value using the provided logger.
- Parameters
train_loss (
Union
[Tensor
,float
]) – The loss on the training set.val_loss (
Union
[Tensor
,float
]) – The loss on the validation set.
- Return type
None
- val_split: Optional[float]#
- class imitation.regularization.regularizers.RegularizerFactory(*args, **kwargs)[source]#
Bases:
Protocol
[T_Regularizer_co
]Protocol for functions that create regularizers.
The regularizer factory is meant to be used as a way to create a regularizer in two steps. First, the end-user creates a regularizer factory by calling the .create() method of a regularizer class. This allows specifying all the relevant configuration to the regularization algorithm. Then, the network algorithm finishes setting up the optimizer and logger, and calls the regularizer factory to create the regularizer.
This two-step process separates the configuration of the regularization algorithm from additional “operational” parameters. This is useful because it solves two problems:
The end-user does not have access to the optimizer and logger when configuring the regularization algorithm.
Validation of the configuration is done outside the network constructor.
It also allows re-using the same regularizer factory for multiple networks.
- __init__(*args, **kwargs)#
- class imitation.regularization.regularizers.WeightDecayRegularizer(optimizer, initial_lambda, lambda_updater, logger, val_split=None)[source]#
Bases:
WeightRegularizer
Applies weight decay to a loss function.
- lambda_: float#
- lambda_updater: Optional[LambdaUpdater]#
- logger: HierarchicalLogger#
- optimizer: Optimizer#
- val_split: Optional[float]#
- class imitation.regularization.regularizers.WeightRegularizer(optimizer, initial_lambda, lambda_updater, logger, val_split=None)[source]#
Bases:
Regularizer
Abstract base class for regularizers that regularize the weights of a network.
Requires the user to implement the _weight_penalty method.
- lambda_: float#
- lambda_updater: Optional[LambdaUpdater]#
- logger: HierarchicalLogger#
- optimizer: Optimizer#
- regularize_and_backward(loss)[source]#
Regularize the weights of the network, and call
loss.backward()
.- Return type
None
- val_split: Optional[float]#