diff --git a/fbgemm_gpu/fbgemm_gpu/__init__.py b/fbgemm_gpu/fbgemm_gpu/__init__.py index 043368e52..633d75242 100644 --- a/fbgemm_gpu/fbgemm_gpu/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/__init__.py @@ -17,18 +17,10 @@ def _load_library(filename: str) -> None: torch.ops.load_library(os.path.join(os.path.dirname(__file__), filename)) logging.info(f"Successfully loaded: '{filename}'") except Exception as error: - logging.warning( - f"Could not load the library '{filename}': {error}. This may be expected depending on the FBGEMM_GPU variant." - ) + logging.error(f"Could not load the library '{filename}': {error}") + raise error -for filename in [ - "fbgemm_gpu_py.so", - "fbgemm_gpu_tbe_inference.so", - "experimental/gen_ai/fbgemm_gpu_experimental_gen_ai_py.so", -]: - _load_library(filename) - # Since __init__.py is only used in OSS context, we define `open_source` here # and use its existence to determine whether or not we are in OSS context open_source: bool = True @@ -43,6 +35,28 @@ def _load_library(filename: str) -> None: __variant__: str = "NONE" __version__: str = "NONE" +libraries_to_load = { + "cpu": [ + "fbgemm_gpu_py.so", + "fbgemm_gpu_tbe_inference.so", + ], + "cuda": [ + "fbgemm_gpu_py.so", + "fbgemm_gpu_tbe_inference.so", + "experimental/gen_ai/fbgemm_gpu_experimental_gen_ai_py.so", + ], + "genai": [ + "experimental/gen_ai/fbgemm_gpu_experimental_gen_ai_py.so", + ], + "rocm": [ + "fbgemm_gpu_py.so", + "fbgemm_gpu_tbe_inference.so", + ], +} + +for library in libraries_to_load.get(__variant__, []): + _load_library(library) + try: # Trigger meta operator registrations from . import sparse_ops # noqa: F401, E402