Skip to content

Commit

Permalink
Support MUSA (Moore Threads GPU) backend in transformers (huggingface…
Browse files Browse the repository at this point in the history
…#31913)

Add accelerate version check, needs accelerate>=0.33.0
  • Loading branch information
fmo-mt authored Aug 14, 2024
1 parent c135783 commit a22ff36
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-ci-docker-images.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@ jobs:
slack_channel: "#transformers-ci-circleci-images"
title: 🤗 New docker images for CircleCI are pushed.
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,7 @@
"is_tokenizers_available",
"is_torch_available",
"is_torch_mlu_available",
"is_torch_musa_available",
"is_torch_neuroncore_available",
"is_torch_npu_available",
"is_torch_tpu_available",
Expand Down Expand Up @@ -5706,6 +5707,7 @@
is_tokenizers_available,
is_torch_available,
is_torch_mlu_available,
is_torch_musa_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tpu_available,
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
is_torch_cuda_available,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,
is_torch_npu_available,
is_torch_xpu_available,
logging,
Expand Down Expand Up @@ -873,6 +874,8 @@ def __init__(
self.device = torch.device("cpu")
elif is_torch_mlu_available():
self.device = torch.device(f"mlu:{device}")
elif is_torch_musa_available():
self.device = torch.device(f"musa:{device}")
elif is_torch_cuda_available():
self.device = torch.device(f"cuda:{device}")
elif is_torch_npu_available():
Expand Down Expand Up @@ -1042,6 +1045,9 @@ def device_placement(self):
elif self.device.type == "mlu":
with torch.mlu.device(self.device):
yield
elif self.device.type == "musa":
with torch.musa.device(self.device):
yield
else:
yield

Expand Down
20 changes: 20 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@
is_torch_compile_available,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_xla_available,
Expand Down Expand Up @@ -2894,6 +2895,17 @@ def _load_rng_state(self, checkpoint):
f"Didn't manage to set back the RNG states of the MLU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)
if is_torch_musa_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.musa.set_rng_state_all(checkpoint_rng_state["musa"])
else:
try:
torch.musa.set_rng_state(checkpoint_rng_state["musa"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the MUSA because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)

def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
Expand Down Expand Up @@ -2982,6 +2994,12 @@ def _save_rng_state(self, output_dir):
else:
rng_states["mlu"] = torch.mlu.random.get_rng_state()

if is_torch_musa_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
rng_states["musa"] = torch.musa.get_rng_state_all()
else:
rng_states["musa"] = torch.musa.get_rng_state()

# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)
Expand Down Expand Up @@ -3351,6 +3369,8 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
torch.xpu.empty_cache()
elif is_torch_mlu_available():
torch.mlu.empty_cache()
elif is_torch_musa_available():
torch.musa.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available(min_version="2.0"):
Expand Down
15 changes: 14 additions & 1 deletion src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
is_torch_cuda_available,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,
is_torch_npu_available,
is_torch_xla_available,
is_torch_xpu_available,
Expand Down Expand Up @@ -108,6 +109,8 @@ def set_seed(seed: int, deterministic: bool = False):
torch.use_deterministic_algorithms(True)
if is_torch_mlu_available():
torch.mlu.manual_seed_all(seed)
if is_torch_musa_available():
torch.musa.manual_seed_all(seed)
if is_torch_npu_available():
torch.npu.manual_seed_all(seed)
if is_torch_xpu_available():
Expand Down Expand Up @@ -464,7 +467,7 @@ def __init__(self, skip_memory_metrics=False):

import psutil # noqa

if is_torch_cuda_available() or is_torch_mlu_available():
if is_torch_cuda_available() or is_torch_mlu_available() or is_torch_musa_available():
import torch

self.torch = torch
Expand Down Expand Up @@ -540,6 +543,9 @@ def start(self):
elif is_torch_mlu_available():
self.torch.mlu.reset_peak_memory_stats()
self.torch.mlu.empty_cache()
elif is_torch_musa_available():
self.torch.musa.reset_peak_memory_stats()
self.torch.musa.empty_cache()
elif is_torch_xpu_available():
self.torch.xpu.reset_peak_memory_stats()
self.torch.xpu.empty_cache()
Expand All @@ -555,6 +561,8 @@ def start(self):
self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
elif is_torch_mlu_available():
self.gpu_mem_used_at_start = self.torch.mlu.memory_allocated()
elif is_torch_musa_available():
self.gpu_mem_used_at_start = self.torch.musa.memory_allocated()
elif is_torch_xpu_available():
self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
elif is_torch_npu_available():
Expand Down Expand Up @@ -588,6 +596,8 @@ def stop(self, stage):
self.torch.cuda.empty_cache()
elif is_torch_mlu_available():
self.torch.mlu.empty_cache()
elif is_torch_musa_available():
self.torch.musa.empty_cache()
elif is_torch_xpu_available():
self.torch.xpu.empty_cache()
elif is_torch_npu_available():
Expand All @@ -608,6 +618,9 @@ def stop(self, stage):
elif is_torch_mlu_available():
self.gpu_mem_used_now = self.torch.mlu.memory_allocated()
self.gpu_mem_used_peak = self.torch.mlu.max_memory_allocated()
elif is_torch_musa_available():
self.gpu_mem_used_now = self.torch.musa.memory_allocated()
self.gpu_mem_used_peak = self.torch.musa.max_memory_allocated()
elif is_torch_xpu_available():
self.gpu_mem_used_now = self.torch.xpu.memory_allocated()
self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated()
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
is_torch_bf16_gpu_available,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tf32_available,
Expand Down Expand Up @@ -1090,7 +1091,7 @@ class TrainingArguments:
default=None,
metadata={
"help": "The backend to be used for distributed training",
"choices": ["nccl", "gloo", "mpi", "ccl", "hccl", "cncl"],
"choices": ["nccl", "gloo", "mpi", "ccl", "hccl", "cncl", "mccl"],
},
)
tpu_num_cores: Optional[int] = field(
Expand Down Expand Up @@ -2201,6 +2202,9 @@ def _setup_devices(self) -> "torch.device":
elif is_torch_mlu_available():
device = torch.device("mlu:0")
torch.mlu.set_device(device)
elif is_torch_musa_available():
device = torch.device("musa:0")
torch.musa.set_device(device)
elif is_torch_npu_available():
device = torch.device("npu:0")
torch.npu.set_device(device)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@
is_torch_fx_proxy,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_sdpa_available,
Expand Down
23 changes: 23 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,29 @@ def is_torch_mlu_available(check_device=False):
return hasattr(torch, "mlu") and torch.mlu.is_available()


@lru_cache()
def is_torch_musa_available(check_device=False):
"Checks if `torch_musa` is installed and potentially if a MUSA is in the environment"
if not _torch_available or importlib.util.find_spec("torch_musa") is None:
return False

import torch
import torch_musa # noqa: F401

torch_musa_min_version = "0.33.0"
if _accelerate_available and version.parse(_accelerate_version) < version.parse(torch_musa_min_version):
return False

if check_device:
try:
# Will raise a RuntimeError if no MUSA is found
_ = torch.musa.device_count()
return torch.musa.is_available()
except RuntimeError:
return False
return hasattr(torch, "musa") and torch.musa.is_available()


def is_torchdynamo_available():
if not is_torch_available():
return False
Expand Down

0 comments on commit a22ff36

Please sign in to comment.