Source code for imitation.algorithms.base

"""Module of base classes and helper methods for imitation learning algorithms."""

import abc
from typing import (
    Any,
    Generic,
    Iterable,
    Iterator,
    Mapping,
    Optional,
    TypeVar,
    Union,
    cast,
)

import torch.utils.data as th_data
from stable_baselines3.common import policies

from imitation.data import rollout, types
from imitation.util import logger as imit_logger
from imitation.util import util


[docs]class BaseImitationAlgorithm(abc.ABC): """Base class for all imitation learning algorithms.""" _logger: imit_logger.HierarchicalLogger """Object to log statistics and natural language messages to.""" allow_variable_horizon: bool """If True, allow variable horizon trajectories; otherwise error if detected.""" _horizon: Optional[int] """Horizon of trajectories seen so far (None if no trajectories seen)."""
[docs] def __init__( self, *, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, allow_variable_horizon: bool = False, ): """Creates an imitation learning algorithm. Args: custom_logger: Where to log to; if None (default), creates a new logger. allow_variable_horizon: If False (default), algorithm will raise an exception if it detects trajectories of different length during training. If True, overrides this safety check. WARNING: variable horizon episodes leak information about the reward via termination condition, and can seriously confound evaluation. Read https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html before overriding this. """ self._logger = custom_logger or imit_logger.configure() self.allow_variable_horizon = allow_variable_horizon if allow_variable_horizon: self.logger.warn( "Running with `allow_variable_horizon` set to True. " "Some algorithms are biased towards shorter or longer " "episodes, which may significantly confound results. " "Additionally, even unbiased algorithms can exploit " "the information leak from the termination condition, " "producing spuriously high performance. See " "https://imitation.readthedocs.io/en/latest/getting-started/" "variable-horizon.html for more information.", ) self._horizon = None
@property def logger(self) -> imit_logger.HierarchicalLogger: return self._logger @logger.setter def logger(self, value: imit_logger.HierarchicalLogger) -> None: self._logger = value def _check_fixed_horizon(self, horizons: Iterable[int]) -> None: """Checks that episode lengths in `horizons` are fixed and equal to prior calls. If algorithm is safe to use with variable horizon episodes (e.g. behavioral cloning), then just don't call this method. Args: horizons: An iterable sequence of episode lengths. Raises: ValueError: The length of trajectories in trajs differs from one another, or from trajectory lengths in previous calls to this method. """ if self.allow_variable_horizon: # skip check -- YOLO return # horizons = all horizons seen so far (including trajs) horizons = set(horizons) if self._horizon is not None: horizons.add(self._horizon) if len(horizons) > 1: raise ValueError( f"Episodes of different length detected: {horizons}. " "Variable horizon environments are discouraged -- " "termination conditions leak information about reward. See " "https://imitation.readthedocs.io/en/latest/getting-started/" "variable-horizon.html for more information. " "If you are SURE you want to run imitation on a " "variable horizon task, then please pass in the flag: " "`allow_variable_horizon=True`.", ) elif len(horizons) == 1: self._horizon = horizons.pop() def __getstate__(self): state = self.__dict__.copy() # logger can't be pickled as it depends on open files del state["_logger"] return state def __setstate__(self, state): self.__dict__.update(state) # callee should modify self.logger directly if they want to override this self.logger = state.get("_logger") or imit_logger.configure()
TransitionKind = TypeVar("TransitionKind", bound=types.TransitionsMinimal) AnyTransitions = Union[ Iterable[types.Trajectory], Iterable[types.TransitionMapping], types.TransitionsMinimal, ]
[docs]class DemonstrationAlgorithm(BaseImitationAlgorithm, Generic[TransitionKind]): """An algorithm that learns from demonstration: BC, IRL, etc."""
[docs] def __init__( self, *, demonstrations: Optional[AnyTransitions], custom_logger: Optional[imit_logger.HierarchicalLogger] = None, allow_variable_horizon: bool = False, ): """Creates an algorithm that learns from demonstrations. Args: demonstrations: Demonstrations from an expert (optional). Transitions expressed directly as a `types.TransitionsMinimal` object, a sequence of trajectories, or an iterable of transition batches (mappings from keywords to arrays containing observations, etc). custom_logger: Where to log to; if None (default), creates a new logger. allow_variable_horizon: If False (default), algorithm will raise an exception if it detects trajectories of different length during training. If True, overrides this safety check. WARNING: variable horizon episodes leak information about the reward via termination condition, and can seriously confound evaluation. Read https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html before overriding this. """ super().__init__( custom_logger=custom_logger, allow_variable_horizon=allow_variable_horizon, ) if demonstrations is not None: self.set_demonstrations(demonstrations)
[docs] @abc.abstractmethod def set_demonstrations(self, demonstrations: AnyTransitions) -> None: """Sets the demonstration data. Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger. Args: demonstrations: Either a Torch `DataLoader`, any other iterator that yields dictionaries containing "obs" and "acts" Tensors or NumPy arrays, `TransitionKind` instance, or a Sequence of Trajectory objects. """
@property @abc.abstractmethod def policy(self) -> policies.BasePolicy: """Returns a policy imitating the demonstration data."""
class _WrappedDataLoader: """Wraps a data loader (batch iterable) and checks for specified batch size.""" def __init__( self, data_loader: Iterable[types.TransitionMapping], expected_batch_size: int, ): """Builds _WrappedDataLoader. Args: data_loader: The data loader (batch iterable) to wrap. expected_batch_size: The batch size to check for. """ self.data_loader = data_loader self.expected_batch_size = expected_batch_size def __iter__(self) -> Iterator[types.TransitionMapping]: """Yields data from `self.data_loader`, checking `self.expected_batch_size`. Yields: Identity -- yields same batches as from `self.data_loader`. Raises: ValueError: `self.data_loader` returns a batch of size not equal to `self.expected_batch_size`. """ for batch in self.data_loader: if len(batch["obs"]) != self.expected_batch_size: raise ValueError( f"Expected batch size {self.expected_batch_size} " f"!= {len(batch['obs'])} = len(batch['obs'])", ) if len(batch["acts"]) != self.expected_batch_size: raise ValueError( f"Expected batch size {self.expected_batch_size} " f"!= {len(batch['acts'])} = len(batch['acts'])", ) yield batch
[docs]def make_data_loader( transitions: AnyTransitions, batch_size: int, data_loader_kwargs: Optional[Mapping[str, Any]] = None, ) -> Iterable[types.TransitionMapping]: """Converts demonstration data to Torch data loader. Args: transitions: Transitions expressed directly as a `types.TransitionsMinimal` object, a sequence of trajectories, or an iterable of transition batches (mappings from keywords to arrays containing observations, etc). batch_size: The size of the batch to create. Does not change the batch size if `transitions` is already an iterable of transition batches. data_loader_kwargs: Arguments to pass to `th_data.DataLoader`. Returns: An iterable of transition batches. Raises: ValueError: if `transitions` is an iterable over transition batches with batch size not equal to `batch_size`; or if `transitions` is transitions or a sequence of trajectories with total timesteps less than `batch_size`. TypeError: if `transitions` is an unsupported type. """ if batch_size <= 0: raise ValueError(f"batch_size={batch_size} must be positive.") if isinstance(transitions, Iterable): # Inferring the correct type here is difficult with generics. ( first_item, transitions, ) = util.get_first_iter_element( # type: ignore[assignment] transitions, ) if isinstance(first_item, types.Trajectory): transitions = cast(Iterable[types.Trajectory], transitions) transitions = rollout.flatten_trajectories(list(transitions)) if isinstance(transitions, types.TransitionsMinimal): if len(transitions) < batch_size: raise ValueError( f"Number of transitions in `demonstrations` {len(transitions)} " f"is smaller than batch size {batch_size}.", ) kwargs: Mapping[str, Any] = { "shuffle": True, "drop_last": True, **(data_loader_kwargs or {}), } return th_data.DataLoader( transitions, batch_size=batch_size, collate_fn=types.transitions_collate_fn, **kwargs, ) elif isinstance(transitions, Iterable): # Safe to ignore this error since we've already converted Iterable[Trajectory] # `transitions` into Iterable[TransitionMapping] return _WrappedDataLoader(transitions, batch_size) # type: ignore[arg-type] else: raise TypeError(f"`demonstrations` unexpected type {type(transitions)}")