imitation.data.types#
Types and helper methods for transitions and trajectories.
Functions
Typeguard to assert x is an array, not a DictObs. |
|
Concatenates a list of observations appropriately (depending on type). |
|
Extract dataclass to items using dataclasses.fields + dict comprehension. |
|
|
Either maps fn over dictionary values or applies fn to maybe_dict. |
Unwraps if a DictObs, otherwise returns the object. |
|
Converts an observation into a DictObs, if necessary. |
|
|
Stacks a list of observations appropriately (depending on type). |
|
Custom torch.utils.data.DataLoader collate_fn for TransitionsMinimal. |
Classes
|
Stores observations from an environment with a dictionary observation space. |
|
A trajectory, e.g. |
|
A Trajectory that additionally includes reward information. |
|
Dictionary with obs and acts, maybe also next_obs, dones, rew. |
|
Dictionary with obs and acts. |
|
A batch of obs-act-obs-done transitions. |
|
A Torch-compatible Dataset of obs-act transitions. |
|
A batch of obs-act-obs-rew-done transitions. |
- class imitation.data.types.DictObs(_d)[source]#
Bases:
objectStores observations from an environment with a dictionary observation space.
Provides an interface that is similar to observations in a numpy array. Length, slicing, indexing, and iterating operations will operate on the first dimension of the constituent arrays, as they would for observations in a single array.
There are also utility functions for mapping / stacking / concatenating lists of dictobs.
- __init__(_d)#
- classmethod concatenate(dictobs_list, axis=0)[source]#
Returns a single dictobs concatenating the arrays by key.
- Return type
- property dict_len#
Returns the number of arrays in the DictObs.
- property dtype: Dict[str, dtype]#
Returns a dictionary with dtype-tuples in place of the arrays.
- Return type
Dict[str,dtype]
- classmethod from_obs_list(obs_list)[source]#
Stacks the observation list into a single DictObs.
- Return type
- property shape: Dict[str, Tuple[int, ...]]#
Returns a dictionary with shape-tuples in place of the arrays.
- Return type
Dict[str,Tuple[int,...]]
- classmethod stack(dictobs_list, axis=0)[source]#
Returns a single dictobs stacking the arrays by key.
- Return type
- class imitation.data.types.Trajectory(obs, acts, infos, terminal)[source]#
Bases:
objectA trajectory, e.g. a one episode rollout from an expert policy.
- __init__(obs, acts, infos, terminal)#
- acts: ndarray#
Actions, shape (trajectory_len, ) + action_shape.
- infos: Optional[ndarray]#
An array of info dicts, shape (trajectory_len, ).
The info dict is returned by some environments step() and contains auxiliary diagnostic information. For example the monitor wrapper adds an info dict to the last step of each episode containing the episode return and length.
- terminal: bool#
Does this trajectory (fragment) end in a terminal state?
Episodes are always terminal. Trajectory fragments are also terminal when they contain the final state of an episode (even if missing the start of the episode).
- class imitation.data.types.TrajectoryWithRew(obs, acts, infos, terminal, rews)[source]#
Bases:
TrajectoryA Trajectory that additionally includes reward information.
- __init__(obs, acts, infos, terminal, rews)#
- rews: ndarray#
Reward, shape (trajectory_len, ). dtype float.
- class imitation.data.types.TransitionMapping(*args, **kwargs)[source]#
Bases:
dictDictionary with obs and acts, maybe also next_obs, dones, rew.
- acts: Union[ndarray, Tensor]#
- dones: Union[ndarray, Tensor]#
- rew: Union[ndarray, Tensor]#
- class imitation.data.types.TransitionMappingNoNextObs(*args, **kwargs)[source]#
Bases:
dictDictionary with obs and acts.
- acts: Union[ndarray, Tensor]#
- class imitation.data.types.Transitions(obs, acts, infos, next_obs, dones)[source]#
Bases:
TransitionsMinimalA batch of obs-act-obs-done transitions.
- __init__(obs, acts, infos, next_obs, dones)#
- dones: ndarray#
(batch_size, ).
done[i] is true iff next_obs[i] the last observation of an episode.
- Type
Boolean array indicating episode termination. Shape
- class imitation.data.types.TransitionsMinimal(obs, acts, infos)[source]#
Bases:
Dataset,Sequence[Mapping[str,ndarray]]A Torch-compatible Dataset of obs-act transitions.
This class and its subclasses are usually instantiated via imitation.data.rollout.flatten_trajectories.
Indexing an instance trans of TransitionsMinimal with an integer i returns the i`th `Dict[str, np.ndarray] sample, whose keys are the field names of each dataclass field and whose values are the ith elements of each field value.
Slicing returns a possibly empty instance of TransitionsMinimal where each field has been sliced.
- __init__(obs, acts, infos)#
- acts: ndarray#
(batch_size,) + action_shape.
- Type
Actions. Shape
- infos: ndarray#
(batch_size,).
- Type
Array of info dicts. Shape
- class imitation.data.types.TransitionsWithRew(obs, acts, infos, next_obs, dones, rews)[source]#
Bases:
TransitionsA batch of obs-act-obs-rew-done transitions.
- __init__(obs, acts, infos, next_obs, dones, rews)#
- rews: ndarray#
(batch_size, ). dtype float.
The reward rew[i] at the i’th timestep is received after the agent has taken action acts[i].
- Type
Reward. Shape
- imitation.data.types.assert_not_dictobs(x)[source]#
Typeguard to assert x is an array, not a DictObs.
- Return type
ndarray
- imitation.data.types.concatenate_maybe_dictobs(arrs)[source]#
Concatenates a list of observations appropriately (depending on type).
- Return type
TypeVar(ObsVar,ndarray,DictObs)
- imitation.data.types.dataclass_quick_asdict(obj)[source]#
Extract dataclass to items using dataclasses.fields + dict comprehension.
This is a quick alternative to dataclasses.asdict, which expensively and undocumentedly deep-copies every numpy array value. See https://stackoverflow.com/a/52229565/1091722.
This is also used to preserve DictObj objects, as dataclasses.asdict unwraps them recursively.
- Parameters
obj – A dataclass instance.
- Return type
Dict[str,Any]- Returns
A dictionary mapping from obj field names to values.
- imitation.data.types.map_maybe_dict(fn, maybe_dict)[source]#
Either maps fn over dictionary values or applies fn to maybe_dict.
- Parameters
fn – function to apply. Must take a single argument.
maybe_dict – either a dict or a value that can be passed to fn.
- Returns
Either a dict (if maybe_dict was a dict) or fn(maybe_dict).
- imitation.data.types.maybe_unwrap_dictobs(maybe_dictobs: DictObs) Dict[str, ndarray][source]#
- imitation.data.types.maybe_unwrap_dictobs(maybe_dictobs: T) T
Unwraps if a DictObs, otherwise returns the object.
- imitation.data.types.maybe_wrap_in_dictobs(obs: Union[Dict[str, ndarray], DictObs]) DictObs[source]#
- imitation.data.types.maybe_wrap_in_dictobs(obs: ndarray) ndarray
Converts an observation into a DictObs, if necessary.
- Return type
Union[ndarray,DictObs]
- imitation.data.types.stack_maybe_dictobs(arrs)[source]#
Stacks a list of observations appropriately (depending on type).
- Return type
TypeVar(ObsVar,ndarray,DictObs)
- imitation.data.types.transitions_collate_fn(batch)[source]#
Custom torch.utils.data.DataLoader collate_fn for TransitionsMinimal.
Use this as the collate_fn argument to DataLoader if using an instance of TransitionsMinimal as the dataset argument.
- Parameters
batch (
Sequence[Mapping[str,ndarray]]) – The batch to collate.- Return type
Mapping[str,Union[ndarray,Tensor]]- Returns
A collated batch. Uses Torch’s default collate function for everything except the “infos” key. For “infos”, we join all the info dicts into a list of dicts. (The default behavior would recursively collate every info dict into a single dict, which is incorrect.)