Skip to content

Commit

Permalink
feat(wandb): logging and configuration improvements (huggingface#10826)
Browse files Browse the repository at this point in the history
* feat: ensure unique artifact id

* feat: allow manual init

* fix: simplify reinit logic

* fix: no dropped value + immediate commits

* fix: wandb use in sagemaker

* docs: improve documenation and formatting

* fix: typos

* docs: improve formatting
  • Loading branch information
borisdayma authored Mar 22, 2021
1 parent b230181 commit 125ccea
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 57 deletions.
33 changes: 5 additions & 28 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,34 +240,11 @@ Whenever you use `Trainer` or `TFTrainer` classes, your losses, evaluation metri

Advanced configuration is possible by setting environment variables:

<table>
<thead>
<tr>
<th style="text-align:left">Environment Variables</th>
<th style="text-align:left">Options</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align:left">WANDB_LOG_MODEL</td>
<td style="text-align:left">Log the model as artifact at the end of training (<b>false</b> by default)</td>
</tr>
<tr>
<td style="text-align:left">WANDB_WATCH</td>
<td style="text-align:left">
<ul>
<li><b>gradients</b> (default): Log histograms of the gradients</li>
<li><b>all</b>: Log histograms of gradients and parameters</li>
<li><b>false</b>: No gradient or parameter logging</li>
</ul>
</td>
</tr>
<tr>
<td style="text-align:left">WANDB_PROJECT</td>
<td style="text-align:left">Organize runs by project</td>
</tr>
</tbody>
</table>
| Environment Variable | Value |
|---|---|
| WANDB_LOG_MODEL | Log the model as artifact (log the model as artifact at the end of training (`false` by default) |
| WANDB_WATCH | one of `gradients` (default) to log histograms of gradients, `all` to log histograms of both gradients and parameters, or `false` for no histogram logging |
| WANDB_PROJECT | Organize runs by project |

Set run names with `run_name` argument present in scripts or as part of `TrainingArguments`.

Expand Down
53 changes: 24 additions & 29 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import json
import numbers
import os
import re
import tempfile
from copy import deepcopy
from pathlib import Path
Expand Down Expand Up @@ -559,20 +558,12 @@ def __init__(self):
if has_wandb:
import wandb

wandb.ensure_configured()
if wandb.api.api_key is None:
has_wandb = False
logger.warning(
"W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable."
)
self._wandb = None
else:
self._wandb = wandb
self._wandb = wandb
self._initialized = False
# log outputs
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})

def setup(self, args, state, model, reinit, **kwargs):
def setup(self, args, state, model, **kwargs):
"""
Setup the optional Weights & Biases (`wandb`) integration.
Expand All @@ -581,7 +572,8 @@ def setup(self, args, state, model, reinit, **kwargs):
Environment:
WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to log model as artifact at the end of training.
Whether or not to log model as artifact at the end of training. Use along with
`TrainingArguments.load_best_model_at_end` to upload best model.
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
logging or :obj:`"all"` to log gradients and parameters.
Expand Down Expand Up @@ -610,13 +602,19 @@ def setup(self, args, state, model, reinit, **kwargs):
else:
run_name = args.run_name

self._wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"),
config=combined_dict,
name=run_name,
reinit=reinit,
**init_args,
)
if self._wandb.run is None:
self._wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"),
name=run_name,
**init_args,
)
# add config parameters (run may have been created manually)
self._wandb.config.update(combined_dict, allow_val_change=True)

# define default x-axis (for latest wandb versions)
if getattr(self._wandb, "define_metric", None):
self._wandb.define_metric("train/global_step")
self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)

# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
Expand All @@ -628,23 +626,20 @@ def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
return
hp_search = state.is_hyper_param_search
if not self._initialized or hp_search:
self.setup(args, state, model, reinit=hp_search, **kwargs)
if hp_search:
self._wandb.finish()
if not self._initialized:
self.setup(args, state, model, **kwargs)

def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None:
return
# commit last step
if state.is_world_process_zero:
self._wandb.log({})
if self._log_model and self._initialized and state.is_world_process_zero:
from .trainer import Trainer

fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
with tempfile.TemporaryDirectory() as temp_dir:
fake_trainer.save_model(temp_dir)
# use run name and ensure it's a valid Artifact name
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self._wandb.run.name)
metadata = (
{
k: v
Expand All @@ -657,7 +652,7 @@ def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwarg
"train/total_floss": state.total_flos,
}
)
artifact = self._wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa:
Expand All @@ -668,10 +663,10 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if self._wandb is None:
return
if not self._initialized:
self.setup(args, state, model, reinit=False)
self.setup(args, state, model)
if state.is_world_process_zero:
logs = rewrite_logs(logs)
self._wandb.log(logs, step=state.global_step)
self._wandb.log({**logs, "train/global_step": state.global_step})


class CometCallback(TrainerCallback):
Expand Down

0 comments on commit 125ccea

Please sign in to comment.