Skip to content

Commit

Permalink
[FSDP] Zero 3 Optimization Support (facebookresearch#4903)
Browse files Browse the repository at this point in the history
* zero3 init commit

* minor cleanup:

* handle mpeval

* remove fairscale dependence

* fsdp avail

* update reqs

* better reqs

* autoformat

* autofromat
  • Loading branch information
klshuster authored Dec 5, 2022
1 parent 25df082 commit 96aa1bb
Show file tree
Hide file tree
Showing 13 changed files with 272 additions and 48 deletions.
12 changes: 6 additions & 6 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -219,26 +219,26 @@ commands:
- setupcuda
- fixgit
- restore_cache:
key: deps-20221130-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
key: deps-20221202-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
- setup
- installdeps
- << parameters.more_installs >>
- save_cache:
key: deps-20221130-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
key: deps-20221202-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
paths:
- "~/venv/bin"
- "~/venv/lib"
- findtests:
marker: << parameters.marker >>
- restore_cache:
key: data-20221130-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
key: data-20221202-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
- run:
name: Run tests
no_output_timeout: 60m
command: |
coverage run -m pytest -m << parameters.marker >> << parameters.pytest_flags >> --junitxml=test-results/junit.xml
- save_cache:
key: data-20221130-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
key: data-20221202-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
paths:
- "~/ParlAI/data"
- codecov
Expand All @@ -255,12 +255,12 @@ commands:
- checkout
- fixgit
- restore_cache:
key: deps-20221130-bw-{{ checksum "requirements.txt" }}
key: deps-20221202-bw-{{ checksum "requirements.txt" }}
- setup
- installdeps
- installtorchgpu
- save_cache:
key: deps-20221130-bw-{{ checksum "requirements.txt" }}
key: deps-20221202-bw-{{ checksum "requirements.txt" }}
paths:
- "~/venv/bin"
- "~/venv/lib"
Expand Down
8 changes: 5 additions & 3 deletions parlai/agents/hugging_face/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from parlai.core.params import ParlaiParser
from parlai.core.torch_agent import Batch, TorchAgent
from parlai.core.torch_generator_agent import TorchGeneratorAgent, TorchGeneratorModel
from parlai.utils.fsdp import is_fsdp
from parlai.utils.fsdp import is_fsdp, delay_halving


def check_hf_version(v: Tuple[int, int]) -> bool:
Expand All @@ -41,7 +41,9 @@ def check_hf_version(v: Tuple[int, int]) -> bool:
def build_t5(opt: Opt) -> T5ForConditionalGeneration:
if not check_hf_version(HF_VERSION):
raise RuntimeError('Must use transformers package >= 4.3 to use t5')
torch_dtype = torch.float16 if opt['fp16'] else torch.float32
torch_dtype = (
torch.float16 if (opt['fp16'] and not delay_halving(opt)) else torch.float32
)
try:
return T5ForConditionalGeneration.from_pretrained(
opt['t5_model_arch'],
Expand Down Expand Up @@ -369,7 +371,7 @@ def output(self, tensor):
"""
# Taken directly from HuggingFace
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
# See https://github.com/tensorflow/mesh/blob/fa19d69/mesh_tensorflow/transformer/transformer.py#L586
tensor = tensor * (self.t5.model_dim**-0.5)
lm_logits = self.t5.lm_head(tensor)
return lm_logits
Expand Down
4 changes: 2 additions & 2 deletions parlai/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,11 +774,11 @@ def add_distributed_training_args(self):
)
grp.add_argument(
'--ddp-backend',
# TODO: add in zero3. https://github.com/facebookresearch/ParlAI/issues/3753
choices=['ddp', 'zero2'],
choices=['ddp', 'zero2', 'zero3'],
default='ddp',
help=(
'Distributed backend. Zero2 can be faster but is more experimental. '
'Zero3 significantly reduces memory pressure. '
'DDP is the most tested.'
),
)
Expand Down
13 changes: 11 additions & 2 deletions parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@
from parlai.utils.distributed import is_distributed
from parlai.utils.misc import AttrDict, warn_once
from parlai.utils.io import PathManager
from parlai.utils.fsdp import should_sync_gradnorm, is_fsdp, DEFAULT_DDP_BACKEND
from parlai.utils.fsdp import (
should_sync_gradnorm,
is_fsdp,
DEFAULT_DDP_BACKEND,
FSDP_AVAILABLE,
get_state_dict,
)
from parlai.utils.fp16 import (
SafeFP16Optimizer,
MemoryEfficientFP16Optimizer,
Expand Down Expand Up @@ -1981,8 +1987,11 @@ def state_dict(self):
if hasattr(self.model, 'module') and not is_fsdp(self.model):
# did we wrap in a DistributedDataParallel or DataParallel
states['model'] = self.model.module.state_dict()
elif is_fsdp(self.model) and FSDP_AVAILABLE:
# FSDP Model; use fancy saving
states['model'] = get_state_dict(self.model)
else:
# regular model or FSDP
# regular model
states['model'] = self.model.state_dict()

if hasattr(self, 'optimizer'):
Expand Down
8 changes: 5 additions & 3 deletions parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from parlai.utils.misc import warn_once
from parlai.utils.io import PathManager
import parlai.utils.logging as logging
from parlai.core.metrics import Metric, SumMetric, AverageMetric, FairseqBleuMetric
from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric
from parlai.utils.fp16 import FP16SafeCrossEntropy
import parlai.utils.fsdp as fsdp_utils
from parlai.utils.torch import (
Expand Down Expand Up @@ -516,8 +516,10 @@ def __init__(self, opt: Opt, shared=None):
else:
# this is not a shared instance of this class, so do full init
self.criterion = self.build_criterion()

self.model = self.build_model()
with fsdp_utils.maybe_fsdp_wrap(opt):
self.model = fsdp_utils.fsdp_wrap(self.build_model())
self.model = fsdp_utils.fsdp_wrap(self.model)
if self.fp16 and not fsdp_utils.delay_halving(opt):
self.model = self.model.half()

Expand Down Expand Up @@ -2054,7 +2056,7 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection

class FactualNucleusSampling(NucleusSampling):
"""
Factual Nucleus Sampling
Factual Nucleus Sampling.
See https://arxiv.org/pdf/2206.04624.pdf for more information
"""
Expand Down
8 changes: 5 additions & 3 deletions parlai/scripts/distributed_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
-m seq2seq -t convai2 --dict-file /path/to/dict-file
```
"""

import parlai.scripts.eval_model as eval_model
from parlai.core.script import ParlaiScript
import parlai.scripts.eval_model as eval_model
import parlai.utils.distributed as distributed_utils
import parlai.utils.fsdp as fsdp_utils


def setup_args():
Expand All @@ -51,7 +51,9 @@ def setup_args(cls):

def run(self):
with distributed_utils.slurm_distributed_context(self.opt) as opt:
return eval_model.eval_model(opt)
self.evaluator = fsdp_utils.JoinableEvaluator(opt)
with fsdp_utils.fsdp_join(self.evaluator):
return self.evaluator.eval_model()


if __name__ == '__main__':
Expand Down
6 changes: 4 additions & 2 deletions parlai/scripts/distributed_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import parlai.scripts.train_model as single_train
from parlai.core.script import ParlaiScript
import parlai.utils.distributed as distributed_utils
import parlai.utils.fsdp as fsdp_utils


def setup_args():
Expand All @@ -51,8 +52,9 @@ def setup_args(cls):

def run(self):
with distributed_utils.slurm_distributed_context(self.opt) as opt:
self.train_loop = single_train.TrainLoop(opt)
return self.train_loop.train()
self.train_loop = fsdp_utils.JoinableTrainLoop(opt)
with fsdp_utils.fsdp_join(self.train_loop):
return self.train_loop.train()


if __name__ == '__main__':
Expand Down
17 changes: 15 additions & 2 deletions parlai/scripts/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
aggregate_unnamed_reports,
Metric,
)
from parlai.core.opt import Opt
from parlai.core.worlds import create_task
from parlai.utils.misc import TimeLogger, nice_report
from parlai.utils.world_logging import WorldLogger
Expand Down Expand Up @@ -77,7 +78,10 @@ def setup_args(parser=None):
'-auc',
type=int,
default=-1,
help='a positive number indicates to calculate the area under the roc curve and it also determines how many decimal digits of the predictions to keep (higher numbers->more precise); also used to determine whether or not to calculate the AUC metric',
help='a positive number indicates to calculate the area under the '
'roc curve and it also determines how many decimal digits of the '
'predictions to keep (higher numbers->more precise); also used '
'to determine whether or not to calculate the AUC metric',
)
parser.add_argument(
'--area-under-curve-class',
Expand Down Expand Up @@ -291,14 +295,23 @@ def eval_model(opt):
return report


class Evaluator:
def __init__(self, opt: Opt):
self.opt = opt

def eval_model(self):
return eval_model(self.opt)


@register_script('eval_model', aliases=['em', 'eval'])
class EvalModel(ParlaiScript):
@classmethod
def setup_args(cls):
return setup_args()

def run(self):
return eval_model(self.opt)
self.evaluator = Evaluator(self.opt)
return self.evaluator.eval_model()


if __name__ == '__main__':
Expand Down
7 changes: 5 additions & 2 deletions parlai/scripts/multiprocessing_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
import torch
import os
import signal
import parlai.utils.distributed as distributed_utils
import parlai.scripts.eval_model as eval_model
import parlai.utils.distributed as distributed_utils
import parlai.utils.fsdp as fsdp_utils
from parlai.core.script import ParlaiScript, register_script


Expand All @@ -43,7 +44,9 @@ def multiprocess_eval(
rank, opt, rank_offset, gpu, init_method=init_method
) as opt:
opt['multiprocessing'] = True
return eval_model.eval_model(opt)
evaluator = fsdp_utils.JoinableEvaluator(opt)
with fsdp_utils.fsdp_join(evaluator):
return evaluator.eval_model()


def launch_and_eval(opt, port):
Expand Down
5 changes: 4 additions & 1 deletion parlai/scripts/multiprocessing_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import traceback
import parlai.scripts.train_model as single_train
import parlai.utils.distributed as distributed_utils
import parlai.utils.fsdp as fsdp_utils
from parlai.core.script import ParlaiScript, register_script


Expand All @@ -41,8 +42,10 @@ def multiprocess_train(
) as opt:
# Run the actual training
opt['multiprocessing'] = True
loop = fsdp_utils.JoinableTrainLoop(opt)
try:
return single_train.TrainLoop(opt).train()
with fsdp_utils.fsdp_join(loop):
return loop.train()
except Exception:
import parlai.utils.logging as logging

Expand Down
Loading

0 comments on commit 96aa1bb

Please sign in to comment.