imitation.data.types#

Types and helper methods for transitions and trajectories.

Functions

dataclass_quick_asdict(obj)

Extract dataclass to items using dataclasses.fields + dict comprehension.

transitions_collate_fn(batch)

Custom torch.utils.data.DataLoader collate_fn for TransitionsMinimal.

Classes

Trajectory(obs, acts, infos, terminal)

A trajectory, e.g.

TrajectoryWithRew(obs, acts, infos, ...)

A Trajectory that additionally includes reward information.

Transitions(obs, acts, infos, next_obs, dones)

A batch of obs-act-obs-done transitions.

TransitionsMinimal(obs, acts, infos)

A Torch-compatible Dataset of obs-act transitions.

TransitionsWithRew(obs, acts, infos, ...)

A batch of obs-act-obs-rew-done transitions.

class imitation.data.types.Trajectory(obs, acts, infos, terminal)[source]#

Bases: object

A trajectory, e.g. a one episode rollout from an expert policy.

__init__(obs, acts, infos, terminal)#
acts: ndarray#

Actions, shape (trajectory_len, ) + action_shape.

infos: Optional[ndarray]#

An array of info dicts, shape (trajectory_len, ).

The info dict is returned by some environments step() and contains auxiliary diagnostic information. For example the monitor wrapper adds an info dict to the last step of each episode containing the episode return and length.

obs: ndarray#

Observations, shape (trajectory_len + 1, ) + observation_shape.

terminal: bool#

Does this trajectory (fragment) end in a terminal state?

Episodes are always terminal. Trajectory fragments are also terminal when they contain the final state of an episode (even if missing the start of the episode).

class imitation.data.types.TrajectoryWithRew(obs, acts, infos, terminal, rews)[source]#

Bases: Trajectory

A Trajectory that additionally includes reward information.

__init__(obs, acts, infos, terminal, rews)#
rews: ndarray#

Reward, shape (trajectory_len, ). dtype float.

class imitation.data.types.Transitions(obs, acts, infos, next_obs, dones)[source]#

Bases: TransitionsMinimal

A batch of obs-act-obs-done transitions.

__init__(obs, acts, infos, next_obs, dones)#
dones: ndarray#

(batch_size, ).

done[i] is true iff next_obs[i] the last observation of an episode.

Type

Boolean array indicating episode termination. Shape

next_obs: ndarray#

(batch_size, ) + observation_shape.

The i’th observation next_obs[i] in this array is the observation after the agent has taken action acts[i].

Invariants:
  • next_obs.dtype == obs.dtype

  • len(next_obs) == len(obs)

Type

New observation. Shape

class imitation.data.types.TransitionsMinimal(obs, acts, infos)[source]#

Bases: Dataset, Sequence[Mapping[str, ndarray]]

A Torch-compatible Dataset of obs-act transitions.

This class and its subclasses are usually instantiated via imitation.data.rollout.flatten_trajectories.

Indexing an instance trans of TransitionsMinimal with an integer i returns the i`th `Dict[str, np.ndarray] sample, whose keys are the field names of each dataclass field and whose values are the ith elements of each field value.

Slicing returns a possibly empty instance of TransitionsMinimal where each field has been sliced.

__init__(obs, acts, infos)#
acts: ndarray#

(batch_size,) + action_shape.

Type

Actions. Shape

infos: ndarray#

(batch_size,).

Type

Array of info dicts. Shape

obs: ndarray#

(batch_size, ) + observation_shape.

The i’th observation obs[i] in this array is the observation seen by the agent when choosing action acts[i]. obs[i] is not required to be from the timestep preceding obs[i+1].

Type

Previous observations. Shape

class imitation.data.types.TransitionsWithRew(obs, acts, infos, next_obs, dones, rews)[source]#

Bases: Transitions

A batch of obs-act-obs-rew-done transitions.

__init__(obs, acts, infos, next_obs, dones, rews)#
rews: ndarray#

(batch_size, ). dtype float.

The reward rew[i] at the i’th timestep is received after the agent has taken action acts[i].

Type

Reward. Shape

imitation.data.types.dataclass_quick_asdict(obj)[source]#

Extract dataclass to items using dataclasses.fields + dict comprehension.

This is a quick alternative to dataclasses.asdict, which expensively and undocumentedly deep-copies every numpy array value. See https://stackoverflow.com/a/52229565/1091722.

Parameters

obj – A dataclass instance.

Return type

Dict[str, Any]

Returns

A dictionary mapping from obj field names to values.

imitation.data.types.transitions_collate_fn(batch)[source]#

Custom torch.utils.data.DataLoader collate_fn for TransitionsMinimal.

Use this as the collate_fn argument to DataLoader if using an instance of TransitionsMinimal as the dataset argument.

Parameters

batch (Sequence[Mapping[str, ndarray]]) – The batch to collate.

Return type

Mapping[str, Union[ndarray, Tensor]]

Returns

A collated batch. Uses Torch’s default collate function for everything except the “infos” key. For “infos”, we join all the info dicts into a list of dicts. (The default behavior would recursively collate every info dict into a single dict, which is incorrect.)