imitation.algorithms.sqil#
Soft Q Imitation Learning (SQIL) (https://arxiv.org/abs/1905.11108).
Trains a policy via DQN-style Q-learning, replacing half the buffer with expert demonstrations and adjusting the rewards.
Classes
|
Soft Q Imitation Learning (SQIL). |
|
A replay buffer that injects 50% expert demonstrations when sampling. |
- class imitation.algorithms.sqil.SQIL(*, venv, demonstrations, policy, custom_logger=None, rl_algo_class=<class 'stable_baselines3.dqn.dqn.DQN'>, rl_kwargs=None)[source]#
Bases:
DemonstrationAlgorithm
[Transitions
]Soft Q Imitation Learning (SQIL).
Trains a policy via DQN-style Q-learning, replacing half the buffer with expert demonstrations and adjusting the rewards.
- __init__(*, venv, demonstrations, policy, custom_logger=None, rl_algo_class=<class 'stable_baselines3.dqn.dqn.DQN'>, rl_kwargs=None)[source]#
Builds SQIL.
- Parameters
venv (
VecEnv
) – The vectorized environment to train on.demonstrations (
Union
[Iterable
[Trajectory
],Iterable
[TransitionMapping
],TransitionsMinimal
,None
]) – Demonstrations to use for training.policy (
Union
[str
,Type
[BasePolicy
]]) – The policy model to use (SB3).custom_logger (
Optional
[HierarchicalLogger
]) – Where to log to; if None (default), creates a new logger.rl_algo_class (
Type
[OffPolicyAlgorithm
]) – Off-policy RL algorithm to use.rl_kwargs (
Optional
[Dict
[str
,Any
]]) – Keyword arguments to pass to the RL algorithm constructor.
- Raises
ValueError – if dqn_kwargs includes a key replay_buffer_class or replay_buffer_kwargs.
- expert_buffer: ReplayBuffer#
- property policy: BasePolicy#
Returns a policy imitating the demonstration data.
- Return type
BasePolicy
- set_demonstrations(demonstrations)[source]#
Sets the demonstration data.
Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger.
- Parameters
demonstrations (
Union
[Iterable
[Trajectory
],Iterable
[TransitionMapping
],TransitionsMinimal
]) – Either a Torch DataLoader, any other iterator that yields dictionaries containing “obs” and “acts” Tensors or NumPy arrays, TransitionKind instance, or a Sequence of Trajectory objects.- Return type
None
- class imitation.algorithms.sqil.SQILReplayBuffer(buffer_size, observation_space, action_space, demonstrations, device='auto', n_envs=1, optimize_memory_usage=False)[source]#
Bases:
ReplayBuffer
A replay buffer that injects 50% expert demonstrations when sampling.
This buffer is fundamentally the same as ReplayBuffer, but it includes an expert demonstration internal buffer. When sampling a batch of data, it will be 50/50 expert and collected data.
It can be used in off-policy algorithms like DQN/SAC/TD3.
Here it is used as part of SQIL, where it is used to train a DQN.
- __init__(buffer_size, observation_space, action_space, demonstrations, device='auto', n_envs=1, optimize_memory_usage=False)[source]#
Create a SQILReplayBuffer instance.
- Parameters
buffer_size (
int
) – Max number of elements in the bufferobservation_space (
Space
) – Observation spaceaction_space (
Space
) – Action spacedemonstrations (
Union
[Iterable
[Trajectory
],Iterable
[TransitionMapping
],TransitionsMinimal
]) – Expert demonstrations.device (
Union
[device
,str
]) – PyTorch device.n_envs (
int
) – Number of parallel environments. Defaults to 1.optimize_memory_usage (
bool
) – Enable a memory efficient variant of the replay buffer which reduces by almost a factor two the memory used, at a cost of more complexity.
- actions: ndarray#
- add(obs, next_obs, action, reward, done, infos)[source]#
Add elements to the buffer.
- Return type
None
- dones: ndarray#
- next_observations: ndarray#
- observations: ndarray#
- rewards: ndarray#
- sample(batch_size, env=None)[source]#
Sample a batch of data.
Half of the batch will be from expert transitions, and the other half will be from the learner transitions.
- Parameters
batch_size (
int
) – Number of elements to sample in totalenv (
Optional
[VecNormalize
]) – associated gym VecEnv to normalize the observations/rewards when sampling
- Return type
ReplayBufferSamples
- Returns
A mix of transitions from the expert and from the learner.
- set_demonstrations(demonstrations)[source]#
Set the expert demonstrations to be injected when sampling from the buffer.
- Parameters
demonstrations (algo_base.AnyTransitions) – Expert demonstrations.
- Raises
NotImplementedError – If demonstrations is not a transitions object or a list of trajectories.
- Return type
None
- timeouts: ndarray#