Skip to content

Commit

Permalink
Improve device support and add support for Apple Silicon chipset (`mp…
Browse files Browse the repository at this point in the history
…s`) (AdaptiveMotorControlLab#34)

* add to() method to move cebra models (sklearn API) from devices

* better name of test

* assign self.device_ if it exists only

* modify check_device() to allow GPU id specification

* adapt test given the possibility of specifying GPU ids

* add support for mps device

* add mps to _set_device() in io

* add mps logic when cuda_if_available + fix test for torch versions < 1.12

* fix test when cuda is not available

* fix test when pytorch < 1.12
  • Loading branch information
gonlairo authored Jul 17, 2023
1 parent c95bd5a commit 00601fb
Show file tree
Hide file tree
Showing 5 changed files with 345 additions and 21 deletions.
12 changes: 12 additions & 0 deletions cebra/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import numpy as np
import numpy.typing as npt
import pkg_resources
import requests
import torch

Expand Down Expand Up @@ -61,6 +62,17 @@ def download_file_from_zip_url(url, file="montblanc_tracks.h5"):
return pathlib.Path(foldername) / "data" / file


def _is_mps_availabe(torch):
available = False
if pkg_resources.parse_version(
torch.__version__) >= pkg_resources.parse_version("1.12"):
if torch.backends.mps.is_available():
if torch.backends.mps.is_built():
available = True

return available


def _is_integer(y: Union[npt.NDArray, torch.Tensor]) -> bool:
"""Check if the values in ``y`` are :py:class:`int`.
Expand Down
49 changes: 49 additions & 0 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,3 +1256,52 @@ def load(cls,
raise RuntimeError("Model loaded from file is not compatible with "
"the current CEBRA version.")
return model

def to(self, device: Union[str, torch.device]):
"""Moves the cebra model to the specified device.
Args:
device: The device to move the cebra model to. This can be a string representing
the device ('cpu','cuda', cuda:device_id, or 'mps') or a torch.device object.
Returns:
The cebra model instance.
Example:
>>> import cebra
>>> import numpy as np
>>> dataset = np.random.uniform(0, 1, (1000, 30))
>>> cebra_model = cebra.CEBRA(max_iterations=10, device = "cuda_if_available")
>>> cebra_model.fit(dataset)
CEBRA(max_iterations=10)
>>> cebra_model = cebra_model.to("cpu")
"""

if not isinstance(device, (str, torch.device)):
raise TypeError(
"The 'device' parameter must be a string or torch.device object."
)

if (not device == 'cpu') and (not device.startswith('cuda')) and (
not device == 'mps'):
raise ValueError(
"The 'device' parameter must be a valid device string or device object."
)

if isinstance(device, str):
device = torch.device(device)

if (not device.type == 'cpu') and (
not device.type.startswith('cuda')) and (not device == 'mps'):
raise ValueError(
"The 'device' parameter must be a valid device string or device object."
)

if hasattr(self, "device_"):
self.device_ = device

self.device = device
self.solver_.model.to(device)

return self
42 changes: 39 additions & 3 deletions cebra/integrations/sklearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import sklearn.utils.validation as sklearn_utils_validation
import torch

import cebra.helper


def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple:
"""Handle deprecated arguments of a function until they are replaced.
Expand Down Expand Up @@ -114,16 +116,50 @@ def check_device(device: str) -> str:
device: The device to return, if possible.
Returns:
Either cuda or cpu depending on {device} and availability in the environment.
Either cuda, cuda:device_id, mps, or cpu depending on {device} and availability in the environment.
"""

if device == "cuda_if_available":
if torch.cuda.is_available():
return "cuda"
elif cebra.helper._is_mps_availabe(torch):
return "mps"
else:
return "cpu"
elif device in ["cuda", "cpu"]:
elif device.startswith("cuda:") and len(device) > 5:
cuda_device_id = device[5:]
if cuda_device_id.isdigit():
device_count = torch.cuda.device_count()
device_id = int(cuda_device_id)
if device_id < device_count:
return f"cuda:{device_id}"
else:
raise ValueError(
f"CUDA device {device_id} is not available. Available device IDs are 0 to {device_count - 1}."
)
else:
raise ValueError(
f"Invalid CUDA device ID format. Please use 'cuda:device_id' where '{cuda_device_id}' is an integer."
)
elif device == "cuda" and torch.cuda.is_available():
return "cuda:0"
elif device == "cpu":
return device
raise ValueError(f"Device needs to be cuda or cpu, but got {device}.")
elif device == "mps":
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
raise ValueError(
"MPS not available because the current PyTorch install was not "
"built with MPS enabled.")
else:
raise ValueError(
"MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine."
)

return device

raise ValueError(f"Device needs to be cuda, cpu or mps, but got {device}.")


def check_fitted(model: "cebra.models.Model") -> bool:
Expand Down
2 changes: 1 addition & 1 deletion cebra/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _set_device(self, device):
return
if not isinstance(device, str):
device = device.type
if device not in ("cpu", "cuda"):
if device not in ("cpu", "cuda", "mps"):
if device.startswith("cuda"):
_, id_ = device.split(":")
if int(id_) >= torch.cuda.device_count():
Expand Down
Loading

0 comments on commit 00601fb

Please sign in to comment.