Skip to content

Commit

Permalink
real_accelerator validation check for both accelerator and deepspeed.…
Browse files Browse the repository at this point in the history
…accelerator path (microsoft#2685)
  • Loading branch information
delock authored Jan 10, 2023
1 parent c702b64 commit 62c071e
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,33 @@
from .abstract_accelerator import DeepSpeedAccelerator
try:
from accelerator.abstract_accelerator import DeepSpeedAccelerator as dsa1
except ImportError as e:
dsa1 = None
try:
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator as dsa2
except ImportError as e:
dsa2 = None

ds_accelerator = None


def _validate_accelerator(accel_obj):
assert isinstance(accel_obj, DeepSpeedAccelerator), \
f'{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator'
# because abstract_accelerator has different path during
# build time (accelerator.abstract_accelerator)
# and run time (deepspeed.accelerator.abstract_accelerator)
# and extension would import the
# run time abstract_accelerator/DeepSpeedAccelerator as its base
# class, so we need to compare accel_obj with both base class.
# if accel_obj is instance of DeepSpeedAccelerator in one of
# accelerator.abstractor_accelerator
# or deepspeed.accelerator.abstract_accelerator, consider accel_obj
# is a conforming object
if not ((dsa1 != None and isinstance(accel_obj,
dsa1)) or
(dsa2 != None and isinstance(accel_obj,
dsa2))):
raise AssertionError(
f'{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator'
)

# TODO: turn off is_available test since this breaks tests
#assert accel_obj.is_available(), \
Expand Down

0 comments on commit 62c071e

Please sign in to comment.