Skip to content

Commit

Permalink
Fix build for pytorch post 0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
csarofeen committed May 15, 2018
1 parent cc8f03c commit 83acda9
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ def findcuda():
cuda_headers = find(curdir, lambda file: file.endswith(".cuh"), True)
headers = find(curdir, lambda file: file.endswith(".h"), True)

libaten = find(torch_dir, re.compile("libaten", re.IGNORECASE).search, False)
libaten = list(set(find(torch_dir, re.compile("libaten", re.IGNORECASE).search, True)))
libaten_names = [os.path.splitext(os.path.basename(entry))[0] for entry in libaten]
for i, entry in enumerate(libaten_names):
if entry[:3]=='lib':
libaten_names[i] = entry[3:]

aten_h = find(torch_dir, re.compile("aten.h", re.IGNORECASE).search, False)

include_dirs = [os.path.dirname(os.path.dirname(aten_h))]
Expand All @@ -119,13 +124,13 @@ def findcuda():
assert libaten, "Could not find PyTorch's libATen."
assert aten_h, "Could not find PyTorch's ATen header."

library_dirs.append(os.path.dirname(libaten))
library_dirs.append(os.path.dirname(libaten[0]))

#create some places to collect important things
object_files = []
extra_link_args=[]
main_libraries = []
main_libraries += ['cudart', 'ATen']
main_libraries += ['cudart',]+libaten_names
extra_compile_args = ["--std=c++11",]

#findcuda returns root dir of CUDA
Expand Down

0 comments on commit 83acda9

Please sign in to comment.