Skip to content

Commit

Permalink
Fix type bug in to() method (AdaptiveMotorControlLab#55)
Browse files Browse the repository at this point in the history
* fix type bug in to() method

* added test for to() method

* update tests

---------

Co-authored-by: Rodrigo <[email protected]>
  • Loading branch information
sofiagilardini and gonlairo authored Sep 12, 2023
1 parent 808099b commit eda4aa7
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 27 deletions.
26 changes: 13 additions & 13 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,21 +1282,21 @@ def to(self, device: Union[str, torch.device]):
raise TypeError(
"The 'device' parameter must be a string or torch.device object."
)

if (not device == 'cpu') and (not device.startswith('cuda')) and (
not device == 'mps'):
raise ValueError(
"The 'device' parameter must be a valid device string or device object."
)


if isinstance(device, str):
device = torch.device(device)
if (not device == 'cpu') and (not device.startswith('cuda')) and (
not device == 'mps'):
raise ValueError(
"The 'device' parameter must be a valid device string or device object."
)

if (not device.type == 'cpu') and (
not device.type.startswith('cuda')) and (not device == 'mps'):
raise ValueError(
"The 'device' parameter must be a valid device string or device object."
)
elif isinstance(device, torch.device):
if (not device.type == 'cpu') and (
not device.type.startswith('cuda')) and (not device == 'mps'):
raise ValueError(
"The 'device' parameter must be a valid device string or device object."
)
device = device.type

if hasattr(self, "device_"):
self.device_ = device
Expand Down
64 changes: 50 additions & 14 deletions tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,39 @@ def get_ordered_cuda_devices():
ordered_cuda_devices = get_ordered_cuda_devices() if torch.cuda.is_available(
) else []

def test_fit_after_moving_to_device():
expected_device = 'cpu'
expected_type = type(expected_device)

X = np.random.uniform(0, 1, (10, 5))
cebra_model = cebra_sklearn_cebra.CEBRA(model_architecture="offset1-model",
max_iterations=5,
device=expected_device)

assert type(cebra_model.device) == expected_type
assert cebra_model.device == expected_device

cebra_model.partial_fit(X)
assert type(cebra_model.device) == expected_type
assert cebra_model.device == expected_device
if hasattr(cebra_model, 'device_'):
assert type(cebra_model.device_) == expected_type
assert cebra_model.device_ == expected_device

# Move the model to device using the to() method
cebra_model.to('cpu')
assert type(cebra_model.device) == expected_type
assert cebra_model.device == expected_device
if hasattr(cebra_model, 'device_'):
assert type(cebra_model.device_) == expected_type
assert cebra_model.device_ == expected_device

cebra_model.partial_fit(X)
assert type(cebra_model.device) == expected_type
assert cebra_model.device == expected_device
if hasattr(cebra_model, 'device_'):
assert type(cebra_model.device_) == expected_type
assert cebra_model.device_ == expected_device

@pytest.mark.parametrize("device", ['cpu'] + ordered_cuda_devices)
def test_move_cpu_to_cuda_device(device):
Expand All @@ -875,9 +908,12 @@ def test_move_cpu_to_cuda_device(device):
new_device = 'cpu' if device.startswith('cuda') else 'cuda:0'
cebra_model.to(new_device)

assert cebra_model.device == torch.device(new_device)
assert next(cebra_model.solver_.model.parameters()).device == torch.device(
new_device)
assert cebra_model.device == new_device
device_model = next(cebra_model.solver_.model.parameters()).device
device_str = str(device_model)
if device_model.type == 'cuda':
device_str = f'cuda:{device_model.index}'
assert device_str == new_device

with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile:
cebra_model.save(savefile.name)
Expand All @@ -903,9 +939,10 @@ def test_move_cpu_to_mps_device(device):
new_device = 'cpu' if device == 'mps' else 'mps'
cebra_model.to(new_device)

assert cebra_model.device == torch.device(new_device)
assert next(cebra_model.solver_.model.parameters()).device == torch.device(
new_device)
assert cebra_model.device == new_device

device_model = next(cebra_model.solver_.model.parameters()).device
assert device_model.type == new_device

with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile:
cebra_model.save(savefile.name)
Expand Down Expand Up @@ -939,9 +976,12 @@ def test_move_mps_to_cuda_device(device):
new_device = 'mps' if device.startswith('cuda') else 'cuda:0'
cebra_model.to(new_device)

assert cebra_model.device == torch.device(new_device)
assert next(cebra_model.solver_.model.parameters()).device == torch.device(
new_device)
assert cebra_model.device == new_device
device_model = next(cebra_model.solver_.model.parameters()).device
device_str = str(device_model)
if device_model.type == 'cuda':
device_str = f'cuda:{device_model.index}'
assert device_str == new_device

with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile:
cebra_model.save(savefile.name)
Expand All @@ -963,11 +1003,7 @@ def test_mps():

if torch.backends.mps.is_available() and torch.backends.mps.is_built():
torch.backends.mps.is_available = lambda: False
with pytest.raises(ValueError):
cebra_model.fit(X)

torch.backends.mps.is_available = lambda: True
torch.backends.mps.is_built = lambda: False

with pytest.raises(ValueError):
cebra_model.fit(X)

Expand Down

0 comments on commit eda4aa7

Please sign in to comment.