Skip to content

Latest commit

 

History

History

07-2d-parallel

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 

2d parallelism (TP + DP)

Using both FSDP and TP is actually quite simple code wise when starting from our chapter 6 TP script.

Disclaimer this only works if you use pytorch's newer FSDP 2 api, which is still in alpha stages.

What does using these two together mean exactly? Let's get into an example with 6 GPUs, 2 way FSDP and 3 way TP:

image

When we first start out every gpu holds the full model. Then we shard the model into 3 pieces (our TP dimension). The 3 shards in the graphic above are red+orange, yellow+green, and blue+purple. Note that GPU 0 and GPU 3 have the exact same shard! This is because they are the same tensor parallel rank, but are different data parallel ranks. This means we have duplicated our model across our data parallel dimension.

When we apply FSDP in the next step, we split those duplicated shards! So Shard red+orange (which is duplicated on GPU 0 & 3) is split into two pieces (Shard red and Shard orange).

By the end we have 6 distinct shards of our model split on every GPU.

Now if you remember with FSDP, it does an allgather of all the shards before the forward pass. When GPU 0 & GPU 3 are executing their forward passes, they will gather the two shards (Shard red and Shard orange) into local memory to form Shard red+orange, so that each one can use the full shard during computation.

Applying FSDP after TP

We are starting from our chapter 6 code, which already support TP. So we just need to add FSDP to the script:

The api is much simpler than FSDP 1 api, this is all we need to add after our TP code:

from torch.distributed._composable.fsdp import fully_shard

if mesh["dp"].size() > 1:
    for layer in model.model.layers:
        fully_shard(layer, mesh=mesh["dp"])
    fully_shard(model, mesh=mesh["dp"])

Note how we are passing our mesh["dp"] here to indicate that this is happening across our data parallel dimension.

Controlling TP size

When creating our mesh we are going to set the TP size based on a CLI argument:

assert world_size % args.tp == 0

mesh = dist.device_mesh.init_device_mesh(
    "cuda",
    (world_size // args.tp, args.tp),
    mesh_dim_names=("dp", "tp"),
)

and add it to our argparser:

parser.add_argument("--tp", default=8, type=int)

Performance with different configurations

Here are some training results for 4 different setups of the TP size:

  • 1x8 is 8 way TP, and no data parallelism. --batch-size 18 --tp 8
  • 2x4 is 4 way TP, with 2 groups of FSDP. --batch-size 14 --tp 4
  • 4x2 is 2 way TP, with 4 groups of FSDP. --batch-size 10 --tp 2
  • 8x1 is FSDP. --batch-size 7 --tp 1

Note that all of these runs have the same --lr while having different batch sizes, which is why the loss curves are slightly different.