From 125ccead714002224167415f6e80573e161f5146 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Mon, 22 Mar 2021 09:45:17 -0500 Subject: [PATCH] feat(wandb): logging and configuration improvements (#10826) * feat: ensure unique artifact id * feat: allow manual init * fix: simplify reinit logic * fix: no dropped value + immediate commits * fix: wandb use in sagemaker * docs: improve documenation and formatting * fix: typos * docs: improve formatting --- examples/README.md | 33 +++----------------- src/transformers/integrations.py | 53 +++++++++++++++----------------- 2 files changed, 29 insertions(+), 57 deletions(-) diff --git a/examples/README.md b/examples/README.md index 4e2e4afc452782..49e693e731583f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -240,34 +240,11 @@ Whenever you use `Trainer` or `TFTrainer` classes, your losses, evaluation metri Advanced configuration is possible by setting environment variables: - - - - - - - - - - - - - - - - - - - - - -
Environment VariablesOptions
WANDB_LOG_MODELLog the model as artifact at the end of training (false by default)
WANDB_WATCH -
    -
  • gradients (default): Log histograms of the gradients
  • -
  • all: Log histograms of gradients and parameters
  • -
  • false: No gradient or parameter logging
  • -
-
WANDB_PROJECTOrganize runs by project
+| Environment Variable | Value | +|---|---| +| WANDB_LOG_MODEL | Log the model as artifact (log the model as artifact at the end of training (`false` by default) | +| WANDB_WATCH | one of `gradients` (default) to log histograms of gradients, `all` to log histograms of both gradients and parameters, or `false` for no histogram logging | +| WANDB_PROJECT | Organize runs by project | Set run names with `run_name` argument present in scripts or as part of `TrainingArguments`. diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 86b3b27b23ba35..cdde91021b4103 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -19,7 +19,6 @@ import json import numbers import os -import re import tempfile from copy import deepcopy from pathlib import Path @@ -559,20 +558,12 @@ def __init__(self): if has_wandb: import wandb - wandb.ensure_configured() - if wandb.api.api_key is None: - has_wandb = False - logger.warning( - "W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable." - ) - self._wandb = None - else: - self._wandb = wandb + self._wandb = wandb self._initialized = False # log outputs self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}) - def setup(self, args, state, model, reinit, **kwargs): + def setup(self, args, state, model, **kwargs): """ Setup the optional Weights & Biases (`wandb`) integration. @@ -581,7 +572,8 @@ def setup(self, args, state, model, reinit, **kwargs): Environment: WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to log model as artifact at the end of training. + Whether or not to log model as artifact at the end of training. Use along with + `TrainingArguments.load_best_model_at_end` to upload best model. WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`): Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient logging or :obj:`"all"` to log gradients and parameters. @@ -610,13 +602,19 @@ def setup(self, args, state, model, reinit, **kwargs): else: run_name = args.run_name - self._wandb.init( - project=os.getenv("WANDB_PROJECT", "huggingface"), - config=combined_dict, - name=run_name, - reinit=reinit, - **init_args, - ) + if self._wandb.run is None: + self._wandb.init( + project=os.getenv("WANDB_PROJECT", "huggingface"), + name=run_name, + **init_args, + ) + # add config parameters (run may have been created manually) + self._wandb.config.update(combined_dict, allow_val_change=True) + + # define default x-axis (for latest wandb versions) + if getattr(self._wandb, "define_metric", None): + self._wandb.define_metric("train/global_step") + self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True) # keep track of model topology and gradients, unsupported on TPU if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": @@ -628,23 +626,20 @@ def on_train_begin(self, args, state, control, model=None, **kwargs): if self._wandb is None: return hp_search = state.is_hyper_param_search - if not self._initialized or hp_search: - self.setup(args, state, model, reinit=hp_search, **kwargs) + if hp_search: + self._wandb.finish() + if not self._initialized: + self.setup(args, state, model, **kwargs) def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs): if self._wandb is None: return - # commit last step - if state.is_world_process_zero: - self._wandb.log({}) if self._log_model and self._initialized and state.is_world_process_zero: from .trainer import Trainer fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer) with tempfile.TemporaryDirectory() as temp_dir: fake_trainer.save_model(temp_dir) - # use run name and ensure it's a valid Artifact name - artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self._wandb.run.name) metadata = ( { k: v @@ -657,7 +652,7 @@ def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwarg "train/total_floss": state.total_flos, } ) - artifact = self._wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata) + artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata) for f in Path(temp_dir).glob("*"): if f.is_file(): with artifact.new_file(f.name, mode="wb") as fa: @@ -668,10 +663,10 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs): if self._wandb is None: return if not self._initialized: - self.setup(args, state, model, reinit=False) + self.setup(args, state, model) if state.is_world_process_zero: logs = rewrite_logs(logs) - self._wandb.log(logs, step=state.global_step) + self._wandb.log({**logs, "train/global_step": state.global_step}) class CometCallback(TrainerCallback):