diff --git a/finetune/adapter.py b/finetune/adapter.py index d07a98c67e..ac7e958711 100644 --- a/finetune/adapter.py +++ b/finetune/adapter.py @@ -1,14 +1,12 @@ import os import sys import time -from functools import partial from pathlib import Path from typing import Optional, Tuple, Dict, List import lightning as L import torch from lightning.fabric.strategies import FSDPStrategy, XLAStrategy -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -60,10 +58,9 @@ def setup( fabric_devices = "auto" strategy = XLAStrategy(sync_module_states=False) else: - auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) strategy = FSDPStrategy( - auto_wrap_policy=auto_wrap_policy, - activation_checkpointing=Block, + auto_wrap_policy={Block}, + activation_checkpointing_policy={Block}, state_dict_type="full", limit_all_gathers=True, cpu_offload=False, diff --git a/finetune/adapter_v2.py b/finetune/adapter_v2.py index 7eff9951a8..704519c406 100644 --- a/finetune/adapter_v2.py +++ b/finetune/adapter_v2.py @@ -1,14 +1,12 @@ import os import sys import time -from functools import partial from pathlib import Path from typing import Optional, List, Dict, Tuple import lightning as L import torch from lightning.fabric.strategies import FSDPStrategy, XLAStrategy -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -65,10 +63,9 @@ def setup( fabric_devices = "auto" strategy = XLAStrategy(sync_module_states=False) else: - auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) strategy = FSDPStrategy( - auto_wrap_policy=auto_wrap_policy, - activation_checkpointing=Block, + auto_wrap_policy={Block}, + activation_checkpointing_policy={Block}, state_dict_type="full", limit_all_gathers=True, cpu_offload=False, diff --git a/finetune/full.py b/finetune/full.py index b546a65253..7ed29d7fa8 100644 --- a/finetune/full.py +++ b/finetune/full.py @@ -1,14 +1,12 @@ import os import sys import time -from functools import partial from pathlib import Path from typing import Optional, Tuple, Dict, List import lightning as L import torch from lightning.fabric.strategies import FSDPStrategy, XLAStrategy -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -60,10 +58,9 @@ def setup( fabric_devices = "auto" strategy = XLAStrategy(sync_module_states=False) else: - auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) strategy = FSDPStrategy( - auto_wrap_policy=auto_wrap_policy, - activation_checkpointing=Block, + auto_wrap_policy={Block}, + activation_checkpointing_policy={Block}, state_dict_type="full", limit_all_gathers=True, cpu_offload=False, diff --git a/generate/adapter.py b/generate/adapter.py index 7b36d41598..2e123bd584 100644 --- a/generate/adapter.py +++ b/generate/adapter.py @@ -2,14 +2,12 @@ import sys import time import warnings -from functools import partial from pathlib import Path from typing import Literal, Optional import lightning as L import torch from lightning.fabric.strategies import FSDPStrategy -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -61,8 +59,7 @@ def main( precision: Indicates the Fabric precision setting to use. """ if strategy == "fsdp": - auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) - strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, cpu_offload=False) + strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False) fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy) fabric.launch() diff --git a/generate/adapter_v2.py b/generate/adapter_v2.py index 15b68266e2..214e7b9a85 100644 --- a/generate/adapter_v2.py +++ b/generate/adapter_v2.py @@ -2,14 +2,12 @@ import sys import time import warnings -from functools import partial from pathlib import Path from typing import Literal, Optional import lightning as L import torch from lightning.fabric.strategies import FSDPStrategy -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -63,8 +61,7 @@ def main( precision: Indicates the Fabric precision setting to use. """ if strategy == "fsdp": - auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) - strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, cpu_offload=False) + strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False) fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy) fabric.launch() diff --git a/generate/base.py b/generate/base.py index c7e076914f..0fb2f87609 100644 --- a/generate/base.py +++ b/generate/base.py @@ -2,14 +2,12 @@ import sys import time import warnings -from functools import partial from pathlib import Path from typing import Optional, Literal import lightning as L import torch from lightning.fabric.strategies import FSDPStrategy -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -125,8 +123,7 @@ def main( precision: Indicates the Fabric precision setting to use. """ if strategy == "fsdp": - auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) - strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, cpu_offload=False) + strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False) fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy) fabric.launch() diff --git a/generate/full.py b/generate/full.py index e27185512c..53cf577d5f 100644 --- a/generate/full.py +++ b/generate/full.py @@ -2,14 +2,12 @@ import sys import time import warnings -from functools import partial from pathlib import Path from typing import Literal, Optional import lightning as L import torch from lightning.fabric.strategies import FSDPStrategy -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -61,8 +59,7 @@ def main( precision: Indicates the Fabric precision setting to use. """ if strategy == "fsdp": - auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) - strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, cpu_offload=False) + strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False) fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy) fabric.launch() diff --git a/generate/lora.py b/generate/lora.py index bf2e063231..9fbb86709a 100644 --- a/generate/lora.py +++ b/generate/lora.py @@ -2,14 +2,12 @@ import sys import time import warnings -from functools import partial from pathlib import Path from typing import Literal, Optional import lightning as L import torch from lightning.fabric.strategies import FSDPStrategy -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -66,8 +64,7 @@ def main( precision: Indicates the Fabric precision setting to use. """ if strategy == "fsdp": - auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) - strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, cpu_offload=False) + strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False) fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy) fabric.launch() diff --git a/pretrain/openwebtext.py b/pretrain/openwebtext.py index d31d6ae2ba..6a2f90fa2e 100644 --- a/pretrain/openwebtext.py +++ b/pretrain/openwebtext.py @@ -1,7 +1,6 @@ import math import sys import time -from functools import partial from pathlib import Path from typing import Tuple, Optional, Union @@ -9,7 +8,6 @@ import numpy as np import torch from lightning.fabric.strategies import FSDPStrategy, XLAStrategy -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -60,10 +58,9 @@ def setup( devices = "auto" strategy = XLAStrategy(sync_module_states=False) else: - auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) strategy = FSDPStrategy( - auto_wrap_policy=auto_wrap_policy, - activation_checkpointing=Block, + auto_wrap_policy={Block}, + activation_checkpointing_policy={Block}, state_dict_type="full", limit_all_gathers=True, cpu_offload=False, diff --git a/pretrain/openwebtext_trainer.py b/pretrain/openwebtext_trainer.py index d8255d9954..61577c1c1f 100644 --- a/pretrain/openwebtext_trainer.py +++ b/pretrain/openwebtext_trainer.py @@ -1,7 +1,6 @@ import math import sys import time -from functools import partial from pathlib import Path from typing import Optional, Any @@ -11,7 +10,6 @@ from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.strategies import FSDPStrategy, XLAStrategy -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -108,10 +106,9 @@ def main(devices: int = 1, precision: Optional[str] = None, tpu: bool = False) - devices = "auto" strategy = XLAStrategy(sync_module_states=False) else: - auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) strategy = FSDPStrategy( - auto_wrap_policy=auto_wrap_policy, - activation_checkpointing=Block, + auto_wrap_policy={Block}, + activation_checkpointing_policy={Block}, # the argument is not available in the Trainer strategy, but it's the default anyways # state_dict_type="full", limit_all_gathers=True, diff --git a/pretrain/redpajama.py b/pretrain/redpajama.py index e82c17b3ee..c7a70b46cc 100644 --- a/pretrain/redpajama.py +++ b/pretrain/redpajama.py @@ -2,14 +2,12 @@ import math import sys import time -from functools import partial from pathlib import Path from typing import Tuple, Optional, Union import lightning as L import torch from lightning.fabric.strategies import FSDPStrategy, XLAStrategy -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data import DataLoader # support running without installing as a package @@ -77,10 +75,9 @@ def setup( devices = "auto" strategy = XLAStrategy(sync_module_states=False) else: - auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) strategy = FSDPStrategy( - auto_wrap_policy=auto_wrap_policy, - activation_checkpointing=Block, + auto_wrap_policy={Block}, + activation_checkpointing_policy={Block}, state_dict_type="full", limit_all_gathers=True, cpu_offload=False,