imitation.testing.reward_improvement#

Utility functions used to check if rewards improved wrt to previous rewards.

Functions

is_significant_reward_improvement(...[, p_value])

Checks if the new rewards are really better than the old rewards.

mean_reward_improved_by(old_rews, new_rews, ...)

Checks if mean rewards improved wrt.

imitation.testing.reward_improvement.is_significant_reward_improvement(old_rewards, new_rewards, p_value=0.05)[source]#

Checks if the new rewards are really better than the old rewards.

Ensures that this is not just due to lucky sampling by a permutation test.

Parameters
  • old_rewards (Iterable[float]) – Iterable of “old” trajectory rewards (e.g. before training).

  • new_rewards (Iterable[float]) – Iterable of “new” trajectory rewards (e.g. after training).

  • p_value (float) – The maximum probability, that the old rewards are just as good as the new rewards, that we tolerate.

Return type

bool

Returns

True, if the new rewards are most probably better than the old rewards. For this, the probability, that the old rewards are just as good as the new rewards must be below p_value.

>>> is_significant_reward_improvement((5, 6, 7, 4, 4), (7, 5, 9, 9, 12))
True
>>> is_significant_reward_improvement((5, 6, 7, 4, 4), (7, 5, 9, 7, 4))
False
>>> is_significant_reward_improvement((5, 6, 7, 4, 4), (7, 5, 9, 7, 4), p_value=0.3)
True
imitation.testing.reward_improvement.mean_reward_improved_by(old_rews, new_rews, min_improvement)[source]#

Checks if mean rewards improved wrt. to old rewards by a certain amount.

Parameters
  • old_rews (Iterable[float]) – Iterable of “old” trajectory rewards (e.g. before training).

  • new_rews (Iterable[float]) – Iterable of “new” trajectory rewards (e.g. after training).

  • min_improvement (float) – The minimum amount of improvement that we expect.

Returns

True if the mean of the new rewards is larger than the mean of the old rewards by min_improvement.

>>> mean_reward_improved_by([5, 8, 7], [8, 9, 10], 2)
True
>>> mean_reward_improved_by([5, 8, 7], [8, 9, 10], 5)
False