Maximum Causal Entropy Inverse Reinforcement Learning (MCE IRL)#
Implements Modeling Interaction via the Principle of Maximum Causal Entropy.
Example#
Detailed example notebook: Learn a Reward Function using Maximum Conditional Entropy Inverse Reinforcement Learning
from functools import partial
from seals import base_envs
from seals.diagnostics.cliff_world import CliffWorldEnv
import numpy as np
from stable_baselines3.common.vec_env import DummyVecEnv
from imitation.algorithms.mce_irl import (
MCEIRL,
mce_occupancy_measures,
mce_partition_fh,
)
from imitation.data import rollout
from imitation.rewards import reward_nets
rng = np.random.default_rng(0)
env_creator = partial(CliffWorldEnv, height=4, horizon=8, width=7, use_xy_obs=True)
env_single = env_creator()
state_env_creator = lambda: base_envs.ExposePOMDPStateWrapper(env_creator())
# This is just a vectorized environment because `generate_trajectories` expects one
state_venv = DummyVecEnv([state_env_creator] * 4)
_, _, pi = mce_partition_fh(env_single)
_, om = mce_occupancy_measures(env_single, pi=pi)
reward_net = reward_nets.BasicRewardNet(
env_single.observation_space,
env_single.action_space,
hid_sizes=[256],
use_action=False,
use_done=False,
use_next_state=False,
)
# training on analytically computed occupancy measures
mce_irl = MCEIRL(
om,
env_single,
reward_net,
log_interval=250,
optimizer_kwargs={"lr": 0.01},
rng=rng,
)
occ_measure = mce_irl.train()
imitation_trajs = rollout.generate_trajectories(
policy=mce_irl.policy,
venv=state_venv,
sample_until=rollout.make_min_timesteps(5000),
rng=rng,
)
print("Imitation stats: ", rollout.rollout_stats(imitation_trajs))
API#
- class imitation.algorithms.mce_irl.MCEIRL(demonstrations, env, reward_net, rng, optimizer_cls=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, discount=1.0, linf_eps=0.001, grad_l2_eps=0.0001, log_interval=100, *, custom_logger=None)[source]
Bases:
DemonstrationAlgorithm
[TransitionsMinimal
]Tabular MCE IRL.
Reward is a function of observations, but policy is a function of states.
The “observations” effectively exist just to let MCE IRL learn a reward in a reasonable feature space, giving a helpful inductive bias, e.g. that similar states have similar reward.
Since we are performing planning to compute the policy, there is no need for function approximation in the policy.
- __init__(demonstrations, env, reward_net, rng, optimizer_cls=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, discount=1.0, linf_eps=0.001, grad_l2_eps=0.0001, log_interval=100, *, custom_logger=None)[source]
Creates MCE IRL.
- Parameters
demonstrations (
Union
[ndarray
,Iterable
[Trajectory
],Iterable
[TransitionMapping
],TransitionsMinimal
,None
]) – Demonstrations from an expert (optional). Can be a sequence of trajectories, or transitions, an iterable over mappings that represent a batch of transitions, or a state occupancy measure. The demonstrations must have observations one-hot coded unless demonstrations is a state-occupancy measure.env (
TabularModelPOMDP
) – a tabular MDP.rng (
Generator
) – random state used for sampling from policy.reward_net (
RewardNet
) – a neural network that computes rewards for the supplied observations.optimizer_cls (
Type
[Optimizer
]) – optimizer to use for supervised training.optimizer_kwargs (
Optional
[Mapping
[str
,Any
]]) – keyword arguments for optimizer construction.discount (
float
) – the discount factor to use when computing occupancy measure. If not 1.0 (undiscounted), then demonstrations must either be a (discounted) state-occupancy measure, or trajectories. Transitions are not allowed as we cannot discount them appropriately without knowing the timestep they were drawn from.linf_eps (
float
) – optimisation terminates if the $l_{infty}$ distance between the demonstrator’s state occupancy measure and the state occupancy measure for the current reward falls below this value.grad_l2_eps (
float
) – optimisation also terminates if the $ell_2$ norm of the MCE IRL gradient falls below this value.log_interval (
Optional
[int
]) – how often to log current loss stats (using logging). None to disable.custom_logger (
Optional
[HierarchicalLogger
]) – Where to log to; if None (default), creates a new logger.
- Raises
ValueError – if the env horizon is not finite (or an integer).
- allow_variable_horizon: bool
If True, allow variable horizon trajectories; otherwise error if detected.
- demo_state_om: Optional[ndarray]
- property logger: HierarchicalLogger
- Return type
- property policy: BasePolicy
Returns a policy imitating the demonstration data.
- Return type
BasePolicy
- set_demonstrations(demonstrations)[source]
Sets the demonstration data.
Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger.
- Parameters
demonstrations (
Union
[ndarray
,Iterable
[Trajectory
],Iterable
[TransitionMapping
],TransitionsMinimal
]) – Either a Torch DataLoader, any other iterator that yields dictionaries containing “obs” and “acts” Tensors or NumPy arrays, TransitionKind instance, or a Sequence of Trajectory objects.- Return type
None
- train(max_iter=1000)[source]
Runs MCE IRL.
- Parameters
max_iter (
int
) – The maximum number of iterations to train for. May terminate earlier if self.linf_eps or self.grad_l2_eps thresholds are reached.- Return type
ndarray
- Returns
State occupancy measure for the final reward function. self.reward_net and self.optimizer will be updated in-place during optimisation.
- class imitation.algorithms.base.DemonstrationAlgorithm(*, demonstrations, custom_logger=None, allow_variable_horizon=False)[source]
Bases:
BaseImitationAlgorithm
,Generic
[TransitionKind
]An algorithm that learns from demonstration: BC, IRL, etc.
- __init__(*, demonstrations, custom_logger=None, allow_variable_horizon=False)[source]
Creates an algorithm that learns from demonstrations.
- Parameters
demonstrations (
Union
[Iterable
[Trajectory
],Iterable
[TransitionMapping
],TransitionsMinimal
,None
]) – Demonstrations from an expert (optional). Transitions expressed directly as a types.TransitionsMinimal object, a sequence of trajectories, or an iterable of transition batches (mappings from keywords to arrays containing observations, etc).custom_logger (
Optional
[HierarchicalLogger
]) – Where to log to; if None (default), creates a new logger.allow_variable_horizon (
bool
) – If False (default), algorithm will raise an exception if it detects trajectories of different length during training. If True, overrides this safety check. WARNING: variable horizon episodes leak information about the reward via termination condition, and can seriously confound evaluation. Read https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html before overriding this.
- allow_variable_horizon: bool
If True, allow variable horizon trajectories; otherwise error if detected.
- abstract property policy: BasePolicy
Returns a policy imitating the demonstration data.
- Return type
BasePolicy
- abstract set_demonstrations(demonstrations)[source]
Sets the demonstration data.
Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger.
- Parameters
demonstrations (
Union
[Iterable
[Trajectory
],Iterable
[TransitionMapping
],TransitionsMinimal
]) – Either a Torch DataLoader, any other iterator that yields dictionaries containing “obs” and “acts” Tensors or NumPy arrays, TransitionKind instance, or a Sequence of Trajectory objects.- Return type
None