imitation.data.types#

Types and helper methods for transitions and trajectories.

Functions

assert_not_dictobs(x)

Typeguard to assert x is an array, not a DictObs.

concatenate_maybe_dictobs(arrs)

Concatenates a list of observations appropriately (depending on type).

dataclass_quick_asdict(obj)

Extract dataclass to items using dataclasses.fields + dict comprehension.

map_maybe_dict(fn, maybe_dict)

Either maps fn over dictionary values or applies fn to maybe_dict.

maybe_unwrap_dictobs()

Unwraps if a DictObs, otherwise returns the object.

maybe_wrap_in_dictobs()

Converts an observation into a DictObs, if necessary.

stack_maybe_dictobs(arrs)

Stacks a list of observations appropriately (depending on type).

transitions_collate_fn(batch)

Custom torch.utils.data.DataLoader collate_fn for TransitionsMinimal.

Classes

DictObs(_d)

Stores observations from an environment with a dictionary observation space.

Trajectory(obs, acts, infos, terminal)

A trajectory, e.g.

TrajectoryWithRew(obs, acts, infos, ...)

A Trajectory that additionally includes reward information.

TransitionMapping(*args, **kwargs)

Dictionary with obs and acts, maybe also next_obs, dones, rew.

TransitionMappingNoNextObs(*args, **kwargs)

Dictionary with obs and acts.

Transitions(obs, acts, infos, next_obs, dones)

A batch of obs-act-obs-done transitions.

TransitionsMinimal(obs, acts, infos)

A Torch-compatible Dataset of obs-act transitions.

TransitionsWithRew(obs, acts, infos, ...)

A batch of obs-act-obs-rew-done transitions.

class imitation.data.types.DictObs(_d)[source]#

Bases: object

Stores observations from an environment with a dictionary observation space.

Provides an interface that is similar to observations in a numpy array. Length, slicing, indexing, and iterating operations will operate on the first dimension of the constituent arrays, as they would for observations in a single array.

There are also utility functions for mapping / stacking / concatenating lists of dictobs.

__init__(_d)#
classmethod concatenate(dictobs_list, axis=0)[source]#

Returns a single dictobs concatenating the arrays by key.

Return type

DictObs

property dict_len#

Returns the number of arrays in the DictObs.

property dtype: Dict[str, dtype]#

Returns a dictionary with dtype-tuples in place of the arrays.

Return type

Dict[str, dtype]

classmethod from_obs_list(obs_list)[source]#

Stacks the observation list into a single DictObs.

Return type

DictObs

get(key)[source]#

Returns the array for the given key, or raises KeyError.

Return type

ndarray

items()[source]#
keys()[source]#
map_arrays(fn)[source]#

Returns a new DictObs with fn applied to every array.

Return type

DictObs

property shape: Dict[str, Tuple[int, ...]]#

Returns a dictionary with shape-tuples in place of the arrays.

Return type

Dict[str, Tuple[int, ...]]

classmethod stack(dictobs_list, axis=0)[source]#

Returns a single dictobs stacking the arrays by key.

Return type

DictObs

unwrap()[source]#

Returns a copy of the underlying dictionary (arrays are not copied).

Return type

Dict[str, ndarray]

values()[source]#
class imitation.data.types.Trajectory(obs, acts, infos, terminal)[source]#

Bases: object

A trajectory, e.g. a one episode rollout from an expert policy.

__init__(obs, acts, infos, terminal)#
acts: ndarray#

Actions, shape (trajectory_len, ) + action_shape.

infos: Optional[ndarray]#

An array of info dicts, shape (trajectory_len, ).

The info dict is returned by some environments step() and contains auxiliary diagnostic information. For example the monitor wrapper adds an info dict to the last step of each episode containing the episode return and length.

obs: Union[ndarray, DictObs]#

Observations, shape (trajectory_len + 1, ) + observation_shape.

terminal: bool#

Does this trajectory (fragment) end in a terminal state?

Episodes are always terminal. Trajectory fragments are also terminal when they contain the final state of an episode (even if missing the start of the episode).

class imitation.data.types.TrajectoryWithRew(obs, acts, infos, terminal, rews)[source]#

Bases: Trajectory

A Trajectory that additionally includes reward information.

__init__(obs, acts, infos, terminal, rews)#
rews: ndarray#

Reward, shape (trajectory_len, ). dtype float.

class imitation.data.types.TransitionMapping(*args, **kwargs)[source]#

Bases: dict

Dictionary with obs and acts, maybe also next_obs, dones, rew.

acts: Union[ndarray, Tensor]#
dones: Union[ndarray, Tensor]#
next_obs: Union[ndarray, DictObs, Tensor]#
obs: Union[ndarray, DictObs, Tensor]#
rew: Union[ndarray, Tensor]#
class imitation.data.types.TransitionMappingNoNextObs(*args, **kwargs)[source]#

Bases: dict

Dictionary with obs and acts.

acts: Union[ndarray, Tensor]#
obs: Union[ndarray, DictObs, Tensor]#
class imitation.data.types.Transitions(obs, acts, infos, next_obs, dones)[source]#

Bases: TransitionsMinimal

A batch of obs-act-obs-done transitions.

__init__(obs, acts, infos, next_obs, dones)#
dones: ndarray#

(batch_size, ).

done[i] is true iff next_obs[i] the last observation of an episode.

Type

Boolean array indicating episode termination. Shape

next_obs: Union[ndarray, DictObs]#

(batch_size, ) + observation_shape.

The i’th observation next_obs[i] in this array is the observation after the agent has taken action acts[i].

Invariants:
  • next_obs.dtype == obs.dtype

  • len(next_obs) == len(obs)

Type

New observation. Shape

class imitation.data.types.TransitionsMinimal(obs, acts, infos)[source]#

Bases: Dataset, Sequence[Mapping[str, ndarray]]

A Torch-compatible Dataset of obs-act transitions.

This class and its subclasses are usually instantiated via imitation.data.rollout.flatten_trajectories.

Indexing an instance trans of TransitionsMinimal with an integer i returns the i`th `Dict[str, np.ndarray] sample, whose keys are the field names of each dataclass field and whose values are the ith elements of each field value.

Slicing returns a possibly empty instance of TransitionsMinimal where each field has been sliced.

__init__(obs, acts, infos)#
acts: ndarray#

(batch_size,) + action_shape.

Type

Actions. Shape

infos: ndarray#

(batch_size,).

Type

Array of info dicts. Shape

obs: Union[ndarray, DictObs]#

(batch_size, ) + observation_shape.

The i’th observation obs[i] in this array is the observation seen by the agent when choosing action acts[i]. obs[i] is not required to be from the timestep preceding obs[i+1].

Type

Previous observations. Shape

class imitation.data.types.TransitionsWithRew(obs, acts, infos, next_obs, dones, rews)[source]#

Bases: Transitions

A batch of obs-act-obs-rew-done transitions.

__init__(obs, acts, infos, next_obs, dones, rews)#
rews: ndarray#

(batch_size, ). dtype float.

The reward rew[i] at the i’th timestep is received after the agent has taken action acts[i].

Type

Reward. Shape

imitation.data.types.assert_not_dictobs(x)[source]#

Typeguard to assert x is an array, not a DictObs.

Return type

ndarray

imitation.data.types.concatenate_maybe_dictobs(arrs)[source]#

Concatenates a list of observations appropriately (depending on type).

Return type

TypeVar(ObsVar, ndarray, DictObs)

imitation.data.types.dataclass_quick_asdict(obj)[source]#

Extract dataclass to items using dataclasses.fields + dict comprehension.

This is a quick alternative to dataclasses.asdict, which expensively and undocumentedly deep-copies every numpy array value. See https://stackoverflow.com/a/52229565/1091722.

This is also used to preserve DictObj objects, as dataclasses.asdict unwraps them recursively.

Parameters

obj – A dataclass instance.

Return type

Dict[str, Any]

Returns

A dictionary mapping from obj field names to values.

imitation.data.types.map_maybe_dict(fn, maybe_dict)[source]#

Either maps fn over dictionary values or applies fn to maybe_dict.

Parameters
  • fn – function to apply. Must take a single argument.

  • maybe_dict – either a dict or a value that can be passed to fn.

Returns

Either a dict (if maybe_dict was a dict) or fn(maybe_dict).

imitation.data.types.maybe_unwrap_dictobs(maybe_dictobs: DictObs) Dict[str, ndarray][source]#
imitation.data.types.maybe_unwrap_dictobs(maybe_dictobs: T) T

Unwraps if a DictObs, otherwise returns the object.

imitation.data.types.maybe_wrap_in_dictobs(obs: Union[Dict[str, ndarray], DictObs]) DictObs[source]#
imitation.data.types.maybe_wrap_in_dictobs(obs: ndarray) ndarray

Converts an observation into a DictObs, if necessary.

Return type

Union[ndarray, DictObs]

imitation.data.types.stack_maybe_dictobs(arrs)[source]#

Stacks a list of observations appropriately (depending on type).

Return type

TypeVar(ObsVar, ndarray, DictObs)

imitation.data.types.transitions_collate_fn(batch)[source]#

Custom torch.utils.data.DataLoader collate_fn for TransitionsMinimal.

Use this as the collate_fn argument to DataLoader if using an instance of TransitionsMinimal as the dataset argument.

Parameters

batch (Sequence[Mapping[str, ndarray]]) – The batch to collate.

Return type

Mapping[str, Union[ndarray, Tensor]]

Returns

A collated batch. Uses Torch’s default collate function for everything except the “infos” key. For “infos”, we join all the info dicts into a list of dicts. (The default behavior would recursively collate every info dict into a single dict, which is incorrect.)