forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmicrobatches.py
144 lines (115 loc) · 6.18 KB
/
microbatches.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron number of micro-batches calculators."""
from abc import ABC
from abc import abstractmethod
def build_num_microbatches_calculator(args):
# Constant num micro-batches.
if args.rampup_batch_size is None:
num_microbatches_calculator = ConstantNumMicroBatches(
args.global_batch_size, args.micro_batch_size,
args.data_parallel_size)
if args.rank == 0:
print('setting number of micro-batches to constant {}'.format(
num_microbatches_calculator.get()), flush=True)
else:
assert len(args.rampup_batch_size) == 3, 'expected the following ' \
'format: --rampup-batch-size <start batch size> ' \
'<batch size incerement> <ramp-up samples>'
start_batch_size = int(args.rampup_batch_size[0])
batch_size_increment = int(args.rampup_batch_size[1])
ramup_samples = int(args.rampup_batch_size[2])
if args.rank == 0:
print('will use batch size rampup starting from global batch '
'size {} to global batch size {} with batch size increments '
'{} over {} samples.'.format(start_batch_size,
args.global_batch_size,
batch_size_increment,
ramup_samples), flush=True)
num_microbatches_calculator = RampupBatchsizeNumMicroBatches(
start_batch_size, batch_size_increment, ramup_samples,
args.global_batch_size, args.micro_batch_size,
args.data_parallel_size)
return num_microbatches_calculator
class NumMicroBatchesCalculator(ABC):
def __init__(self):
self.num_micro_batches = None
self.current_global_batch_size = None
def get(self):
return self.num_micro_batches
def get_current_global_batch_size(self):
return self.current_global_batch_size
@abstractmethod
def update(self, consumed_samples, consistency_check):
pass
class ConstantNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, global_batch_size, micro_batch_size, data_parallel_size):
micro_batch_times_data_parallel = micro_batch_size * \
data_parallel_size
assert global_batch_size % micro_batch_times_data_parallel == 0, \
'global batch size ({}) is not divisible by micro batch size ({})' \
' times data parallel size ({})'.format(global_batch_size,
micro_batch_size,
data_parallel_size)
self.num_micro_batches = global_batch_size // \
micro_batch_times_data_parallel
assert self.num_micro_batches >= 1
self.current_global_batch_size = global_batch_size
def update(self, consumed_samples, consistency_check):
pass
class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, start_batch_size, batch_size_increment, ramup_samples,
global_batch_size, micro_batch_size, data_parallel_size):
"""Batch size ramp up.
Over
steps = (global-batch-size - start-batch-size) / batch_size_increment
increment batch size from start-batch-size to global-batch-size using
rampup-samples / steps
samples.
Arguments:
start_batch_size: global batch size to start with
batch_size_increment: global batch size increments
ramup_samples: number of samples to use ramp up global
batch size from `start_batch_size` to `global_batch_size`
global_batch_size: global batch size post rampup
micro_batch_size: micro batch size
data_parallel_size: data parallel size.
"""
self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
self.data_parallel_size
assert self.micro_batch_times_data_parallel_size > 0
assert start_batch_size > 0
self.start_batch_size = start_batch_size
assert global_batch_size > 0
self.global_batch_size = global_batch_size
diff_batch_size = self.global_batch_size - self.start_batch_size
assert diff_batch_size >= 0
assert batch_size_increment > 0
self.batch_size_increment = batch_size_increment
assert diff_batch_size % batch_size_increment == 0, 'expected ' \
'global batch size interval ({}) to be divisible by global batch ' \
'size increment ({})'.format(diff_batch_size, batch_size_increment)
num_increments = diff_batch_size // self.batch_size_increment
self.ramup_samples = ramup_samples
assert self.ramup_samples >= 0
self.rampup_samples_per_increment = self.ramup_samples / num_increments
# Initialize number of microbatches.
self.update(0, False)
def update(self, consumed_samples, consistency_check):
if consumed_samples > self.ramup_samples:
self.current_global_batch_size = self.global_batch_size
else:
steps = int(consumed_samples / self.rampup_samples_per_increment)
self.current_global_batch_size = self.start_batch_size + \
steps * self.batch_size_increment
assert self.current_global_batch_size <= self.global_batch_size
if consistency_check:
assert self.current_global_batch_size % \
self.micro_batch_times_data_parallel_size == 0, 'current global ' \
'batch size ({}) is not divisible by micro-batch-size ({}) times' \
'data parallel size ({})'.format(self.current_global_batch_size,
self.micro_batch_size,
self.data_parallel_size)
self.num_micro_batches = self.current_global_batch_size // \
self.micro_batch_times_data_parallel_size