Skip to content

Commit

Permalink
Trainer: add logging through Weights & Biases (huggingface#3916)
Browse files Browse the repository at this point in the history
* feat: add logging through Weights & Biases

* feat(wandb): make logging compatible with all scripts

* style(trainer.py): fix formatting

* [Trainer] Tweak wandb integration

Co-authored-by: Julien Chaumond <[email protected]>
  • Loading branch information
borisdayma and julien-c authored May 5, 2020
1 parent 858b1d1 commit 818463e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ proc_data
# examples
runs
/runs_old
/wandb
examples/runs

# data
Expand Down
30 changes: 29 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ def is_tensorboard_available():
return _has_tensorboard


try:
import wandb

_has_wandb = True
except ImportError:
_has_wandb = False


def is_wandb_available():
return _has_wandb


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -151,6 +163,10 @@ def __init__(
logger.warning(
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
)
if not is_wandb_available():
logger.info(
"You are instantiating a Trainer but wandb is not installed. Install it to use Weights & Biases logging."
)
set_seed(self.args.seed)
# Create output directory if needed
if self.args.local_rank in [-1, 0]:
Expand Down Expand Up @@ -209,6 +225,12 @@ def get_optimizers(
)
return optimizer, scheduler

def _setup_wandb(self):
# Start a wandb run and log config parameters
wandb.init(name=self.args.logging_dir, config=vars(self.args))
# keep track of model topology and gradients
# wandb.watch(self.model)

def train(self, model_path: Optional[str] = None):
"""
Main training entry point.
Expand Down Expand Up @@ -263,6 +285,9 @@ def train(self, model_path: Optional[str] = None):

if self.tb_writer is not None:
self.tb_writer.add_text("args", self.args.to_json_string())
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
if is_wandb_available():
self._setup_wandb()

# Train!
logger.info("***** Running training *****")
Expand Down Expand Up @@ -351,6 +376,9 @@ def train(self, model_path: Optional[str] = None):
if self.tb_writer:
for k, v in logs.items():
self.tb_writer.add_scalar(k, v, global_step)
if is_wandb_available():
wandb.log(logs, step=global_step)

epoch_iterator.write(json.dumps({**logs, **{"step": global_step}}))

if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
Expand Down Expand Up @@ -467,7 +495,7 @@ def _rotate_checkpoints(self, use_mtime=False) -> None:
shutil.rmtree(checkpoint)

def evaluate(
self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None
self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None,
) -> Dict[str, float]:
"""
Run evaluation and return metrics.
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
from dataclasses import dataclass, field
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple

from .file_utils import cached_property, is_torch_available, torch_required

Expand Down Expand Up @@ -138,3 +138,13 @@ def to_json_string(self):
Serializes this instance to a JSON string.
"""
return json.dumps(dataclasses.asdict(self), indent=2)

def to_sanitized_dict(self) -> Dict[str, Any]:
"""
Sanitized serialization to use with TensorBoard’s hparams
"""
d = dataclasses.asdict(self)
valid_types = [bool, int, float, str]
if is_torch_available():
valid_types.append(torch.Tensor)
return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}

0 comments on commit 818463e

Please sign in to comment.