Source code for imitation.scripts.ingredients.wb

"""This ingredient provides Weights & Biases logging."""

import logging
from typing import Any, Mapping, Optional

import sacred

wandb_ingredient = sacred.Ingredient("logging.wandb")
logger = logging.getLogger(__name__)


@wandb_ingredient.config
def wandb_config():
    # Other users can overwrite this function to customize their wandb.init() call.
    wandb_tag = None  # User-specified tag for this run
    wandb_name_prefix = ""  # User-specified prefix for the run name
    wandb_kwargs = dict(
        project="imitation",
        monitor_gym=False,
        save_code=False,
    )  # Other kwargs to pass to wandb.init()
    wandb_additional_info = dict()

    locals()


[docs]@wandb_ingredient.capture def wandb_init( _run, wandb_name_prefix: str, wandb_tag: Optional[str], wandb_kwargs: Mapping[str, Any], wandb_additional_info: Mapping[str, Any], log_dir: str, ) -> None: """Putting everything together to get the W&B kwargs for wandb.init(). Args: wandb_name_prefix: User-specified prefix for wandb run name. wandb_tag: User-specified tag for this run. wandb_kwargs: User-specified kwargs for wandb.init(). wandb_additional_info: User-specific additional info to add to wandb experiment ``config``. log_dir: W&B logs will be stored in directory `{log_dir}/wandb/`. Raises: ModuleNotFoundError: wandb is not installed. """ env_name = _run.config["environment"]["gym_id"] root_seed = _run.config["seed"] updated_wandb_kwargs: Mapping[str, Any] = { **wandb_kwargs, "name": f"{wandb_name_prefix}-{env_name}-seed{root_seed}", "tags": [env_name, f"seed{root_seed}"] + ([wandb_tag] if wandb_tag else []), "dir": log_dir, } try: import wandb except ModuleNotFoundError as e: raise ModuleNotFoundError( "Trying to call `wandb.init()` but `wandb` not installed: " "try `pip install wandb`.", ) from e wandb_config_dict = dict(**_run.config) wandb_config_dict.update(wandb_additional_info) wandb.init(config=wandb_config_dict, **updated_wandb_kwargs)