Soft Q Imitation Learning (SQIL)#

Soft Q Imitation learning learns to imitate a policy from demonstrations by using the DQN algorithm with modified rewards. During each policy update, half of the batch is sampled from the demonstrations and half is sampled from the environment. Expert demonstrations are assigned a reward of 1, and the environment is assigned a reward of 0. This encourages the policy to imitate the demonstrations, and to simultaneously avoid states not seen in the demonstrations.

Note

This implementation is based on the DQN implementation in Stable Baselines 3, which does not implement the soft Q-learning and therefore does not support continuous actions. Therefore, this implementation only supports discrete actions and the name “soft” Q-learning could be misleading.

Example#

Detailed example notebook: Train an Agent using Soft Q Imitation Learning

import datasets
import gymnasium as gym
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv

from imitation.algorithms import sqil
from imitation.data import huggingface_utils

# Download some expert trajectories from the HuggingFace Datasets Hub.
dataset = datasets.load_dataset("HumanCompatibleAI/ppo-CartPole-v1")
rollouts = huggingface_utils.TrajectoryDatasetSequence(dataset["train"])

sqil_trainer = sqil.SQIL(
    venv=DummyVecEnv([lambda: gym.make("CartPole-v1")]),
    demonstrations=rollouts,
    policy="MlpPolicy",
)
# Hint: set to 1_000_000 to match the expert performance.
sqil_trainer.train(total_timesteps=1_000)
reward, _ = evaluate_policy(sqil_trainer.policy, sqil_trainer.venv, 10)
print("Reward:", reward)

API#

class imitation.algorithms.sqil.SQIL(*, venv, demonstrations, policy, custom_logger=None, rl_algo_class=<class 'stable_baselines3.dqn.dqn.DQN'>, rl_kwargs=None)[source]

Bases: DemonstrationAlgorithm[Transitions]

Soft Q Imitation Learning (SQIL).

Trains a policy via DQN-style Q-learning, replacing half the buffer with expert demonstrations and adjusting the rewards.

__init__(*, venv, demonstrations, policy, custom_logger=None, rl_algo_class=<class 'stable_baselines3.dqn.dqn.DQN'>, rl_kwargs=None)[source]

Builds SQIL.

Parameters
  • venv (VecEnv) – The vectorized environment to train on.

  • demonstrations (Union[Iterable[Trajectory], Iterable[TransitionMapping], TransitionsMinimal, None]) – Demonstrations to use for training.

  • policy (Union[str, Type[BasePolicy]]) – The policy model to use (SB3).

  • custom_logger (Optional[HierarchicalLogger]) – Where to log to; if None (default), creates a new logger.

  • rl_algo_class (Type[OffPolicyAlgorithm]) – Off-policy RL algorithm to use.

  • rl_kwargs (Optional[Dict[str, Any]]) – Keyword arguments to pass to the RL algorithm constructor.

Raises

ValueError – if dqn_kwargs includes a key replay_buffer_class or replay_buffer_kwargs.

allow_variable_horizon: bool

If True, allow variable horizon trajectories; otherwise error if detected.

expert_buffer: ReplayBuffer
property policy: BasePolicy

Returns a policy imitating the demonstration data.

Return type

BasePolicy

set_demonstrations(demonstrations)[source]

Sets the demonstration data.

Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger.

Parameters

demonstrations (Union[Iterable[Trajectory], Iterable[TransitionMapping], TransitionsMinimal]) – Either a Torch DataLoader, any other iterator that yields dictionaries containing “obs” and “acts” Tensors or NumPy arrays, TransitionKind instance, or a Sequence of Trajectory objects.

Return type

None

train(*, total_timesteps, tb_log_name='SQIL', **kwargs)[source]