Skip to content

Commit

Permalink
Update .SO loader in OSS (pytorch#3477)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#559

Pull Request resolved: pytorch#3477

- Fix .SO loader in OSS to load only the modules specified by the variant.  This allows for load to fail if .SO file fails to load, and prevents logging false warnings

Reviewed By: spcyppt

Differential Revision: D66906278

fbshipit-source-id: 045adf87764c573ae4e058fa378a18e00a5cdad3
  • Loading branch information
q10 authored and facebook-github-bot committed Dec 7, 2024
1 parent 3bf676c commit 887e5bf
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions fbgemm_gpu/fbgemm_gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,10 @@ def _load_library(filename: str) -> None:
torch.ops.load_library(os.path.join(os.path.dirname(__file__), filename))
logging.info(f"Successfully loaded: '{filename}'")
except Exception as error:
logging.warning(
f"Could not load the library '{filename}': {error}. This may be expected depending on the FBGEMM_GPU variant."
)
logging.error(f"Could not load the library '{filename}': {error}")
raise error


for filename in [
"fbgemm_gpu_py.so",
"fbgemm_gpu_tbe_inference.so",
"experimental/gen_ai/fbgemm_gpu_experimental_gen_ai_py.so",
]:
_load_library(filename)

# Since __init__.py is only used in OSS context, we define `open_source` here
# and use its existence to determine whether or not we are in OSS context
open_source: bool = True
Expand All @@ -43,6 +35,28 @@ def _load_library(filename: str) -> None:
__variant__: str = "NONE"
__version__: str = "NONE"

libraries_to_load = {
"cpu": [
"fbgemm_gpu_py.so",
"fbgemm_gpu_tbe_inference.so",
],
"cuda": [
"fbgemm_gpu_py.so",
"fbgemm_gpu_tbe_inference.so",
"experimental/gen_ai/fbgemm_gpu_experimental_gen_ai_py.so",
],
"genai": [
"experimental/gen_ai/fbgemm_gpu_experimental_gen_ai_py.so",
],
"rocm": [
"fbgemm_gpu_py.so",
"fbgemm_gpu_tbe_inference.so",
],
}

for library in libraries_to_load.get(__variant__, []):
_load_library(library)

try:
# Trigger meta operator registrations
from . import sparse_ops # noqa: F401, E402
Expand Down

0 comments on commit 887e5bf

Please sign in to comment.