Skip to content

Commit

Permalink
add absolute path to cpp files in setup
Browse files Browse the repository at this point in the history
  • Loading branch information
Wenliang Zhao authored and Wenliang Zhao committed Jan 29, 2024
1 parent b6bd5c7 commit 016776e
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
import os
from os.path import join


def readme():
Expand Down Expand Up @@ -91,17 +92,17 @@ def gen_packages_items():
packages = list(gen_packages_items())
return packages

def get_ext_modules():
def get_ext_modules(cur_dir):
# if encounter compilation issues, please refer to https://github.com/HuiZeng/Image-Adaptive-3DLUT?tab=readme-ov-file#build.
if torch.cuda.is_available():
return CUDAExtension('trilinear',
['libcom/image_harmonization/source/trilinear_cpp/src/trilinear_cuda.cpp',
'libcom/image_harmonization/source/trilinear_cpp/src/trilinear_kernel.cu'],
[join(cur_dir, 'libcom/image_harmonization/source/trilinear_cpp/src/trilinear_cuda.cpp'),
join(cur_dir, 'libcom/image_harmonization/source/trilinear_cpp/src/trilinear_kernel.cu')],
)
else:
return CppExtension('trilinear',
['libcom/image_harmonization/source/trilinear_cpp/src/trilinear.cpp'],
include_dirs=['libcom/image_harmonization/source/trilinear_cpp/src']
[join(cur_dir, 'libcom/image_harmonization/source/trilinear_cpp/src/trilinear.cpp')],
include_dirs=[join(cur_dir, 'libcom/image_harmonization/source/trilinear_cpp/src')]
)

if __name__ == '__main__':
Expand All @@ -128,11 +129,11 @@ def get_ext_modules():
setup_requires=parse_requirements('requirements/build.txt'),
tests_require=parse_requirements('requirements/tests.txt'),
install_requires=parse_requirements('requirements/requirements.txt'),
ext_modules=[get_ext_modules()],
ext_modules=[get_ext_modules(cur_dir)],
cmdclass={'build_ext': BuildExtension},
extras_require={
'all': parse_requirements('requirements.txt'),
'build': parse_requirements('requirements/build.txt'),
'test': parse_requirements('requirements.txt'),
},
)
)

0 comments on commit 016776e

Please sign in to comment.