"""Miscellaneous utility methods."""
import datetime
import functools
import itertools
import os
import pathlib
import uuid
import warnings
from typing import (
Any,
Callable,
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
overload,
)
import gymnasium as gym
import numpy as np
import torch as th
from stable_baselines3.common import monitor, policies
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
from imitation.data.types import AnyPath
[docs]def save_policy(policy: policies.BasePolicy, policy_path: AnyPath) -> None:
"""Save policy to a path.
Args:
policy: policy to save.
policy_path: path to save policy to.
"""
th.save(policy, parse_path(policy_path))
[docs]def oric(x: np.ndarray) -> np.ndarray:
"""Optimal rounding under integer constraints.
Given a vector of real numbers such that the sum is an integer, returns a vector
of rounded integers that preserves the sum and which minimizes the Lp-norm of the
difference between the rounded and original vectors for all p >= 1. Algorithm from
https://arxiv.org/abs/1501.00014. Runs in O(n log n) time.
Args:
x: A 1D vector of real numbers that sum to an integer.
Returns:
A 1D vector of rounded integers, preserving the sum.
"""
rounded = np.floor(x)
shortfall = x - rounded
# The total shortfall should be *exactly* an integer, but we
# round to account for numerical error.
total_shortfall = np.round(shortfall.sum()).astype(int)
indices = np.argsort(-shortfall)
# Apportion the total shortfall to the elements in order of
# decreasing shortfall.
rounded[indices[:total_shortfall]] += 1
return rounded.astype(int)
[docs]def make_unique_timestamp() -> str:
"""Timestamp, with random uuid added to avoid collisions."""
ISO_TIMESTAMP = "%Y%m%d_%H%M%S"
timestamp = datetime.datetime.now().strftime(ISO_TIMESTAMP)
random_uuid = uuid.uuid4().hex[:6]
return f"{timestamp}_{random_uuid}"
[docs]def make_vec_env(
env_name: str,
*,
rng: np.random.Generator,
n_envs: int = 8,
parallel: bool = False,
log_dir: Optional[str] = None,
max_episode_steps: Optional[int] = None,
post_wrappers: Optional[Sequence[Callable[[gym.Env, int], gym.Env]]] = None,
env_make_kwargs: Optional[Mapping[str, Any]] = None,
) -> VecEnv:
"""Makes a vectorized environment.
Args:
env_name: The Env's string id in Gym.
rng: The random state to use to seed the environment.
n_envs: The number of duplicate environments.
parallel: If True, uses SubprocVecEnv; otherwise, DummyVecEnv.
log_dir: If specified, saves Monitor output to this directory.
max_episode_steps: If specified, wraps each env in a TimeLimit wrapper
with this episode length. If not specified and `max_episode_steps`
exists for this `env_name` in the Gym registry, uses the registry
`max_episode_steps` for every TimeLimit wrapper (this automatic
wrapper is the default behavior when calling `gym.make`). Otherwise
the environments are passed into the VecEnv unwrapped.
post_wrappers: If specified, iteratively wraps each environment with each
of the wrappers specified in the sequence. The argument should be a Callable
accepting two arguments, the Env to be wrapped and the environment index,
and returning the wrapped Env.
env_make_kwargs: The kwargs passed to `spec.make`.
Returns:
A VecEnv initialized with `n_envs` environments.
"""
# Resolve the spec outside of the subprocess first, so that it is available to
# subprocesses running `make_env` via automatic pickling.
# Just to ensure packages are imported and spec is properly resolved
tmp_env = gym.make(env_name)
tmp_env.close()
spec = tmp_env.spec
env_make_kwargs = env_make_kwargs or {}
def make_env(i: int, this_seed: int) -> gym.Env:
# Previously, we directly called `gym.make(env_name)`, but running
# `imitation.scripts.train_adversarial` within `imitation.scripts.parallel`
# created a weird interaction between Gym and Ray -- `gym.make` would fail
# inside this function for any of our custom environment unless those
# environments were also `gym.register()`ed inside `make_env`. Even
# registering the custom environment in the scope of `make_vec_env` didn't
# work. For more discussion and hypotheses on this issue see PR #160:
# https://github.com/HumanCompatibleAI/imitation/pull/160.
assert env_make_kwargs is not None # Note: to satisfy mypy
assert spec is not None # Note: to satisfy mypy
env = gym.make(spec, max_episode_steps=max_episode_steps, **env_make_kwargs)
# Seed each environment with a different, non-sequential seed for diversity
# (even if caller is passing us sequentially-assigned base seeds). int() is
# necessary to work around gym bug where it chokes on numpy int64s.
env.reset(seed=int(this_seed))
# NOTE: we do it here rather than on the final VecEnv, because
# that would set the same seed for all the environments.
# Use Monitor to record statistics needed for Baselines algorithms logging
# Optionally, save to disk
log_path = None
if log_dir is not None:
log_subdir = os.path.join(log_dir, "monitor")
os.makedirs(log_subdir, exist_ok=True)
log_path = os.path.join(log_subdir, f"mon{i:03d}")
env = monitor.Monitor(env, log_path)
if post_wrappers:
for wrapper in post_wrappers:
env = wrapper(env, i)
return env
env_seeds = make_seeds(rng, n_envs)
env_fns: List[Callable[[], gym.Env]] = [
functools.partial(make_env, i, s) for i, s in enumerate(env_seeds)
]
if parallel:
# See GH hill-a/stable-baselines issue #217
return SubprocVecEnv(env_fns, start_method="forkserver")
else:
return DummyVecEnv(env_fns)
@overload
def make_seeds(
rng: np.random.Generator,
) -> int:
...
@overload
def make_seeds(rng: np.random.Generator, n: int) -> List[int]:
...
[docs]def make_seeds(
rng: np.random.Generator,
n: Optional[int] = None,
) -> Union[Sequence[int], int]:
"""Generate n random seeds from a random state.
Args:
rng: The random state to use to generate seeds.
n: The number of seeds to generate.
Returns:
A list of n random seeds.
"""
seeds_arr = rng.integers(0, (1 << 31) - 1, (n if n is not None else 1,))
seeds: List[int] = seeds_arr.tolist()
if n is None:
return seeds[0]
else:
return seeds
[docs]def docstring_parameter(*args, **kwargs):
"""Treats the docstring as a format string, substituting in the arguments."""
def helper(obj):
obj.__doc__ = obj.__doc__.format(*args, **kwargs)
return obj
return helper
T = TypeVar("T")
[docs]def endless_iter(iterable: Iterable[T]) -> Iterator[T]:
"""Generator that endlessly yields elements from `iterable`.
>>> x = range(2)
>>> it = endless_iter(x)
>>> next(it)
0
>>> next(it)
1
>>> next(it)
0
Args:
iterable: The non-iterator iterable object to endlessly iterate over.
Returns:
An iterator that repeats the elements in `iterable` forever.
Raises:
ValueError: if iterable is an iterator -- that will be exhausted, so
cannot be iterated over endlessly.
"""
if iter(iterable) == iterable:
raise ValueError("endless_iter needs a non-iterator Iterable.")
_, iterable = get_first_iter_element(iterable)
return itertools.chain.from_iterable(itertools.repeat(iterable))
[docs]def safe_to_tensor(array: Union[np.ndarray, th.Tensor], **kwargs) -> th.Tensor:
"""Converts a NumPy array to a PyTorch tensor.
The data is copied in the case where the array is non-writable. Unfortunately if
you just use `th.as_tensor` for this, an ugly warning is logged and there's
undefined behavior if you try to write to the tensor.
Args:
array: The array to convert to a PyTorch tensor.
kwargs: Additional keyword arguments to pass to `th.as_tensor`.
Returns:
A PyTorch tensor with the same content as `array`.
"""
if isinstance(array, th.Tensor):
return array
if not array.flags.writeable:
array = array.copy()
return th.as_tensor(array, **kwargs)
@overload
def safe_to_numpy(obj: Union[np.ndarray, th.Tensor], warn: bool = False) -> np.ndarray:
...
@overload
def safe_to_numpy(obj: None, warn: bool = False) -> None:
...
[docs]def safe_to_numpy(
obj: Optional[Union[np.ndarray, th.Tensor]],
warn: bool = False,
) -> Optional[np.ndarray]:
"""Convert torch tensor to numpy.
If the object is already a numpy array, return it as is.
If the object is none, returns none.
Args:
obj: torch tensor object to convert to numpy array
warn: if True, warn if the object is not already a numpy array. Useful for
warning the user of a potential performance hit if a torch tensor is
not the expected input type.
Returns:
Object converted to numpy array
"""
if obj is None:
# We ignore the type due to https://github.com/google/pytype/issues/445
return None # pytype: disable=bad-return-type
elif isinstance(obj, np.ndarray):
return obj
else:
if warn:
warnings.warn(
"Converted tensor to numpy array, might affect performance. "
"Make sure this is the intended behavior.",
)
return obj.detach().cpu().numpy()
[docs]def tensor_iter_norm(
tensor_iter: Iterable[th.Tensor],
ord: Union[int, float] = 2, # noqa: A002
) -> th.Tensor:
"""Compute the norm of a big vector that is produced one tensor chunk at a time.
Args:
tensor_iter: an iterable that yields tensors.
ord: order of the p-norm (can be any int or float except 0 and NaN).
Returns:
Norm of the concatenated tensors.
Raises:
ValueError: ord is 0 (unsupported).
"""
if ord == 0:
raise ValueError("This function cannot compute p-norms for p=0.")
norms = []
for tensor in tensor_iter:
norms.append(th.norm(tensor.flatten(), p=ord))
norm_tensor = th.as_tensor(norms)
# Norm of the norms is equal to the norm of the concatenated tensor.
# th.norm(norm_tensor) = sum(norm**ord for norm in norm_tensor)**(1/ord)
# = sum(sum(x**ord for x in tensor) for tensor in tensor_iter)**(1/ord)
# = sum(x**ord for x in tensor for tensor in tensor_iter)**(1/ord)
# = th.norm(concatenated tensors)
return th.norm(norm_tensor, p=ord)
[docs]def get_first_iter_element(iterable: Iterable[T]) -> Tuple[T, Iterable[T]]:
"""Get first element of an iterable and a new fresh iterable.
The fresh iterable has the first element added back using ``itertools.chain``.
If the iterable is not an iterator, this is equivalent to
``(next(iter(iterable)), iterable)``.
Args:
iterable: The iterable to get the first element of.
Returns:
A tuple containing the first element of the iterable, and a fresh iterable
with all the elements.
Raises:
ValueError: `iterable` is empty -- the first call to it returns no elements.
"""
iterator = iter(iterable)
try:
first_element = next(iterator)
except StopIteration:
raise ValueError(f"iterable {iterable} had no elements to iterate over.")
return_iterable: Iterable[T]
if iterator == iterable:
# `iterable` was an iterator. Getting `first_element` will have removed it
# from `iterator`, so we need to add a fresh iterable with `first_element`
# added back in.
return_iterable = itertools.chain([first_element], iterator)
else:
# `iterable` was not an iterator; we can just return `iterable`.
# `iter(iterable)` will give a fresh iterator containing the first element.
# It's preferable to return `iterable` without modification so that users
# can generate new iterators from it as needed.
return_iterable = iterable
return first_element, return_iterable
[docs]def parse_path(
path: AnyPath,
allow_relative: bool = True,
base_directory: Optional[pathlib.Path] = None,
) -> pathlib.Path:
"""Parse a path to a `pathlib.Path` object.
All resulting paths are resolved, absolute paths. If `allow_relative` is True,
then relative paths are allowed as input, and are resolved relative to the
current working directory, or relative to `base_directory` if it is
specified.
Args:
path: The path to parse. Can be a string, bytes, or `os.PathLike`.
allow_relative: If True, then relative paths are allowed as input, and
are resolved relative to the current working directory. If False,
an error is raised if the path is not absolute.
base_directory: If specified, then relative paths are resolved relative
to this directory, instead of the current working directory.
Returns:
A `pathlib.Path` object.
Raises:
ValueError: If `allow_relative` is False and the path is not absolute.
ValueError: If `base_directory` is specified and `allow_relative` is
False.
"""
if base_directory is not None and not allow_relative:
raise ValueError(
"If `base_directory` is specified, then `allow_relative` must be True.",
)
parsed_path: pathlib.Path
if isinstance(path, pathlib.Path):
parsed_path = path
elif isinstance(path, str):
parsed_path = pathlib.Path(path)
elif isinstance(path, bytes):
parsed_path = pathlib.Path(path.decode())
else:
parsed_path = pathlib.Path(str(path))
if parsed_path.is_absolute():
return parsed_path
else:
if allow_relative:
base_directory = base_directory or pathlib.Path.cwd()
# relative to current working directory
return base_directory / parsed_path
else:
raise ValueError(f"Path {str(parsed_path)} is not absolute")
[docs]def parse_optional_path(
path: Optional[AnyPath],
allow_relative: bool = True,
base_directory: Optional[pathlib.Path] = None,
) -> Optional[pathlib.Path]:
"""Parse an optional path to a `pathlib.Path` object.
All resulting paths are resolved, absolute paths. If `allow_relative` is True,
then relative paths are allowed as input, and are resolved relative to the
current working directory, or relative to `base_directory` if it is
specified.
Args:
path: The path to parse. Can be a string, bytes, or `os.PathLike`.
allow_relative: If True, then relative paths are allowed as input, and
are resolved relative to the current working directory. If False,
an error is raised if the path is not absolute.
base_directory: If specified, then relative paths are resolved relative
to this directory, instead of the current working directory.
Returns:
A `pathlib.Path` object, or None if `path` is None.
"""
if path is None:
return None
else:
return parse_path(path, allow_relative, base_directory)
[docs]def split_in_half(x: int) -> Tuple[int, int]:
"""Split an integer in half, rounding up.
This is to ensure that the two halves sum to the original integer.
Args:
x: The integer to split.
Returns:
A tuple containing the two halves of `x`.
"""
half = x // 2
return half, x - half
[docs]def clear_screen() -> None:
"""Clears the console screen."""
if os.name == "nt": # Windows
os.system("cls")
else:
os.system("clear")