imitation.algorithms.base#
Module of base classes and helper methods for imitation learning algorithms.
Functions
|
Converts demonstration data to Torch data loader. |
Classes
|
Base class for all imitation learning algorithms. |
|
An algorithm that learns from demonstration: BC, IRL, etc. |
- class imitation.algorithms.base.BaseImitationAlgorithm(*, custom_logger=None, allow_variable_horizon=False)[source]#
Bases:
ABC
Base class for all imitation learning algorithms.
- __init__(*, custom_logger=None, allow_variable_horizon=False)[source]#
Creates an imitation learning algorithm.
- Parameters
custom_logger (
Optional
[HierarchicalLogger
]) – Where to log to; if None (default), creates a new logger.allow_variable_horizon (
bool
) – If False (default), algorithm will raise an exception if it detects trajectories of different length during training. If True, overrides this safety check. WARNING: variable horizon episodes leak information about the reward via termination condition, and can seriously confound evaluation. Read https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html before overriding this.
- allow_variable_horizon: bool#
If True, allow variable horizon trajectories; otherwise error if detected.
- property logger: HierarchicalLogger#
- Return type
- class imitation.algorithms.base.DemonstrationAlgorithm(*, demonstrations, custom_logger=None, allow_variable_horizon=False)[source]#
Bases:
BaseImitationAlgorithm
,Generic
[TransitionKind
]An algorithm that learns from demonstration: BC, IRL, etc.
- __init__(*, demonstrations, custom_logger=None, allow_variable_horizon=False)[source]#
Creates an algorithm that learns from demonstrations.
- Parameters
demonstrations (
Union
[Iterable
[Trajectory
],Iterable
[TransitionMapping
],TransitionsMinimal
,None
]) – Demonstrations from an expert (optional). Transitions expressed directly as a types.TransitionsMinimal object, a sequence of trajectories, or an iterable of transition batches (mappings from keywords to arrays containing observations, etc).custom_logger (
Optional
[HierarchicalLogger
]) – Where to log to; if None (default), creates a new logger.allow_variable_horizon (
bool
) – If False (default), algorithm will raise an exception if it detects trajectories of different length during training. If True, overrides this safety check. WARNING: variable horizon episodes leak information about the reward via termination condition, and can seriously confound evaluation. Read https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html before overriding this.
- allow_variable_horizon: bool#
If True, allow variable horizon trajectories; otherwise error if detected.
- abstract property policy: BasePolicy#
Returns a policy imitating the demonstration data.
- Return type
BasePolicy
- abstract set_demonstrations(demonstrations)[source]#
Sets the demonstration data.
Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger.
- Parameters
demonstrations (
Union
[Iterable
[Trajectory
],Iterable
[TransitionMapping
],TransitionsMinimal
]) – Either a Torch DataLoader, any other iterator that yields dictionaries containing “obs” and “acts” Tensors or NumPy arrays, TransitionKind instance, or a Sequence of Trajectory objects.- Return type
None
- imitation.algorithms.base.make_data_loader(transitions, batch_size, data_loader_kwargs=None)[source]#
Converts demonstration data to Torch data loader.
- Parameters
transitions (
Union
[Iterable
[Trajectory
],Iterable
[TransitionMapping
],TransitionsMinimal
]) – Transitions expressed directly as a types.TransitionsMinimal object, a sequence of trajectories, or an iterable of transition batches (mappings from keywords to arrays containing observations, etc).batch_size (
int
) – The size of the batch to create. Does not change the batch size if transitions is already an iterable of transition batches.data_loader_kwargs (
Optional
[Mapping
[str
,Any
]]) – Arguments to pass to th_data.DataLoader.
- Return type
Iterable
[TransitionMapping
]- Returns
An iterable of transition batches.
- Raises
ValueError – if transitions is an iterable over transition batches with batch size not equal to batch_size; or if transitions is transitions or a sequence of trajectories with total timesteps less than batch_size.
TypeError – if transitions is an unsupported type.