Skip to content

Commit

Permalink
Unsloth kernels as a thunder executor (Lightning-AI#1174)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 21, 2024
1 parent a09ef68 commit 24d5eba
Show file tree
Hide file tree
Showing 9 changed files with 1,487 additions and 16 deletions.
84 changes: 68 additions & 16 deletions extensions/thunder/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Lightning Thunder: a source-to-source compiler for PyTorch.
# Lightning Thunder: a source-to-source compiler for PyTorch

[Lightning Thunder](https://github.com/Lightning-AI/lightning-thunder) makes PyTorch programs faster both on single accelerators or in distributed settings.

Expand Down Expand Up @@ -575,11 +575,56 @@ import thunder
from thunder.executors.sdpaex import sdpa_ex
from thunder.executors.torch_compile import torch_compile_executor

model = thunder.jit(model, executors=[sdpa_ex, torch_compile_executor, thunder.nvfuser_executor, thunder.pytorch_executor])
model = thunder.jit(
model,
executors=[sdpa_ex, torch_compile_executor, thunder.nvfuser_executor, thunder.pytorch_executor]
)
```

Notice how `torch.compile` is a valid executor. This executor registers a few operators with improved performance so that you can utilize the fastest set of operator implementations possible.

### Custom executors

Lightning Thunder provides extension points to integrate fast kernels for operators in your model without having to modify your implementation.

For instance, the [Unsloth project](https://github.com/unslothai/unsloth/) provides several Triton kernels that can be used with LitGPT:
- Cross entropy loss
- SwiGLU (part of `LLaMAMLP`)
- RoPE

The [`unsloth` directory](unsloth) contains a [custom executor](unsloth/executor.py) that registers these operators for LitGPT.
We can enable this executor by passing it to the list of executors available. The order matters because we want to run its custom operators before
`NvFuser` creates its fusion regions.

```python
from unsloth.executor import unsloth_ex

model = thunder.jit(
model,
executors=[sdpa_ex, unsloth_ex, torch_compile_executor, thunder.nvfuser_executor, thunder.pytorch_executor]
)
```

Doing this, the model trace now includes the Unsloth kernel calls:

```python
def augmented_forward_fn(*args):
...
(t121, _, _, _, _, _) = unsloth_apply_rope(t120, t21, t22)
...
(t189, t190) = unsloth_cross_entropy(t187, t188)
...

def backward_fn(saved_for_backward, cotangents):
...
t652 = unsloth_cross_entropy_backward(t651, t187, t188, t190) # t652: "cuda:0 f32[6, 320]"
...
t763 = unsloth_apply_rope_backward(t757, t21, t22, 1, 8, 4) # t763: "cuda:0 f32[2, 4, 3, 16]"
```

We provide a specific [pre-training script copy](unsloth/pretrain.py) that uses this executor.
Given the Unsloth results below, these hand-written kernels do not seem to be worth it, showcasing the power of automated fusion compilers like [NvFuser](https://github.com/NVIDIA/Fuser).

## Examples and benchmarks:

> [!WARNING]
Expand All @@ -588,22 +633,24 @@ Notice how `torch.compile` is a valid executor. This executor registers a few op
We provide a version of the main pre-training script [that integrates Thunder](pretrain.py) that uses TinyLlama, a 1.1B parameter LLM.

| Data parallel | Compiler/JIT | Devices | ms/iter @ step 10 | Memory (GB) |
|---------------|--------------|---------|-------------------|-------------|
| FSDP Zero 3 | Eager | 8 | 460.88 | 22.13 |
| FSDP Zero 3 | Inductor | 8 | 318.71 | 17.08 |
| FSDP Zero 3 | Thunder | 8 | 345.02 | 18.28 |
| | | | | |
| Replicated | Eager | 8 | 535.28 | 32.05 |
| Replicated | Inductor | 8 | 348.19 | 27.01 |
| Replicated | Thunder | 8 | OOM | OOM |
| | | | | |
| None | Eager | 1 | 449.88 | 29.85 |
| None | Inductor | 1 | 320.22 | 24.81 |
| None | Thunder | 1 | 322.83 | 26.37 |
| Setting | Compiler/JIT | Devices | ms/iter @ step 10 | Memory (GB) |
|----------------------|--------------|---------|-------------------|-------------|
| Fully-sharded ZeRO 3 | Eager | 8 | 460.88 | 22.13 |
| Fully-sharded ZeRO 3 | Inductor | 8 | 318.71 | 17.08 |
| Fully-sharded ZeRO 3 | Thunder | 8 | 345.02 | 18.28 |
| | | | | |
| Replicated | Eager | 8 | 535.28 | 32.05 |
| Replicated | Inductor | 8 | 348.19 | 27.01 |
| Replicated | Thunder | 8 | OOM | OOM |
| | | | | |
| - | Eager | 1 | 449.88 | 29.85 |
| - | Inductor | 1 | 320.22 | 24.81 |
| - | Thunder | 1 | 322.83 | 26.37 |
| | | | | |
| Unsloth | Thunder | 1 | 331.93 | 25.19 |

<details>
<summary>Details</summary>
<summary>Reproduction details</summary>

Config:

Expand Down Expand Up @@ -633,12 +680,17 @@ python extensions/thunder/pretrain.py --config config.yaml --strategy ddp
python extensions/thunder/pretrain.py --config config.yaml --compiler null --devices 1
python extensions/thunder/pretrain.py --config config.yaml --compiler torch --devices 1
python extensions/thunder/pretrain.py --config config.yaml --devices 1

python extensions/thunder/unsloth/pretrain.py --config config.yaml --devices 1
```

Gradient accumulation is disabled in the FSDP setting because Thunder does not support skipping the backward synchronization yet.

The CUDA devices are all NVIDIA A100-SXM4-40GB.

The Unsloth example does not support distributed yet.
The Unsloth example requires commenting out this line in Lightning Fabric: https://github.com/Lightning-AI/pytorch-lightning/blob/fadd2fc/src/lightning/fabric/wrappers.py#L233

```text
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Is debug build: False
Expand Down
Loading

0 comments on commit 24d5eba

Please sign in to comment.