"""Buffers to store NumPy arrays and transitions in."""
from typing import Any, Mapping, Optional, Tuple
import numpy as np
from stable_baselines3.common import vec_env
from imitation.data import types
[docs]def num_samples(data: Mapping[Any, np.ndarray]) -> int:
"""Computes the number of samples contained in `data`.
Args:
data: A Mapping from keys to NumPy arrays.
Returns:
The unique length of the first dimension of arrays contained in `data`.
Raises:
ValueError: The length is not unique.
"""
n_samples_list = [arr.shape[0] for arr in data.values()]
n_samples_np = np.unique(n_samples_list)
if len(n_samples_np) > 1:
raise ValueError("Keys map to different length values.")
return int(n_samples_np[0])
[docs]class Buffer:
"""A FIFO ring buffer for NumPy arrays of a fixed shape and dtype.
Supports random sampling with replacement.
"""
capacity: int
"""The number of data samples that can be stored in this buffer."""
sample_shapes: Mapping[str, Tuple[int, ...]]
"""The shapes of each data sample stored in this buffer."""
_arrays: Mapping[str, np.ndarray]
"""The underlying NumPy arrays (which actually store the data)."""
_n_data: int
"""The number of samples currently stored in this buffer.
An integer in `range(0, self.capacity + 1)`. This attribute is the return
value of `self.size()`.
"""
_idx: int
"""The index of the first row that new data should be written to.
An integer in `range(0, self.capacity)`.
"""
[docs] def __init__(
self,
capacity: int,
sample_shapes: Mapping[str, Tuple[int, ...]],
dtypes: Mapping[str, np.dtype],
):
"""Constructs a Buffer.
Args:
capacity: The number of samples that can be stored.
sample_shapes: A dictionary mapping string keys to the shape of
samples associated with that key.
dtypes (`np.dtype`-like): A dictionary mapping string keys to the dtype
of samples associated with that key.
Raises:
KeyError: `sample_shapes` and `dtypes` have different keys.
"""
if sample_shapes.keys() != dtypes.keys():
raise KeyError("sample_shape and dtypes keys don't match")
self.capacity = capacity
self.sample_shapes = {k: tuple(shape) for k, shape in sample_shapes.items()}
self._arrays = {
k: np.zeros((capacity,) + shape, dtype=dtypes[k])
for k, shape in self.sample_shapes.items()
}
self._n_data = 0
self._idx = 0
[docs] @classmethod
def from_data(
cls,
data: Mapping[str, np.ndarray],
capacity: Optional[int] = None,
truncate_ok: bool = False,
) -> "Buffer":
"""Constructs and return a Buffer containing the provided data.
Shapes and dtypes are automatically inferred.
Args:
data: A dictionary mapping keys to data arrays. The arrays may differ
in their shape, but should agree in the first axis.
capacity: The Buffer capacity. If not provided, then this is automatically
set to the size of the data, so that the returned Buffer is at full
capacity.
truncate_ok: Whether to error if `capacity` < the number of samples in
`data`. If False, then only store the last `capacity` samples from
`data` when overcapacity.
Examples:
In the follow examples, suppose the arrays in `data` are length-1000.
`Buffer` with same capacity as arrays in `data`::
Buffer.from_data(data)
`Buffer` with larger capacity than arrays in `data`::
Buffer.from_data(data, 10000)
`Buffer with smaller capacity than arrays in `data`. Without
`truncate_ok=True`, `from_data` will error::
Buffer.from_data(data, 5, truncate_ok=True)
Returns:
Buffer of specified `capacity` containing provided `data`.
Raises:
ValueError: `data` is empty.
ValueError: `data` has items mapping to arrays differing in the
length of their first axis.
"""
data_capacities_list = [arr.shape[0] for arr in data.values()]
data_capacities = np.unique(data_capacities_list)
if len(data) == 0:
raise ValueError("No keys in data.")
if len(data_capacities) > 1:
raise ValueError("Keys map to different length values")
if capacity is None:
capacity = data_capacities[0]
sample_shapes = {k: arr.shape[1:] for k, arr in data.items()}
dtypes = {k: arr.dtype for k, arr in data.items()}
buf = cls(capacity, sample_shapes, dtypes)
buf.store(data, truncate_ok=truncate_ok)
return buf
[docs] def store(self, data: Mapping[str, np.ndarray], truncate_ok: bool = False) -> None:
"""Stores new data samples, replacing old samples with FIFO priority.
Args:
data: A dictionary mapping keys `k` to arrays with shape
`(n_samples,) + self.sample_shapes[k]`, where `n_samples` is less
than or equal to `self.capacity`.
truncate_ok: If False, then error if the length of `transitions` is
greater than `self.capacity`. Otherwise, store only the final
`self.capacity` transitions.
Raises:
ValueError: `data` is empty.
ValueError: If `n_samples` is greater than `self.capacity`.
ValueError: data is the wrong shape.
"""
expected_keys = set(self.sample_shapes.keys())
missing_keys = expected_keys.difference(data.keys())
unexpected_keys = set(data.keys()).difference(expected_keys)
if len(missing_keys) > 0:
raise ValueError(f"Missing keys {missing_keys}")
if len(unexpected_keys) > 0:
raise ValueError(f"Unexpected keys {unexpected_keys}")
n_samples = num_samples(data)
if n_samples == 0:
raise ValueError("Trying to store empty data.")
if n_samples > self.capacity:
if not truncate_ok:
raise ValueError("Not enough capacity to store data.")
else:
data = {k: arr[-self.capacity :] for k, arr in data.items()}
for k, arr in data.items():
if arr.shape[1:] != self.sample_shapes[k]:
raise ValueError(f"Wrong data shape for {k}")
new_idx = self._idx + n_samples
if new_idx > self.capacity:
n_remain = self.capacity - self._idx
# Need to loop around the buffer. Break into two "easy" calls.
self._store_easy({k: arr[:n_remain] for k, arr in data.items()})
assert self._idx == 0
self._store_easy({k: arr[n_remain:] for k, arr in data.items()})
else:
self._store_easy(data)
def _store_easy(self, data: Mapping[str, np.ndarray]) -> None:
"""Stores new data samples, replacing old samples with FIFO priority.
Requires that `size(data) <= self.capacity - self._idx`, where `size(data)` is
the number of rows in every array in `data.values()`. Updates `self._idx`
to be the insertion point of the next call to `_store_easy` call,
looping back to `self._idx = 0` if necessary.
Also updates `self._n_data`.
Args:
data: Same as in `self.store`'s docstring, except with the additional
constraint `size(data) <= self.capacity - self._idx`.
"""
n_samples = num_samples(data)
assert n_samples <= self.capacity - self._idx
idx_hi = self._idx + n_samples
for k, arr in data.items():
self._arrays[k][self._idx : idx_hi] = arr
self._idx = idx_hi % self.capacity
self._n_data = min(self._n_data + n_samples, self.capacity)
[docs] def sample(self, n_samples: int) -> Mapping[str, np.ndarray]:
"""Uniformly sample `n_samples` samples from the buffer with replacement.
Args:
n_samples: The number of samples to randomly sample.
Returns:
samples (np.ndarray): An array with shape
`(n_samples) + self.sample_shape`.
Raises:
ValueError: The buffer is empty.
"""
if self.size() == 0:
raise ValueError("Buffer is empty")
ind = np.random.randint(self.size(), size=n_samples)
return {k: buffer[ind] for k, buffer in self._arrays.items()}
[docs] def size(self) -> int:
"""Returns the number of samples stored in the buffer."""
assert 0 <= self._n_data <= self.capacity
return self._n_data
[docs]class ReplayBuffer:
"""Buffer for Transitions."""
capacity: int
"""The number of data samples that can be stored in this buffer."""
[docs] def __init__(
self,
capacity: int,
venv: Optional[vec_env.VecEnv] = None,
*,
obs_shape: Optional[Tuple[int, ...]] = None,
act_shape: Optional[Tuple[int, ...]] = None,
obs_dtype: Optional[np.dtype] = None,
act_dtype: Optional[np.dtype] = None,
):
"""Constructs a ReplayBuffer.
Args:
capacity: The number of samples that can be stored.
venv: The environment whose action and observation
spaces can be used to determine the data shapes of the underlying
buffers. Mutually exclusive with shape and dtype arguments.
obs_shape: The shape of the observation space.
act_shape: The shape of the action space.
obs_dtype: The dtype of the observation space.
act_dtype: The dtype of the action space.
Raises:
ValueError: Couldn't infer the observation and action shapes and dtypes
from the arguments.
ValueError: Specified both venv and shapes/dtypes.
"""
params = (obs_shape, act_shape, obs_dtype, act_dtype)
if venv is not None:
if venv.observation_space.shape is not None:
if obs_shape is not None:
raise ValueError(
"Cannot specify both observation shape and also environment "
"with an observation space that has a shape.",
)
obs_shape = tuple(venv.observation_space.shape)
if venv.observation_space.dtype is not None:
if obs_dtype is not None:
raise ValueError(
"Cannot specify both observation dtype and also environment "
"with an observation space that has a dtype.",
)
obs_dtype = venv.observation_space.dtype
if venv.action_space.shape is not None:
if act_shape is not None:
raise ValueError(
"Cannot specify both action shape and also environment "
"with an action space that has a shape.",
)
act_shape = tuple(venv.action_space.shape)
if venv.action_space.dtype is not None:
if act_dtype is not None:
raise ValueError(
"Cannot specify both action dtype and also environment "
"with an action space that has a dtype.",
)
act_dtype = venv.action_space.dtype
else:
if any(x is None for x in params):
raise ValueError("Shape or dtype missing and no environment specified.")
assert obs_shape is not None
assert act_shape is not None
assert obs_dtype is not None
assert act_dtype is not None
self.capacity = capacity
sample_shapes = {
"obs": obs_shape,
"acts": act_shape,
"next_obs": obs_shape,
"dones": (),
"infos": (),
}
dtypes = {
"obs": obs_dtype,
"acts": act_dtype,
"next_obs": obs_dtype,
"dones": np.dtype(bool),
"infos": np.dtype(object),
}
self._buffer = Buffer(capacity, sample_shapes=sample_shapes, dtypes=dtypes)
[docs] @classmethod
def from_data(
cls,
transitions: types.Transitions,
capacity: Optional[int] = None,
truncate_ok: bool = False,
) -> "ReplayBuffer":
"""Construct and return a ReplayBuffer containing the provided data.
Shapes and dtypes are automatically inferred, and the returned ReplayBuffer is
ready for sampling.
Args:
transitions: Transitions to store.
capacity: The ReplayBuffer capacity. If not provided, then this is
automatically set to the size of the data, so that the returned Buffer
is at full capacity.
truncate_ok: Whether to error if `capacity` < the number of samples in
`data`. If False, then only store the last `capacity` samples from
`data` when overcapacity.
Examples:
`ReplayBuffer` with same capacity as arrays in `data`::
ReplayBuffer.from_data(data)
`ReplayBuffer` with larger capacity than arrays in `data`::
ReplayBuffer.from_data(data, 10000)
`ReplayBuffer with smaller capacity than arrays in `data`. Without
`truncate_ok=True`, `from_data` will error::
ReplayBuffer.from_data(data, 5, truncate_ok=True)
Returns:
A new ReplayBuffer.
"""
obs = types.assert_not_dictobs(transitions.obs)
obs_shape = obs.shape[1:]
act_shape = transitions.acts.shape[1:]
if capacity is None:
capacity = obs.shape[0]
instance = cls(
capacity=capacity,
obs_shape=obs_shape,
act_shape=act_shape,
obs_dtype=obs.dtype,
act_dtype=transitions.acts.dtype,
)
instance.store(transitions, truncate_ok=truncate_ok)
return instance
[docs] def sample(self, n_samples: int) -> types.Transitions:
"""Sample obs-act-obs triples.
Args:
n_samples: The number of samples.
Returns:
A Transitions named tuple containing n_samples transitions.
"""
sample = self._buffer.sample(n_samples)
return types.Transitions(**sample)
[docs] def store(self, transitions: types.Transitions, truncate_ok: bool = True) -> None:
"""Store obs-act-obs triples.
Args:
transitions: Transitions to store.
truncate_ok: If False, then error if the length of `transitions` is
greater than `self.capacity`. Otherwise, store only the final
`self.capacity` transitions.
Raises:
ValueError: The arguments didn't have the same length.
""" # noqa: DAR402
trans_dict = types.dataclass_quick_asdict(transitions)
# Remove unnecessary fields
trans_dict = {k: trans_dict[k] for k in self._buffer.sample_shapes.keys()}
self._buffer.store(trans_dict, truncate_ok=truncate_ok)
[docs] def size(self) -> Optional[int]:
"""Returns the number of samples stored in the buffer."""
return self._buffer.size()