forked from twitter/the-algorithm-ml
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_pipeline.py
626 lines (532 loc) · 21.4 KB
/
train_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
"""
Taken from https://raw.githubusercontent.com/pytorch/torchrec/v0.3.2/torchrec/distributed/train_pipeline.py
with TrainPipelineSparseDist.progress modified to support gradient accumulation.
"""
import abc
from dataclasses import dataclass, field
import logging
from typing import (
Any,
cast,
Dict,
Generic,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
)
import torch
from torch.autograd.profiler import record_function
from torch.fx.node import Node
from torchrec.distributed.model_parallel import (
DistributedModelParallel,
ShardedModule,
)
from torchrec.distributed.types import Awaitable
from torchrec.modules.feature_processor import BaseGroupedFeatureProcessor
from torchrec.streamable import Multistreamable, Pipelineable
logger: logging.Logger = logging.getLogger(__name__)
In = TypeVar("In", bound=Pipelineable)
Out = TypeVar("Out")
class TrainPipeline(abc.ABC, Generic[In, Out]):
@abc.abstractmethod
def progress(self, dataloader_iter: Iterator[In]) -> Out:
pass
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
assert isinstance(
batch, (torch.Tensor, Pipelineable)
), f"{type(batch)} must implement Pipelineable interface"
return cast(In, batch.to(device=device, non_blocking=non_blocking))
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
if stream is None:
return
torch.cuda.current_stream().wait_stream(stream)
# As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html,
# PyTorch uses the "caching allocator" for memory allocation for tensors. When a tensor is
# freed, its memory is likely to be reused by newly constructed tenosrs. By default,
# this allocator traces whether a tensor is still in use by only the CUDA stream where it
# was created. When a tensor is used by additional CUDA streams, we need to call record_stream
# to tell the allocator about all these streams. Otherwise, the allocator might free the
# underlying memory of the tensor once it is no longer used by the creator stream. This is
# a notable programming trick when we write programs using multi CUDA streams.
cur_stream = torch.cuda.current_stream()
assert isinstance(
batch, (torch.Tensor, Multistreamable)
), f"{type(batch)} must implement Multistreamable interface"
batch.record_stream(cur_stream)
class TrainPipelineBase(TrainPipeline[In, Out]):
"""
This class runs training iterations using a pipeline of two stages, each as a CUDA
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
memory, and the default stream runs forward, backward, and optimization.
"""
def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device,
) -> None:
self._model = model
self._optimizer = optimizer
self._device = device
self._memcpy_stream: Optional[torch.cuda.streams.Stream] = (
torch.cuda.Stream() if device.type == "cuda" else None
)
self._cur_batch: Optional[In] = None
self._connected = False
def _connect(self, dataloader_iter: Iterator[In]) -> None:
cur_batch = next(dataloader_iter)
self._cur_batch = cur_batch
with torch.cuda.stream(self._memcpy_stream):
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
self._connected = True
def progress(self, dataloader_iter: Iterator[In]) -> Out:
if not self._connected:
self._connect(dataloader_iter)
# Fetch next batch
with record_function("## next_batch ##"):
next_batch = next(dataloader_iter)
cur_batch = self._cur_batch
assert cur_batch is not None
if self._model.training:
with record_function("## zero_grad ##"):
self._optimizer.zero_grad()
with record_function("## wait_for_batch ##"):
_wait_for_batch(cur_batch, self._memcpy_stream)
with record_function("## forward ##"):
(losses, output) = self._model(cur_batch)
if self._model.training:
with record_function("## backward ##"):
torch.sum(losses, dim=0).backward()
# Copy the next batch to GPU
self._cur_batch = cur_batch = next_batch
with record_function("## copy_batch_to_gpu ##"):
with torch.cuda.stream(self._memcpy_stream):
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
# Update
if self._model.training:
with record_function("## optimizer ##"):
self._optimizer.step()
return output
class Tracer(torch.fx.Tracer):
# Disable proxying buffers during tracing. Ideally, proxying buffers would
# be disabled, but some models are currently mutating buffer values, which
# causes errors during tracing. If those models can be rewritten to not do
# that, we can likely remove this line
proxy_buffer_attributes = False
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
super().__init__()
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
return True
return super().is_leaf_module(m, module_qualified_name)
@dataclass
class TrainPipelineContext:
# pyre-ignore [4]
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
# pyre-ignore [4]
feature_processor_forwards: List[Any] = field(default_factory=list)
@dataclass
class ArgInfo:
# attributes of input batch, e.g. batch.attr1.attr2 call
# will produce ["attr1", "attr2"]
input_attrs: List[str]
# batch[attr1].attr2 will produce [True, False]
is_getitems: List[bool]
# name for kwarg of pipelined forward() call or None
# for a positional arg
name: Optional[str]
class PipelinedForward:
def __init__(
self,
name: str,
args: List[ArgInfo],
module: ShardedModule,
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
) -> None:
self._name = name
self._args = args
self._module = module
self._context = context
self._dist_stream = dist_stream
# pyre-ignore [2, 24]
def __call__(self, *input, **kwargs) -> Awaitable:
assert self._name in self._context.input_dist_requests
request = self._context.input_dist_requests[self._name]
assert isinstance(request, Awaitable)
with record_function("## wait_sparse_data_dist ##"):
# Finish waiting on the dist_stream,
# in case some delayed stream scheduling happens during the wait() call.
with torch.cuda.stream(self._dist_stream):
data = request.wait()
# Make sure that both result of input_dist and context
# are properly transferred to the current stream.
if self._dist_stream is not None:
torch.cuda.current_stream().wait_stream(self._dist_stream)
cur_stream = torch.cuda.current_stream()
assert isinstance(
data, (torch.Tensor, Multistreamable)
), f"{type(data)} must implement Multistreamable interface"
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
data.record_stream(cur_stream)
ctx = self._context.module_contexts[self._name]
ctx.record_stream(cur_stream)
if len(self._context.feature_processor_forwards) > 0:
with record_function("## feature_processor ##"):
for sparse_feature in data:
if sparse_feature.id_score_list_features is not None:
for fp_forward in self._context.feature_processor_forwards:
sparse_feature.id_score_list_features = fp_forward(
sparse_feature.id_score_list_features
)
return self._module.compute_and_output_dist(self._context.module_contexts[self._name], data)
@property
def name(self) -> str:
return self._name
@property
def args(self) -> List[ArgInfo]:
return self._args
def _start_data_dist(
pipelined_modules: List[ShardedModule],
batch: In,
context: TrainPipelineContext,
) -> None:
context.input_dist_requests.clear()
context.module_contexts.clear()
for module in pipelined_modules:
forward = module.forward
assert isinstance(forward, PipelinedForward)
# Retrieve argument for the input_dist of EBC
# is_getitem True means this argument could be retrieved by a list
# False means this argument is getting while getattr
# and this info was done in the _rewrite_model by tracing the
# entire model to get the arg_info_list
args = []
kwargs = {}
for arg_info in forward.args:
if arg_info.input_attrs:
arg = batch
for attr, is_getitem in zip(arg_info.input_attrs, arg_info.is_getitems):
if is_getitem:
arg = arg[attr]
else:
arg = getattr(arg, attr)
if arg_info.name:
kwargs[arg_info.name] = arg
else:
args.append(arg)
else:
args.append(None)
# Start input distribution.
module_ctx = module.create_context()
context.module_contexts[forward.name] = module_ctx
context.input_dist_requests[forward.name] = module.input_dist(module_ctx, *args, **kwargs)
# Call wait on the first awaitable in the input dist for the tensor splits
for key, awaitable in context.input_dist_requests.items():
context.input_dist_requests[key] = awaitable.wait()
def _get_node_args_helper(
# pyre-ignore
arguments,
num_found: int,
feature_processor_arguments: Optional[List[Node]] = None,
) -> Tuple[List[ArgInfo], int]:
"""
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
It also counts the number of (args + kwargs) found.
"""
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
for arg, arg_info in zip(arguments, arg_info_list):
if arg is None:
num_found += 1
continue
while True:
if not isinstance(arg, torch.fx.Node):
break
child_node = arg
if child_node.op == "placeholder":
num_found += 1
break
# skip this fp node
elif feature_processor_arguments is not None and child_node in feature_processor_arguments:
arg = child_node.args[0]
elif (
child_node.op == "call_function"
and child_node.target.__module__ == "builtins"
# pyre-ignore[16]
and child_node.target.__name__ == "getattr"
):
arg_info.input_attrs.insert(0, child_node.args[1])
arg_info.is_getitems.insert(0, False)
arg = child_node.args[0]
elif (
child_node.op == "call_function"
and child_node.target.__module__ == "_operator"
# pyre-ignore[16]
and child_node.target.__name__ == "getitem"
):
arg_info.input_attrs.insert(0, child_node.args[1])
arg_info.is_getitems.insert(0, True)
arg = child_node.args[0]
else:
break
return arg_info_list, num_found
def _get_node_args(
node: Node, feature_processor_nodes: Optional[List[Node]] = None
) -> Tuple[List[ArgInfo], int]:
num_found = 0
pos_arg_info_list, num_found = _get_node_args_helper(
node.args, num_found, feature_processor_nodes
)
kwargs_arg_info_list, num_found = _get_node_args_helper(node.kwargs.values(), num_found)
# Replace with proper names for kwargs
for name, arg_info_list in zip(node.kwargs, kwargs_arg_info_list):
arg_info_list.name = name
arg_info_list = pos_arg_info_list + kwargs_arg_info_list
return arg_info_list, num_found
def _get_unsharded_module_names_helper(
model: torch.nn.Module,
path: str,
unsharded_module_names: Set[str],
) -> bool:
sharded_children = set()
for name, child in model.named_children():
curr_path = path + name
if isinstance(child, ShardedModule):
sharded_children.add(name)
else:
child_sharded = _get_unsharded_module_names_helper(
child,
curr_path + ".",
unsharded_module_names,
)
if child_sharded:
sharded_children.add(name)
if len(sharded_children) > 0:
for name, _ in model.named_children():
if name not in sharded_children:
unsharded_module_names.add(path + name)
return len(sharded_children) > 0
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
"""
Returns a list of top level modules do not contain any sharded sub modules.
"""
unsharded_module_names: Set[str] = set()
_get_unsharded_module_names_helper(
model,
"",
unsharded_module_names,
)
return list(unsharded_module_names)
def _rewrite_model( # noqa C901
model: torch.nn.Module,
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
) -> List[ShardedModule]:
# Get underlying nn.Module
if isinstance(model, DistributedModelParallel):
model = model.module
# Collect a list of sharded modules.
sharded_modules = {}
fp_modules = {}
for name, m in model.named_modules():
if isinstance(m, ShardedModule):
sharded_modules[name] = m
if isinstance(m, BaseGroupedFeatureProcessor):
fp_modules[name] = m
# Trace a model.
tracer = Tracer(leaf_modules=_get_unsharded_module_names(model))
graph = tracer.trace(model)
feature_processor_nodes = []
# find the fp node
for node in graph.nodes:
if node.op == "call_module" and node.target in fp_modules:
feature_processor_nodes.append(node)
# Select sharded modules, which are top-level in the forward call graph,
# i.e. which don't have input transformations, i.e.
# rely only on 'builtins.getattr'.
ret = []
for node in graph.nodes:
if node.op == "call_module" and node.target in sharded_modules:
total_num_args = len(node.args) + len(node.kwargs)
if total_num_args == 0:
continue
arg_info_list, num_found = _get_node_args(node, feature_processor_nodes)
if num_found == total_num_args:
logger.info(f"Module '{node.target}'' will be pipelined")
child = sharded_modules[node.target]
child.forward = PipelinedForward(
node.target,
arg_info_list,
child,
context,
dist_stream,
)
ret.append(child)
return ret
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
"""
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
forward and backward. This helps hide the all2all latency while preserving the
training forward / backward ordering.
stage 3: forward, backward - uses default CUDA stream
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
stage 1: device transfer - uses memcpy CUDA stream
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
To be considered a top-level module, a module can only depend on 'getattr' calls on
input.
Input model must be symbolically traceable with the exception of `ShardedModule` and
`DistributedDataParallel` modules.
"""
synced_pipeline_id: Dict[int, int] = {}
def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device,
enable_amp: bool = False,
enable_grad_scaling: bool = True,
grad_accum: Optional[int] = None,
) -> None:
self._model = model
self._optimizer = optimizer
self._device = device
self._enable_amp = enable_amp
# NOTE: Pending upstream feedback, but two flags because we can run AMP without CUDA but cannot scale gradients without CUDA.
# Background on gradient/loss scaling
# https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling
# https://pytorch.org/docs/stable/amp.html#gradient-scaling
self._enable_grad_scaling = enable_grad_scaling
self._grad_scaler = torch.cuda.amp.GradScaler(
enabled=self._enable_amp and self._enable_grad_scaling
)
logging.info(f"Amp is enabled: {self._enable_amp}")
# use two data streams to support two concurrent batches
if device.type == "cuda":
self._memcpy_stream: Optional[torch.cuda.streams.Stream] = torch.cuda.Stream()
self._data_dist_stream: Optional[torch.cuda.streams.Stream] = torch.cuda.Stream()
else:
if self._enable_amp:
logging.warning("Amp is enabled, but no CUDA available")
self._memcpy_stream: Optional[torch.cuda.streams.Stream] = None
self._data_dist_stream: Optional[torch.cuda.streams.Stream] = None
self._batch_i: Optional[In] = None
self._batch_ip1: Optional[In] = None
self._batch_ip2: Optional[In] = None
self._connected = False
self._context = TrainPipelineContext()
self._pipelined_modules: List[ShardedModule] = []
self._progress_calls = 0
if grad_accum is not None:
assert isinstance(grad_accum, int) and grad_accum > 0
self._grad_accum = grad_accum
def _connect(self, dataloader_iter: Iterator[In]) -> None:
# batch 1
with torch.cuda.stream(self._memcpy_stream):
batch_i = next(dataloader_iter)
self._batch_i = batch_i = _to_device(batch_i, self._device, non_blocking=True)
# Try to pipeline input data dist.
self._pipelined_modules = _rewrite_model(self._model, self._context, self._data_dist_stream)
with torch.cuda.stream(self._data_dist_stream):
_wait_for_batch(batch_i, self._memcpy_stream)
_start_data_dist(self._pipelined_modules, batch_i, self._context)
# batch 2
with torch.cuda.stream(self._memcpy_stream):
batch_ip1 = next(dataloader_iter)
self._batch_ip1 = batch_ip1 = _to_device(batch_ip1, self._device, non_blocking=True)
self._connected = True
self.__class__.synced_pipeline_id[id(self._model)] = id(self)
def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
NOTE: This method has been updated to perform gradient accumulation.
If `_grad_accum` is set, then loss values are scaled by this amount and
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
(congruent to training steps), and then update/reset on every `_grad_accum`th
step.
"""
should_step_optimizer = (
self._grad_accum is not None
and self._progress_calls > 0
and (self._progress_calls + 1) % self._grad_accum == 0
) or self._grad_accum is None
should_reset_optimizer = (
self._grad_accum is not None
and self._progress_calls > 0
and (self._progress_calls + 2) % self._grad_accum == 0
) or self._grad_accum is None
if not self._connected:
self._connect(dataloader_iter)
elif self.__class__.synced_pipeline_id.get(id(self._model), None) != id(self):
self._sync_pipeline()
self.__class__.synced_pipeline_id[id(self._model)] = id(self)
if self._model.training and should_reset_optimizer:
with record_function("## zero_grad ##"):
self._optimizer.zero_grad()
with record_function("## copy_batch_to_gpu ##"):
with torch.cuda.stream(self._memcpy_stream):
batch_ip2 = next(dataloader_iter)
self._batch_ip2 = batch_ip2 = _to_device(batch_ip2, self._device, non_blocking=True)
batch_i = cast(In, self._batch_i)
batch_ip1 = cast(In, self._batch_ip1)
with record_function("## wait_for_batch ##"):
_wait_for_batch(batch_i, self._data_dist_stream)
# Forward
with record_function("## forward ##"):
# if using multiple streams (ie. CUDA), create an event in default stream
# before starting forward pass
if self._data_dist_stream:
event = torch.cuda.current_stream().record_event()
if self._enable_amp:
# conditionally apply the model to the batch in the autocast context
# it appears that `enabled=self._enable_amp` should handle this,
# but it does not.
with torch.autocast(
device_type=self._device.type,
dtype=torch.bfloat16,
enabled=self._enable_amp,
):
(losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))
else:
(losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))
# Data Distribution
with record_function("## sparse_data_dist ##"):
with torch.cuda.stream(self._data_dist_stream):
_wait_for_batch(batch_ip1, self._memcpy_stream)
# Ensure event in default stream has been called before
# starting data dist
if self._data_dist_stream:
# pyre-ignore [61]: Local variable `event` is undefined, or not always defined
self._data_dist_stream.wait_event(event)
_start_data_dist(self._pipelined_modules, batch_ip1, self._context)
if self._model.training:
# Backward
with record_function("## backward ##"):
# Loss is normalize by number of accumulation steps.
# The reported loss in `output['loss']` remains the unnormalized value.
if self._grad_accum is not None:
losses = losses / self._grad_accum
self._grad_scaler.scale(torch.sum(losses, dim=0)).backward()
if should_step_optimizer:
# Update
with record_function("## optimizer ##"):
self._grad_scaler.step(self._optimizer)
self._grad_scaler.update()
self._batch_i = batch_ip1
self._batch_ip1 = batch_ip2
if self._model.training:
self._progress_calls += 1
return output
def _sync_pipeline(self) -> None:
"""
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
current train pipeline. Used when switching between train pipelines for the same
model.
"""
for module in self._pipelined_modules:
module.forward._context = self._context
module.forward._dist_stream = self._data_dist_stream