Source code for imitation.policies.interactive

"""Interactive policies that query the user for actions."""

import abc
import collections
from typing import Dict, Optional, Union

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
from shimmy import atari_env
from stable_baselines3.common import vec_env

import imitation.policies.base as base_policies
from imitation.util import util


[docs]class DiscreteInteractivePolicy(base_policies.NonTrainablePolicy, abc.ABC): """Abstract class for interactive policies with discrete actions. For each query, the observation is rendered and then the action is provided as a keyboard input. """
[docs] def __init__( self, observation_space: gym.Space, action_space: gym.Space, action_keys_names: collections.OrderedDict, clear_screen_on_query: bool = True, ): """Builds DiscreteInteractivePolicy. Args: observation_space: Observation space. action_space: Action space. action_keys_names: `OrderedDict` containing pairs (key, name) for every action, where key will be used in the console interface, and name is a semantic action name. The index of the pair in the dictionary will be used as the discrete, integer action. clear_screen_on_query: If `True`, console will be cleared on every query. """ super().__init__( observation_space=observation_space, action_space=action_space, ) assert isinstance(action_space, gym.spaces.Discrete) assert ( len(action_keys_names) == len(set(action_keys_names.values())) == action_space.n ) self.action_keys_names = action_keys_names self.action_key_to_index = { k: i for i, k in enumerate(action_keys_names.keys()) } self.clear_screen_on_query = clear_screen_on_query
def _choose_action( self, obs: Union[np.ndarray, Dict[str, np.ndarray]], ) -> np.ndarray: if self.clear_screen_on_query: util.clear_screen() if isinstance(obs, dict): raise ValueError("Dictionary observations are not supported here") context = self._render(obs) key = self._get_input_key() self._clean_up(context) return np.array([self.action_key_to_index[key]]) def _get_input_key(self) -> str: """Obtains input key for action selection.""" print( "Please select an action. Possible choices in [ACTION_NAME:KEY] format:", ", ".join([f"{n}:{k}" for k, n in self.action_keys_names.items()]), ) key = input("Your choice (enter key):") while key not in self.action_keys_names.keys(): key = input("Invalid key, please try again! Your choice (enter key):") return key @abc.abstractmethod def _render(self, obs: np.ndarray) -> Optional[object]: """Renders an observation, optionally returns a context for later cleanup.""" def _clean_up(self, context: object) -> None: """Cleans up after the input has been captured, e.g. stops showing the image."""
[docs]class ImageObsDiscreteInteractivePolicy(DiscreteInteractivePolicy): """DiscreteInteractivePolicy that renders image observations.""" def _render(self, obs: np.ndarray) -> plt.Figure: img = self._prepare_obs_image(obs) fig, ax = plt.subplots() ax.imshow(img, cmap="gray", vmin=0, vmax=255) # cmap is ignored for RGB images. ax.axis("off") fig.show() return fig def _clean_up(self, context: plt.Figure) -> None: plt.close(context) def _prepare_obs_image(self, obs: np.ndarray) -> np.ndarray: """Applies any required observation processing to get an image to show.""" return obs
ATARI_ACTION_NAMES_TO_KEYS = { "NOOP": "1", "FIRE": "2", "UP": "w", "RIGHT": "d", "LEFT": "a", "DOWN": "x", "UPRIGHT": "e", "UPLEFT": "q", "DOWNRIGHT": "c", "DOWNLEFT": "z", "UPFIRE": "t", "RIGHTFIRE": "h", "LEFTFIRE": "f", "DOWNFIRE": "b", "UPRIGHTFIRE": "y", "UPLEFTFIRE": "r", "DOWNRIGHTFIRE": "n", "DOWNLEFTFIRE": "v", }
[docs]class AtariInteractivePolicy(ImageObsDiscreteInteractivePolicy): """Interactive policy for Atari environments."""
[docs] def __init__(self, env: Union[atari_env.AtariEnv, vec_env.VecEnv], *args, **kwargs): """Builds AtariInteractivePolicy.""" action_names = ( env.get_action_meanings() if isinstance(env, atari_env.AtariEnv) else env.env_method("get_action_meanings", indices=[0])[0] ) action_keys_names = collections.OrderedDict( [(ATARI_ACTION_NAMES_TO_KEYS[name], name) for name in action_names], ) super().__init__( env.observation_space, env.action_space, action_keys_names, *args, **kwargs, )