imitation.util.logger#

Logging for quantitative metrics and free-form text.

Functions

configure([folder, format_strs])

Configure Stable Baselines logger to be accumulate_means()-compatible.

make_output_format(_format, log_dir[, ...])

Returns a logger for the requested format.

Classes

HierarchicalLogger(default_logger[, format_strs])

A logger supporting contexts for accumulating mean values.

WandbOutputFormat()

A stable-baseline logger that writes to wandb.

class imitation.util.logger.HierarchicalLogger(default_logger, format_strs=('stdout', 'log', 'csv'))[source]#

Bases: Logger

A logger supporting contexts for accumulating mean values.

self.accumulate_means creates a context manager. While in this context, values are loggged to a sub-logger, with only mean values recorded in the top-level (root) logger.

>>> import tempfile
>>> with tempfile.TemporaryDirectory() as dir:
...     logger: HierarchicalLogger = configure(dir, ('log',))
...     # record the key value pair (loss, 1.0) to path `dir`
...     # at step 1.
...     logger.record("loss", 1.0)
...     logger.dump(step=1)
...     with logger.accumulate_means("dataset"):
...         # record the key value pair `("raw/dataset/entropy", 5.0)` to path
...         # `dir/raw/dataset` at step 100
...         logger.record("entropy", 5.0)
...         logger.dump(step=100)
...         # record the key value pair `("raw/dataset/entropy", 6.0)` to path
...         # `dir/raw/dataset` at step 200
...         logger.record("entropy", 6.0)
...         logger.dump(step=200)
...     # record the key value pair `("mean/dataset/entropy", 5.5)` to path
...     # `dir` at step 1.
...     logger.dump(step=1)
...     with logger.add_accumulate_prefix("foo"), logger.accumulate_means("bar"):
...         # record the key value pair ("raw/foo/bar/biz", 42.0) to path
...         # `dir/raw/foo/bar` at step 2000
...         logger.record("biz", 42.0)
...         logger.dump(step=2000)
...     # record the key value pair `("mean/foo/bar/biz", 42.0)` to path
...     # `dir` at step 1.
...     logger.dump(step=1)
...     with open(os.path.join(dir, 'log.txt')) as f:
...         print(f.read())
-------------------
| loss | 1        |
-------------------
---------------------------------
| mean/              |          |
|    dataset/entropy | 5.5      |
---------------------------------
-----------------------------
| mean/          |          |
|    foo/bar/biz | 42       |
-----------------------------
__init__(default_logger, format_strs=('stdout', 'log', 'csv'))[source]#

Builds HierarchicalLogger.

Parameters
  • default_logger (Logger) – The default logger when not in an accumulate_means context. Also the logger to which mean values are written to after exiting from a context.

  • format_strs (Sequence[str]) – A list of output format strings that should be used by every Logger initialized by this class during an AccumulatingMeans context. For details on available output formats see stable_baselines3.logger.make_output_format.

accumulate_means(name)[source]#

Temporarily modifies this HierarchicalLogger to accumulate means values.

Within this context manager, self.record(key, value) writes the “raw” values in f"{self.default_logger.log_dir}/[{accumulate_prefix}/]{name}" under the key "raw/[{accumulate_prefix}/]{name}/[{key_prefix}/]{key}", where accumulate_prefix is the concatenation of all prefixes added by add_accumulate_prefix and key_prefix is the concatenation of all prefixes added by add_key_prefix, if any. At the same time, any call to self.record will also accumulate mean values on the default logger by calling:

self.default_logger.record_mean(
    f"mean/[{accumulate_prefix}/]{name}/[{key_prefix}/]{key}",
    value,
)

Multiple prefixes may be active at once. In this case the prefix is simply the concatenation of each of the active prefixes in the order they were created e.g. if the active prefixes are ['foo', 'bar'] then the prefix is 'foo/bar'.

After the context exits, calling self.dump() will write the means of all the “raw” values accumulated during this context to self.default_logger under keys of the form mean/{prefix}/{name}/{key}

Note that the behavior of other logging methods, log and record_mean are unmodified and will go straight to the default logger.

Parameters

name (str) – A string key which determines the folder where raw data is written and temporary logging prefixes for raw and mean data. Entering an accumulate_means context in the future with the same subdir will safely append to logs written in this folder rather than overwrite.

Yields

None when the context is entered.

Raises

RuntimeError – If this context is entered into while already in an accumulate_means context.

Return type

Generator[None, None, None]

add_accumulate_prefix(prefix)[source]#

Add a prefix to the subdirectory used to accumulate means.

This prefix only applies when a accumulate_means context is active. If there are multiple active prefixes, then they are concatenated.

Parameters

prefix (str) – The prefix to add to the named sub.

Yields

None when the context manager is entered

Raises

RuntimeError – if accumulate means context is already active.

Return type

Generator[None, None, None]

add_key_prefix(prefix)[source]#

Add a prefix to the keys logged during an accumulate_means context.

This prefix only applies when a accumulate_means context is active. If there are multiple active prefixes, then they are concatenated.

Parameters

prefix (str) – The prefix to add to the keys.

Yields

None when the context manager is entered

Raises

RuntimeError – if accumulate means context is already active.

Return type

Generator[None, None, None]

close()[source]#

closes the file

current_logger: Optional[Logger]#
default_logger: Logger#
dump(step=0)[source]#

Write all of the diagnostics from the current iteration

format_strs: Sequence[str]#
get_accumulate_prefixes()[source]#
Return type

str

get_dir()[source]#

Get directory that log files are being written to. will be None if there is no output directory (i.e., if you didn’t call start)

Return type

str

Returns

the logging directory

log(*args, **kwargs)[source]#

Write the sequence of args, with no separators, to the console and output files (if you’ve configured an output file).

level: int. (see logger.py docs) If the global logger level is higher than

the level argument here, don’t print to stdout.

Parameters
  • args – log the arguments

  • level – the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)

record(key, val, exclude=None)[source]#

Log a value of some diagnostic Call this once for each diagnostic quantity, each iteration If called many times, last value will be used.

Parameters
  • key – save to log this key

  • value – save to log this value

  • exclude – outputs to be excluded

record_mean(key, val, exclude=None)[source]#

The same as record(), but if called many times, values averaged.

Parameters
  • key – save to log this key

  • value – save to log this value

  • exclude – outputs to be excluded

set_level(level)[source]#

Set logging threshold on current logger.

Parameters

level (int) – the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)

Return type

None

class imitation.util.logger.WandbOutputFormat[source]#

Bases: KVWriter

A stable-baseline logger that writes to wandb.

Users need to call wandb.init() before initializing WandbOutputFormat.

__init__()[source]#

Initializes an instance of WandbOutputFormat.

Raises

ModuleNotFoundError – wandb is not installed.

close()[source]#

Close owned resources

Return type

None

write(key_values, key_excluded, step=0)[source]#

Write a dictionary to file

Parameters
  • key_values (Dict[str, Any]) –

  • key_excluded (Dict[str, Tuple[str, ...]]) –

  • step (int) –

Return type

None

imitation.util.logger.configure(folder=None, format_strs=None)[source]#

Configure Stable Baselines logger to be accumulate_means()-compatible.

After this function is called, stable_baselines3.logger.{configure,reset}() are replaced with stubs that raise RuntimeError.

Parameters
  • folder (Union[str, bytes, PathLike, None]) – Argument from stable_baselines3.logger.configure.

  • format_strs (Optional[Sequence[str]]) – An list of output format strings. For details on available output formats see stable_baselines3.logger.make_output_format.

Return type

HierarchicalLogger

Returns

The configured HierarchicalLogger instance.

imitation.util.logger.make_output_format(_format, log_dir, log_suffix='', max_length=50)[source]#

Returns a logger for the requested format.

Parameters
  • _format (str) – the requested format to log to (‘stdout’, ‘log’, ‘json’ or ‘csv’ or ‘tensorboard’).

  • log_dir (str) – the logging directory.

  • log_suffix (str) – the suffix for the log file.

  • max_length (int) – the maximum length beyond which the keys get truncated.

Return type

KVWriter

Returns

the logger.