Source code for imitation.rewards.reward_function
"""Type alias shared by reward-related code."""
import abc
from typing import Protocol
import numpy as np
[docs]class RewardFn(Protocol):
"""Abstract class for reward function.
Requires implementation of __call__() to compute the reward given a batch of
states, actions, next states and dones.
"""
@abc.abstractmethod
def __call__(
self,
state: np.ndarray,
action: np.ndarray,
next_state: np.ndarray,
done: np.ndarray,
) -> np.ndarray:
"""Compute rewards for a batch of transitions.
Args:
state: Current states of shape `(batch_size,) + state_shape`.
action: Actions of shape `(batch_size,) + action_shape`.
next_state: Successor states of shape `(batch_size,) + state_shape`.
done: End-of-episode (terminal state) indicator of shape `(batch_size,)`.
Returns:
Computed rewards of shape `(batch_size,`).
""" # noqa: DAR202