Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to flash-attn2 #149

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Modified scgpt/model/model.py & scgpt/model/multiomic_model.py for fl…
…ash-attn2. Updated README.md with flash-attn2 installation and notes
  • Loading branch information
Henry Ding authored and Henry Ding committed Jan 23, 2024
commit 0a4c75bf81c18ca47ec4508bf2100a2a39c2a37c
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,19 @@ scGPT works with Python >= 3.7.13 and R >=3.6.1. Please make sure you have the c
scGPT is available on PyPI. To install scGPT, run the following command:

```bash
pip install scgpt "flash-attn<1.0.5" # optional, recommended
pip install scgpt ninja packaging && FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install flash-attn --no-build-isolation # optional, recommended
# As of 2023.09, pip install may not run with new versions of the google orbax package, if you encounter related issues, please use the following command instead:
# pip install scgpt "flash-attn<1.0.5" "orbax<0.1.8"
# pip install scgpt ninja packaging "orbax<0.1.8" && FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install flash-attn --no-build-isolation
```

**Note**:
The `flash-attn` dependency requires `CUDA >= 11.6, PyTorch >= 1.12, Linux`.
If you encounter any issues, please refer to the [flash-attn](https://github.com/HazyResearch/flash-attention/tree/main) repository for installation instructions.
For now, ~~May 2023, we recommend using CUDA 11.7 and flash-attn<1.0.5 due to various issues reported about installing new versions of flash-attn.~~
we are using the latest flash-attn2, but please be aware that depends on the gpu architecture,
Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100) are tested and supported, however,
Turing GPUs (T4, RTX 2080) are not supported in flash-attn2 but in flash-attn 1.x

[Optional] We recommend using [wandb](https://wandb.ai/) for logging and visualization.

```bash
Expand All @@ -49,8 +57,6 @@ $ cd scGPT
$ poetry install
```

**Note**: The `flash-attn` dependency usually requires specific GPU and CUDA version. If you encounter any issues, please refer to the [flash-attn](https://github.com/HazyResearch/flash-attention/tree/main) repository for installation instructions. For now, May 2023, we recommend using CUDA 11.7 and flash-attn<1.0.5 due to various issues reported about installing new versions of flash-attn.

## Pretrained scGPT Model Zoo

Here is the list of pretrained models. Please find the links for downloading the checkpoint folders. We recommend using the `whole-human` model for most applications by default. If your fine-tuning dataset shares similar cell type context with the training data of the organ-specific models, these models can usually demonstrate competitive performance as well. A paired vocabulary file mapping gene names to ids is provided in each checkpoint folder. If ENSEMBL ids are needed, please find the conversion at [gene_info.csv](https://github.com/bowang-lab/scGPT/files/13243634/gene_info.csv).
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ orbax = "<0.1.8"
pytest = "^5.2"
black = "^22.3.0"
tensorflow = "^2.8.0"
flash-attn = "^1.0.1"
flash-attn = ">=2.4.2"
torch-geometric = "^2.3.0"
dcor = "~0.5.3"
wandb = "^0.12.3"
Expand Down
8 changes: 4 additions & 4 deletions scgpt/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tqdm import trange

try:
from flash_attn.flash_attention import FlashMHA
from flash_attn.modules.mha import MHA

flash_attn_available = True
except ImportError:
Expand Down Expand Up @@ -634,11 +634,11 @@ def __init__(
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.self_attn = FlashMHA(
self.self_attn = MHA(
embed_dim=d_model,
num_heads=nhead,
batch_first=batch_first,
attention_dropout=dropout,
dropout=dropout,
use_flash_attn=True,
**factory_kwargs,
)
# Version compatibility workaround
Expand Down
8 changes: 4 additions & 4 deletions scgpt/model/multiomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tqdm import trange

try:
from flash_attn.flash_attention import FlashMHA
from flash_attn.modules.mha import MHA
except ImportError:
import warnings

Expand Down Expand Up @@ -668,11 +668,11 @@ def __init__(
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.self_attn = FlashMHA(
self.self_attn = MHA(
embed_dim=d_model,
num_heads=nhead,
batch_first=batch_first,
attention_dropout=dropout,
dropout=dropout,
use_flash_attn=True,
**factory_kwargs,
)
# Implementation of Feedforward model
Expand Down