imitation.algorithms.sqil#

Soft Q Imitation Learning (SQIL) (https://arxiv.org/abs/1905.11108).

Trains a policy via DQN-style Q-learning, replacing half the buffer with expert demonstrations and adjusting the rewards.

Classes

SQIL(*, venv, demonstrations, policy[, ...])

Soft Q Imitation Learning (SQIL).

SQILReplayBuffer(buffer_size, ...[, device, ...])

A replay buffer that injects 50% expert demonstrations when sampling.

class imitation.algorithms.sqil.SQIL(*, venv, demonstrations, policy, custom_logger=None, rl_algo_class=<class 'stable_baselines3.dqn.dqn.DQN'>, rl_kwargs=None)[source]#

Bases: DemonstrationAlgorithm[Transitions]

Soft Q Imitation Learning (SQIL).

Trains a policy via DQN-style Q-learning, replacing half the buffer with expert demonstrations and adjusting the rewards.

__init__(*, venv, demonstrations, policy, custom_logger=None, rl_algo_class=<class 'stable_baselines3.dqn.dqn.DQN'>, rl_kwargs=None)[source]#

Builds SQIL.

Parameters
  • venv (VecEnv) – The vectorized environment to train on.

  • demonstrations (Union[Iterable[Trajectory], Iterable[TransitionMapping], TransitionsMinimal, None]) – Demonstrations to use for training.

  • policy (Union[str, Type[BasePolicy]]) – The policy model to use (SB3).

  • custom_logger (Optional[HierarchicalLogger]) – Where to log to; if None (default), creates a new logger.

  • rl_algo_class (Type[OffPolicyAlgorithm]) – Off-policy RL algorithm to use.

  • rl_kwargs (Optional[Dict[str, Any]]) – Keyword arguments to pass to the RL algorithm constructor.

Raises

ValueError – if dqn_kwargs includes a key replay_buffer_class or replay_buffer_kwargs.

expert_buffer: ReplayBuffer#
property policy: BasePolicy#

Returns a policy imitating the demonstration data.

Return type

BasePolicy

set_demonstrations(demonstrations)[source]#

Sets the demonstration data.

Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger.

Parameters

demonstrations (Union[Iterable[Trajectory], Iterable[TransitionMapping], TransitionsMinimal]) – Either a Torch DataLoader, any other iterator that yields dictionaries containing “obs” and “acts” Tensors or NumPy arrays, TransitionKind instance, or a Sequence of Trajectory objects.

Return type

None

train(*, total_timesteps, tb_log_name='SQIL', **kwargs)[source]#
class imitation.algorithms.sqil.SQILReplayBuffer(buffer_size, observation_space, action_space, demonstrations, device='auto', n_envs=1, optimize_memory_usage=False)[source]#

Bases: ReplayBuffer

A replay buffer that injects 50% expert demonstrations when sampling.

This buffer is fundamentally the same as ReplayBuffer, but it includes an expert demonstration internal buffer. When sampling a batch of data, it will be 50/50 expert and collected data.

It can be used in off-policy algorithms like DQN/SAC/TD3.

Here it is used as part of SQIL, where it is used to train a DQN.

__init__(buffer_size, observation_space, action_space, demonstrations, device='auto', n_envs=1, optimize_memory_usage=False)[source]#

Create a SQILReplayBuffer instance.

Parameters
  • buffer_size (int) – Max number of elements in the buffer

  • observation_space (Space) – Observation space

  • action_space (Space) – Action space

  • demonstrations (Union[Iterable[Trajectory], Iterable[TransitionMapping], TransitionsMinimal]) – Expert demonstrations.

  • device (Union[device, str]) – PyTorch device.

  • n_envs (int) – Number of parallel environments. Defaults to 1.

  • optimize_memory_usage (bool) – Enable a memory efficient variant of the replay buffer which reduces by almost a factor two the memory used, at a cost of more complexity.

actions: ndarray#
add(obs, next_obs, action, reward, done, infos)[source]#

Add elements to the buffer.

Return type

None

dones: ndarray#
next_observations: ndarray#
observations: ndarray#
rewards: ndarray#
sample(batch_size, env=None)[source]#

Sample a batch of data.

Half of the batch will be from expert transitions, and the other half will be from the learner transitions.

Parameters
  • batch_size (int) – Number of elements to sample in total

  • env (Optional[VecNormalize]) – associated gym VecEnv to normalize the observations/rewards when sampling

Return type

ReplayBufferSamples

Returns

A mix of transitions from the expert and from the learner.

set_demonstrations(demonstrations)[source]#

Set the expert demonstrations to be injected when sampling from the buffer.

Parameters

demonstrations (algo_base.AnyTransitions) – Expert demonstrations.

Raises

NotImplementedError – If demonstrations is not a transitions object or a list of trajectories.

Return type

None

timeouts: ndarray#