Source code for imitation.util.sacred

"""Helper methods for the `sacred` experimental configuration and logging framework."""

import json
import os
import pathlib
import warnings
from typing import Any, Callable, NamedTuple, Optional, Sequence

import sacred
import sacred.observers
import sacred.run

from imitation.data import types
from imitation.util import util


[docs]class SacredDicts(NamedTuple): """Each dict `foo` is loaded from `f"{sacred_dir}/foo.json"`.""" sacred_dir: pathlib.Path config: dict run: dict
[docs] @classmethod def load_from_dir(cls, sacred_dir: pathlib.Path): return cls( sacred_dir=sacred_dir, config=json.loads((sacred_dir / "config.json").read_text()), run=json.loads((sacred_dir / "run.json").read_text()), )
[docs]def dir_contains_sacred_jsons(dir_path: pathlib.Path) -> bool: run_path = dir_path / "run.json" config_path = dir_path / "config.json" return run_path.is_file() and config_path.is_file()
[docs]def filter_subdirs( root_dir: pathlib.Path, filter_fn: Callable[[pathlib.Path], bool] = dir_contains_sacred_jsons, *, nested_ok: bool = False, ) -> Sequence[pathlib.Path]: """Walks through a directory tree, returning paths to filtered subdirectories. Does not follow symlinks. Args: root_dir: The start of the directory tree walk. filter_fn: A function with takes a directory path and returns True if we should include the directory path in this function's return value. nested_ok: Allow returning "nested" directories, i.e. a return value where some elements are subdirectories of other elements. Returns: A list of all subdirectory paths where `filter_fn(path) == True`. Raises: ValueError: If `nested_ok` is False and one of the filtered directory paths is a subdirecotry of another. """ filtered_dirs = set() for root_str, _, _ in os.walk(root_dir, followlinks=False): root = pathlib.Path(root_str) if filter_fn(root): filtered_dirs.add(root) if not nested_ok: for dirpath in filtered_dirs: for other_dirpath in filtered_dirs: if dirpath != other_dirpath and other_dirpath in dirpath.parents: raise ValueError( f"Found nested directories: {dirpath} and {other_dirpath}", ) return list(filtered_dirs)
[docs]def get_sacred_dir_from_run(run: sacred.run.Run) -> Optional[pathlib.Path]: """Returns path to the sacred directory, or None if not found.""" for obs in run.observers: if isinstance(obs, sacred.observers.FileStorageObserver): return util.parse_path(obs.dir) return None
[docs]def dict_get_nested(d: dict, nested_key: str, *, sep=".", default=None) -> Any: curr = d for key in nested_key.split(sep): if isinstance(curr, dict) and key in curr: curr = curr[key] else: return default return curr