diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..d5cffbd0 Binary files /dev/null and b/.DS_Store differ diff --git a/README.md b/README.md index f5fe45ef..2340f38d 100644 --- a/README.md +++ b/README.md @@ -30,11 +30,19 @@ scGPT works with Python >= 3.7.13 and R >=3.6.1. Please make sure you have the c scGPT is available on PyPI. To install scGPT, run the following command: ```bash -pip install scgpt "flash-attn<1.0.5" # optional, recommended +pip install scgpt ninja packaging && FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install flash-attn --no-build-isolation # optional, recommended # As of 2023.09, pip install may not run with new versions of the google orbax package, if you encounter related issues, please use the following command instead: -# pip install scgpt "flash-attn<1.0.5" "orbax<0.1.8" +# pip install scgpt ninja packaging "orbax<0.1.8" && FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install flash-attn --no-build-isolation ``` +**Note**: +The `flash-attn` dependency requires `CUDA >= 11.6, PyTorch >= 1.12, Linux`. +If you encounter any issues, please refer to the [flash-attn](https://github.com/HazyResearch/flash-attention/tree/main) repository for installation instructions. +For now, ~~May 2023, we recommend using CUDA 11.7 and flash-attn<1.0.5 due to various issues reported about installing new versions of flash-attn.~~ +we are using the latest flash-attn2, but please be aware that depends on the gpu architecture, +Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100) are tested and supported, however, +Turing GPUs (T4, RTX 2080) are not supported in flash-attn2 but in flash-attn 1.x + [Optional] We recommend using [wandb](https://wandb.ai/) for logging and visualization. ```bash @@ -49,8 +57,6 @@ $ cd scGPT $ poetry install ``` -**Note**: The `flash-attn` dependency usually requires specific GPU and CUDA version. If you encounter any issues, please refer to the [flash-attn](https://github.com/HazyResearch/flash-attention/tree/main) repository for installation instructions. For now, May 2023, we recommend using CUDA 11.7 and flash-attn<1.0.5 due to various issues reported about installing new versions of flash-attn. - ## Pretrained scGPT Model Zoo Here is the list of pretrained models. Please find the links for downloading the checkpoint folders. We recommend using the `whole-human` model for most applications by default. If your fine-tuning dataset shares similar cell type context with the training data of the organ-specific models, these models can usually demonstrate competitive performance as well. A paired vocabulary file mapping gene names to ids is provided in each checkpoint folder. If ENSEMBL ids are needed, please find the conversion at [gene_info.csv](https://github.com/bowang-lab/scGPT/files/13243634/gene_info.csv). diff --git a/pyproject.toml b/pyproject.toml index f402c050..b6471bba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ orbax = "<0.1.8" pytest = "^5.2" black = "^22.3.0" tensorflow = "^2.8.0" -flash-attn = "^1.0.1" +flash-attn = ">=2.4.2" torch-geometric = "^2.3.0" dcor = "~0.5.3" wandb = "^0.12.3" diff --git a/scgpt/model/model.py b/scgpt/model/model.py index ad0639be..b3550d05 100644 --- a/scgpt/model/model.py +++ b/scgpt/model/model.py @@ -12,7 +12,7 @@ from tqdm import trange try: - from flash_attn.flash_attention import FlashMHA + from flash_attn.modules.mha import MHA flash_attn_available = True except ImportError: @@ -634,11 +634,11 @@ def __init__( ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - self.self_attn = FlashMHA( + self.self_attn = MHA( embed_dim=d_model, num_heads=nhead, - batch_first=batch_first, - attention_dropout=dropout, + dropout=dropout, + use_flash_attn=False, **factory_kwargs, ) # Version compatibility workaround diff --git a/scgpt/model/multiomic_model.py b/scgpt/model/multiomic_model.py index 5128fac7..aded0913 100644 --- a/scgpt/model/multiomic_model.py +++ b/scgpt/model/multiomic_model.py @@ -12,7 +12,7 @@ from tqdm import trange try: - from flash_attn.flash_attention import FlashMHA + from flash_attn.modules.mha import MHA except ImportError: import warnings @@ -668,11 +668,11 @@ def __init__( ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - self.self_attn = FlashMHA( + self.self_attn = MHA( embed_dim=d_model, num_heads=nhead, - batch_first=batch_first, - attention_dropout=dropout, + dropout=dropout, + use_flash_attn=False, **factory_kwargs, ) # Implementation of Feedforward model