Skip to content

Commit

Permalink
Use gpu_library_selector for permute_pooled_embedding_ops_gpu (pytorc…
Browse files Browse the repository at this point in the history
…h#2340)

Summary:
Pull Request resolved: pytorch#2340

X-link: pytorch/torchrec#1716

As title

Reviewed By: jspark1105

Differential Revision: D52531661

fbshipit-source-id: 99d17e01f67b43f22c26ffbbb09393463f301fa6
  • Loading branch information
jianyuh authored and facebook-github-bot committed Feb 20, 2024
1 parent bbf83f1 commit 49cca04
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,22 @@
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
)
if torch.version.hip:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu_hip"
)
else:
try:
if torch.version.hip:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu_hip"
)
else:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu_cuda"
)
except OSError:
# For backward compatibility
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu"
)
except OSError:
pass


class PermutePooledEmbeddings:
Expand Down

0 comments on commit 49cca04

Please sign in to comment.