Skip to content

Commit

Permalink
working on utils
Browse files Browse the repository at this point in the history
  • Loading branch information
shoeybi committed Mar 28, 2020
1 parent 94e2ca5 commit 5050203
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 221 deletions.
6 changes: 5 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def parse_args(extra_args_provider=None, defaults={}):

# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
assert getattr(args, key) is None, \
'defaults can only be overwritten for args with None values.'
setattr(args, key, defaults[key])

# Distributed args.
Expand All @@ -60,7 +65,6 @@ def parse_args(extra_args_provider=None, defaults={}):
if args.loss_scale is None:
args.dynamic_loss_scale = True


# Checks.
assert args.hidden_size % args.num_attention_heads == 0
assert args.max_position_embeddings >= args.seq_length
Expand Down
84 changes: 83 additions & 1 deletion megatron/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

import os
import sys
import time

import torch

from megatron.data.tokenizer import build_tokenizer
from .arguments import parse_args
from .utils import Timers

_GLOBAL_ARGS = None
_GLOBAL_TOKENIZER = None
Expand Down Expand Up @@ -137,3 +139,83 @@ def _ensure_var_is_initialized(var, name):
def _ensure_var_is_not_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is None, '{} is already initialized.'.format(name)


class Timers:
"""Group of timers."""

class Timer:
"""Timer."""

def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()

def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True

def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False

def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False

def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_

def __init__(self):
self.timers = {}

def __call__(self, name):
if name not in self.timers:
self.timers[name] = self.Timer(name)
return self.timers[name]

def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '_time', value, iteration)

def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0/ normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(string, flush=True)
else:
print(string, flush=True)
49 changes: 13 additions & 36 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,28 @@
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam

from megatron.arguments import parse_args
from megatron.global_vars import get_args
from megatron.global_vars import get_timers
from megatron.global_vars import get_tensorboard_writer
from megatron.global_vars import get_adlr_autoresume
from megatron.initialize import initialize_megatron

from megatron import mpu
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import enable_adlr_autoresume
from megatron.utils import get_tensorboard_writer
from megatron.utils import initialize_distributed
from megatron.utils import load_checkpoint
from megatron.utils import print_args
from megatron.utils import print_rank_0
from megatron.utils import report_memory
from megatron.utils import save_checkpoint
from megatron.utils import set_random_seed
from megatron.utils import Timers


def run(top_level_message, train_val_test_data_provider,
model_provider, forward_step_func, extra_args_provider=None):
model_provider, forward_step_func, extra_args_provider=None,
args_defaults={}):
"""Main training program.
This function will run the followings in the order provided:
Expand Down Expand Up @@ -72,8 +72,11 @@ def run(top_level_message, train_val_test_data_provider,
"""

# Initalize and get arguments, timers, and Tensorboard writer.
args = parse_args(extra_args_provider=extra_args_provider)
timers, writer = initialize_megatron(top_level_message, args)
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
args = get_args()
timers = get_timers()
writer = get_tensorboard_writer()

# Data stuff.
train_data, val_data, test_data = train_val_test_data_provider(args)
Expand Down Expand Up @@ -116,32 +119,6 @@ def run(top_level_message, train_val_test_data_provider,
args, None, 0, timers, True)


def initialize_megatron(message, args):
""""Initialize distributed, random seed, and autoresume."""

# Timer.
timers = Timers()

# Tensorboard writer.
writer = get_tensorboard_writer(args)

# Pytorch distributed.
initialize_distributed(args)
if torch.distributed.get_rank() == 0:
print(message, flush=True)
print_args(args, writer)

# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)

# Random seeds for reproducability.
set_random_seed(args.seed)

return timers, writer


def get_model(model_provider_func, args):
"""Build the model."""

Expand Down
Loading

0 comments on commit 5050203

Please sign in to comment.