Skip to content

Commit

Permalink
protect load (facebookresearch#4779)
Browse files Browse the repository at this point in the history
  • Loading branch information
klshuster authored Sep 9, 2022
1 parent 58b6977 commit e217dc4
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions parlai/ops/ngram_repeat_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""
Wrapper for ngram_repeat_block cuda extension.
"""
import parlai.utils.logging as logging
import torch
from torch import nn

Expand All @@ -20,13 +21,19 @@
dname = os.path.dirname(abspath)
os.chdir(dname)

ngram_repeat_block_cuda = load(
name='ngram_repeat_block_cuda',
sources=[
'../clib/cuda/ngram_repeat_block_cuda.cpp',
'../clib/cuda/ngram_repeat_block_cuda_kernel.cu',
],
)

try:
ngram_repeat_block_cuda = load(
name='ngram_repeat_block_cuda',
sources=[
'../clib/cuda/ngram_repeat_block_cuda.cpp',
'../clib/cuda/ngram_repeat_block_cuda_kernel.cu',
],
)
except Exception as e:
logging.warning(f"Unable to load ngram blocking on GPU: {e}")
ngram_repeat_block_cuda = None

os.chdir(current)


Expand Down

0 comments on commit e217dc4

Please sign in to comment.