Skip to content

Commit

Permalink
Use the new FSDP policy API (Lightning-AI#270)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Jul 17, 2023
1 parent 19bc3a5 commit 6eb745a
Show file tree
Hide file tree
Showing 11 changed files with 17 additions and 50 deletions.
7 changes: 2 additions & 5 deletions finetune/adapter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import os
import sys
import time
from functools import partial
from pathlib import Path
from typing import Optional, Tuple, Dict, List

import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy, XLAStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -60,10 +58,9 @@ def setup(
fabric_devices = "auto"
strategy = XLAStrategy(sync_module_states=False)
else:
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
strategy = FSDPStrategy(
auto_wrap_policy=auto_wrap_policy,
activation_checkpointing=Block,
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
Expand Down
7 changes: 2 additions & 5 deletions finetune/adapter_v2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import os
import sys
import time
from functools import partial
from pathlib import Path
from typing import Optional, List, Dict, Tuple

import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy, XLAStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -65,10 +63,9 @@ def setup(
fabric_devices = "auto"
strategy = XLAStrategy(sync_module_states=False)
else:
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
strategy = FSDPStrategy(
auto_wrap_policy=auto_wrap_policy,
activation_checkpointing=Block,
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
Expand Down
7 changes: 2 additions & 5 deletions finetune/full.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import os
import sys
import time
from functools import partial
from pathlib import Path
from typing import Optional, Tuple, Dict, List

import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy, XLAStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -60,10 +58,9 @@ def setup(
fabric_devices = "auto"
strategy = XLAStrategy(sync_module_states=False)
else:
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
strategy = FSDPStrategy(
auto_wrap_policy=auto_wrap_policy,
activation_checkpointing=Block,
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
Expand Down
5 changes: 1 addition & 4 deletions generate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import sys
import time
import warnings
from functools import partial
from pathlib import Path
from typing import Literal, Optional

import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -61,8 +59,7 @@ def main(
precision: Indicates the Fabric precision setting to use.
"""
if strategy == "fsdp":
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, cpu_offload=False)
strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)
fabric.launch()

Expand Down
5 changes: 1 addition & 4 deletions generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import sys
import time
import warnings
from functools import partial
from pathlib import Path
from typing import Literal, Optional

import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -63,8 +61,7 @@ def main(
precision: Indicates the Fabric precision setting to use.
"""
if strategy == "fsdp":
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, cpu_offload=False)
strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)
fabric.launch()

Expand Down
5 changes: 1 addition & 4 deletions generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import sys
import time
import warnings
from functools import partial
from pathlib import Path
from typing import Optional, Literal

import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -125,8 +123,7 @@ def main(
precision: Indicates the Fabric precision setting to use.
"""
if strategy == "fsdp":
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, cpu_offload=False)
strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)
fabric.launch()

Expand Down
5 changes: 1 addition & 4 deletions generate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import sys
import time
import warnings
from functools import partial
from pathlib import Path
from typing import Literal, Optional

import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -61,8 +59,7 @@ def main(
precision: Indicates the Fabric precision setting to use.
"""
if strategy == "fsdp":
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, cpu_offload=False)
strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)
fabric.launch()

Expand Down
5 changes: 1 addition & 4 deletions generate/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import sys
import time
import warnings
from functools import partial
from pathlib import Path
from typing import Literal, Optional

import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -66,8 +64,7 @@ def main(
precision: Indicates the Fabric precision setting to use.
"""
if strategy == "fsdp":
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, cpu_offload=False)
strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)
fabric.launch()

Expand Down
7 changes: 2 additions & 5 deletions pretrain/openwebtext.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import math
import sys
import time
from functools import partial
from pathlib import Path
from typing import Tuple, Optional, Union

import lightning as L
import numpy as np
import torch
from lightning.fabric.strategies import FSDPStrategy, XLAStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -60,10 +58,9 @@ def setup(
devices = "auto"
strategy = XLAStrategy(sync_module_states=False)
else:
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
strategy = FSDPStrategy(
auto_wrap_policy=auto_wrap_policy,
activation_checkpointing=Block,
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
Expand Down
7 changes: 2 additions & 5 deletions pretrain/openwebtext_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import math
import sys
import time
from functools import partial
from pathlib import Path
from typing import Optional, Any

Expand All @@ -11,7 +10,6 @@
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.strategies import FSDPStrategy, XLAStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -108,10 +106,9 @@ def main(devices: int = 1, precision: Optional[str] = None, tpu: bool = False) -
devices = "auto"
strategy = XLAStrategy(sync_module_states=False)
else:
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
strategy = FSDPStrategy(
auto_wrap_policy=auto_wrap_policy,
activation_checkpointing=Block,
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
# the argument is not available in the Trainer strategy, but it's the default anyways
# state_dict_type="full",
limit_all_gathers=True,
Expand Down
7 changes: 2 additions & 5 deletions pretrain/redpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import math
import sys
import time
from functools import partial
from pathlib import Path
from typing import Tuple, Optional, Union

import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy, XLAStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import DataLoader

# support running without installing as a package
Expand Down Expand Up @@ -77,10 +75,9 @@ def setup(
devices = "auto"
strategy = XLAStrategy(sync_module_states=False)
else:
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
strategy = FSDPStrategy(
auto_wrap_policy=auto_wrap_policy,
activation_checkpointing=Block,
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
Expand Down

0 comments on commit 6eb745a

Please sign in to comment.