imitation.algorithms.base#

Module of base classes and helper methods for imitation learning algorithms.

Functions

make_data_loader(transitions, batch_size[, ...])

Converts demonstration data to Torch data loader.

Classes

BaseImitationAlgorithm(*[, custom_logger, ...])

Base class for all imitation learning algorithms.

DemonstrationAlgorithm(*, demonstrations[, ...])

An algorithm that learns from demonstration: BC, IRL, etc.

class imitation.algorithms.base.BaseImitationAlgorithm(*, custom_logger=None, allow_variable_horizon=False)[source]#

Bases: ABC

Base class for all imitation learning algorithms.

__init__(*, custom_logger=None, allow_variable_horizon=False)[source]#

Creates an imitation learning algorithm.

Parameters
  • custom_logger (Optional[HierarchicalLogger]) – Where to log to; if None (default), creates a new logger.

  • allow_variable_horizon (bool) – If False (default), algorithm will raise an exception if it detects trajectories of different length during training. If True, overrides this safety check. WARNING: variable horizon episodes leak information about the reward via termination condition, and can seriously confound evaluation. Read https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html before overriding this.

allow_variable_horizon: bool#

If True, allow variable horizon trajectories; otherwise error if detected.

property logger: HierarchicalLogger#
Return type

HierarchicalLogger

class imitation.algorithms.base.DemonstrationAlgorithm(*, demonstrations, custom_logger=None, allow_variable_horizon=False)[source]#

Bases: BaseImitationAlgorithm, Generic[TransitionKind]

An algorithm that learns from demonstration: BC, IRL, etc.

__init__(*, demonstrations, custom_logger=None, allow_variable_horizon=False)[source]#

Creates an algorithm that learns from demonstrations.

Parameters
  • demonstrations (Union[Iterable[Trajectory], Iterable[TransitionMapping], TransitionsMinimal, None]) – Demonstrations from an expert (optional). Transitions expressed directly as a types.TransitionsMinimal object, a sequence of trajectories, or an iterable of transition batches (mappings from keywords to arrays containing observations, etc).

  • custom_logger (Optional[HierarchicalLogger]) – Where to log to; if None (default), creates a new logger.

  • allow_variable_horizon (bool) – If False (default), algorithm will raise an exception if it detects trajectories of different length during training. If True, overrides this safety check. WARNING: variable horizon episodes leak information about the reward via termination condition, and can seriously confound evaluation. Read https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html before overriding this.

allow_variable_horizon: bool#

If True, allow variable horizon trajectories; otherwise error if detected.

abstract property policy: BasePolicy#

Returns a policy imitating the demonstration data.

Return type

BasePolicy

abstract set_demonstrations(demonstrations)[source]#

Sets the demonstration data.

Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger.

Parameters

demonstrations (Union[Iterable[Trajectory], Iterable[TransitionMapping], TransitionsMinimal]) – Either a Torch DataLoader, any other iterator that yields dictionaries containing “obs” and “acts” Tensors or NumPy arrays, TransitionKind instance, or a Sequence of Trajectory objects.

Return type

None

imitation.algorithms.base.make_data_loader(transitions, batch_size, data_loader_kwargs=None)[source]#

Converts demonstration data to Torch data loader.

Parameters
  • transitions (Union[Iterable[Trajectory], Iterable[TransitionMapping], TransitionsMinimal]) – Transitions expressed directly as a types.TransitionsMinimal object, a sequence of trajectories, or an iterable of transition batches (mappings from keywords to arrays containing observations, etc).

  • batch_size (int) – The size of the batch to create. Does not change the batch size if transitions is already an iterable of transition batches.

  • data_loader_kwargs (Optional[Mapping[str, Any]]) – Arguments to pass to th_data.DataLoader.

Return type

Iterable[TransitionMapping]

Returns

An iterable of transition batches.

Raises
  • ValueError – if transitions is an iterable over transition batches with batch size not equal to batch_size; or if transitions is transitions or a sequence of trajectories with total timesteps less than batch_size.

  • TypeError – if transitions is an unsupported type.