Source code for imitation.data.types

"""Types and helper methods for transitions and trajectories."""

import dataclasses
import os
import warnings
from typing import (
    Any,
    Dict,
    Mapping,
    Optional,
    Sequence,
    Tuple,
    TypeVar,
    Union,
    overload,
)

import numpy as np
import torch as th
from torch.utils import data as th_data

T = TypeVar("T")

AnyPath = Union[str, bytes, os.PathLike]


[docs]def dataclass_quick_asdict(obj) -> Dict[str, Any]: """Extract dataclass to items using `dataclasses.fields` + dict comprehension. This is a quick alternative to `dataclasses.asdict`, which expensively and undocumentedly deep-copies every numpy array value. See https://stackoverflow.com/a/52229565/1091722. Args: obj: A dataclass instance. Returns: A dictionary mapping from `obj` field names to values. """ d = {f.name: getattr(obj, f.name) for f in dataclasses.fields(obj)} return d
[docs]@dataclasses.dataclass(frozen=True) class Trajectory: """A trajectory, e.g. a one episode rollout from an expert policy.""" obs: np.ndarray """Observations, shape (trajectory_len + 1, ) + observation_shape.""" acts: np.ndarray """Actions, shape (trajectory_len, ) + action_shape.""" infos: Optional[np.ndarray] """An array of info dicts, shape (trajectory_len, ). The info dict is returned by some environments `step()` and contains auxiliary diagnostic information. For example the monitor wrapper adds an info dict to the last step of each episode containing the episode return and length. """ terminal: bool """Does this trajectory (fragment) end in a terminal state? Episodes are always terminal. Trajectory fragments are also terminal when they contain the final state of an episode (even if missing the start of the episode). """ def __len__(self) -> int: """Returns number of transitions, equal to the number of actions.""" return len(self.acts) def __eq__(self, other) -> bool: if not isinstance(other, Trajectory): return False dict_self, dict_other = dataclasses.asdict(self), dataclasses.asdict(other) # Trajectory objects may still have different keys if different subclasses if dict_self.keys() != dict_other.keys(): return False if len(self) != len(other): # Short-circuit: if trajectories are of different length, then unequal. # Redundant as later checks would catch this, but speeds up common case. return False for k, self_v in dict_self.items(): other_v = dict_other[k] if k == "infos": # Treat None equivalent to sequence of empty dicts self_v = [{}] * len(self) if self_v is None else self_v other_v = [{}] * len(other) if other_v is None else other_v if not np.array_equal(self_v, other_v): return False return True def __post_init__(self): """Performs input validation: check shapes are as specified in docstring.""" if len(self.obs) != len(self.acts) + 1: raise ValueError( "expected one more observations than actions: " f"{len(self.obs)} != {len(self.acts)} + 1", ) if self.infos is not None and len(self.infos) != len(self.acts): raise ValueError( "infos when present must be present for each action: " f"{len(self.infos)} != {len(self.acts)}", ) if len(self.acts) == 0: raise ValueError("Degenerate trajectory: must have at least one action.") def __setstate__(self, state): if "terminal" not in state: warnings.warn( "Loading old version of Trajectory." "Support for this will be removed in future versions.", DeprecationWarning, ) state["terminal"] = True self.__dict__.update(state)
def _rews_validation(rews: np.ndarray, acts: np.ndarray): if rews.shape != (len(acts),): raise ValueError( "rewards must be 1D array, one entry for each action: " f"{rews.shape} != ({len(acts)},)", ) if not np.issubdtype(rews.dtype, np.floating): raise ValueError(f"rewards dtype {rews.dtype} not a float")
[docs]@dataclasses.dataclass(frozen=True, eq=False) class TrajectoryWithRew(Trajectory): """A `Trajectory` that additionally includes reward information.""" rews: np.ndarray """Reward, shape (trajectory_len, ). dtype float.""" def __post_init__(self): """Performs input validation, including for rews.""" super().__post_init__() _rews_validation(self.rews, self.acts)
Pair = Tuple[T, T] TrajectoryPair = Pair[Trajectory] TrajectoryWithRewPair = Pair[TrajectoryWithRew]
[docs]def transitions_collate_fn( batch: Sequence[Mapping[str, np.ndarray]], ) -> Mapping[str, Union[np.ndarray, th.Tensor]]: """Custom `torch.utils.data.DataLoader` collate_fn for `TransitionsMinimal`. Use this as the `collate_fn` argument to `DataLoader` if using an instance of `TransitionsMinimal` as the `dataset` argument. Args: batch: The batch to collate. Returns: A collated batch. Uses Torch's default collate function for everything except the "infos" key. For "infos", we join all the info dicts into a list of dicts. (The default behavior would recursively collate every info dict into a single dict, which is incorrect.) """ batch_no_infos = [ {k: np.array(v) for k, v in sample.items() if k != "infos"} for sample in batch ] result = th_data.dataloader.default_collate(batch_no_infos) assert isinstance(result, dict) result["infos"] = [sample["infos"] for sample in batch] return result
TransitionsMinimalSelf = TypeVar("TransitionsMinimalSelf", bound="TransitionsMinimal")
[docs]@dataclasses.dataclass(frozen=True) class TransitionsMinimal(th_data.Dataset, Sequence[Mapping[str, np.ndarray]]): """A Torch-compatible `Dataset` of obs-act transitions. This class and its subclasses are usually instantiated via `imitation.data.rollout.flatten_trajectories`. Indexing an instance `trans` of TransitionsMinimal with an integer `i` returns the `i`th `Dict[str, np.ndarray]` sample, whose keys are the field names of each dataclass field and whose values are the ith elements of each field value. Slicing returns a possibly empty instance of `TransitionsMinimal` where each field has been sliced. """ obs: np.ndarray """ Previous observations. Shape: (batch_size, ) + observation_shape. The i'th observation `obs[i]` in this array is the observation seen by the agent when choosing action `acts[i]`. `obs[i]` is not required to be from the timestep preceding `obs[i+1]`. """ acts: np.ndarray """Actions. Shape: (batch_size,) + action_shape.""" infos: np.ndarray """Array of info dicts. Shape: (batch_size,).""" def __len__(self) -> int: """Returns number of transitions. Always positive.""" return len(self.obs) def __post_init__(self): """Performs input validation: check shapes & dtypes match docstring. Also make array values read-only. Raises: ValueError: if batch size (array length) is inconsistent between `obs`, `acts` and `infos`. """ for val in vars(self).values(): if isinstance(val, np.ndarray): val.setflags(write=False) if len(self.obs) != len(self.acts): raise ValueError( "obs and acts must have same number of timesteps: " f"{len(self.obs)} != {len(self.acts)}", ) if len(self.infos) != len(self.obs): raise ValueError( "obs and infos must have same number of timesteps: " f"{len(self.obs)} != {len(self.infos)}", ) # TODO(adam): uncomment below once pytype bug fixed in # issue https://github.com/google/pytype/issues/1108 # @overload # def __getitem__(self: T, key: slice) -> T: # pass # pragma: no cover # # @overload # def __getitem__(self, key: int) -> Mapping[str, np.ndarray]: # pass # pragma: no cover @overload def __getitem__(self, key: int) -> Mapping[str, np.ndarray]: pass @overload def __getitem__(self: TransitionsMinimalSelf, key: slice) -> TransitionsMinimalSelf: pass def __getitem__(self, key): """See TransitionsMinimal docstring for indexing and slicing semantics.""" d = dataclass_quick_asdict(self) d_item = {k: v[key] for k, v in d.items()} if isinstance(key, slice): # Return type is the same as this dataclass. Replace field value with # slices. return dataclasses.replace(self, **d_item) else: assert isinstance(key, int) # Return type is a dictionary. Array values have no batch dimension. # # Dictionary of np.ndarray values is a convenient # torch.util.data.Dataset return type, as a torch.util.data.DataLoader # taking in this `Dataset` as its first argument knows how to # automatically concatenate several dictionaries together to make # a single dictionary batch with `torch.Tensor` values. return d_item
[docs]@dataclasses.dataclass(frozen=True) class Transitions(TransitionsMinimal): """A batch of obs-act-obs-done transitions.""" next_obs: np.ndarray """New observation. Shape: (batch_size, ) + observation_shape. The i'th observation `next_obs[i]` in this array is the observation after the agent has taken action `acts[i]`. Invariants: * `next_obs.dtype == obs.dtype` * `len(next_obs) == len(obs)` """ dones: np.ndarray """ Boolean array indicating episode termination. Shape: (batch_size, ). `done[i]` is true iff `next_obs[i]` the last observation of an episode. """ def __post_init__(self): """Performs input validation: check shapes & dtypes match docstring.""" super().__post_init__() if self.obs.shape != self.next_obs.shape: raise ValueError( "obs and next_obs must have same shape: " f"{self.obs.shape} != {self.next_obs.shape}", ) if self.obs.dtype != self.next_obs.dtype: raise ValueError( "obs and next_obs must have the same dtype: " f"{self.obs.dtype} != {self.next_obs.dtype}", ) if self.dones.shape != (len(self.acts),): raise ValueError( "dones must be 1D array, one entry for each timestep: " f"{self.dones.shape} != ({len(self.acts)},)", ) if self.dones.dtype != bool: raise ValueError(f"dones must be boolean, not {self.dones.dtype}")
[docs]@dataclasses.dataclass(frozen=True) class TransitionsWithRew(Transitions): """A batch of obs-act-obs-rew-done transitions.""" rews: np.ndarray """ Reward. Shape: (batch_size, ). dtype float. The reward `rew[i]` at the i'th timestep is received after the agent has taken action `acts[i]`. """ def __post_init__(self): """Performs input validation, including for rews.""" super().__post_init__() _rews_validation(self.rews, self.acts)