Skip to content

Commit

Permalink
Merge pull request HeliXonProtein#35 from seoklab/fix-mps
Browse files Browse the repository at this point in the history
Fix for older versions of pytorch
  • Loading branch information
RuiWang1998 authored Sep 18, 2022
2 parents bea3663 + c6351ac commit 9a4209f
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions omegafold/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@

try:
from torch.backends import mps # Compatibility with earlier versions
_mps_is_available = mps.is_available
except ImportError:
mps = None
def _mps_is_available():
return False


# =============================================================================
Expand Down Expand Up @@ -277,7 +279,7 @@ def _get_device(device) -> str:
if device is None:
if torch.cuda.is_available():
return "cuda"
elif mps.is_available():
elif _mps_is_available():
return "mps"
else:
return 'cpu'
Expand All @@ -289,7 +291,7 @@ def _get_device(device) -> str:
else:
raise ValueError(f"Device cuda is not available")
elif device == "mps":
if mps.is_available():
if _mps_is_available():
return device
else:
raise ValueError(f"Device mps is not available")
Expand Down

0 comments on commit 9a4209f

Please sign in to comment.