Skip to content

Commit

Permalink
Merge pull request open-mmlab#337 from frankier/fix-setup-py
Browse files Browse the repository at this point in the history
Fix setup py: re-enable building of extensions and specify build dependencies as per PEP 518
  • Loading branch information
yjxiong authored Nov 17, 2020
2 parents 82bf68a + 00e8323 commit b4c076b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[build-system]
requires = ["setuptools", "wheel", "Cython", "torch", "numpy"]
34 changes: 22 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,26 @@
import subprocess
import time

from setuptools import find_packages, setup, Extension, dist
from setuptools import find_packages, setup, Extension
from setuptools.command.install import install
dist.Distribution().fetch_build_eggs(['Cython', 'numpy>=1.11.1', 'torch'])

import numpy as np
from Cython.Build import cythonize # noqa: E402
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


class torch_and_cython_build_ext(BuildExtension):
def finalize_options(self):
if self.distribution.ext_modules:
nthreads = getattr(self, 'parallel', None) # -j option in Py3.5+
nthreads = int(nthreads) if nthreads else None
from Cython.Build.Dependencies import cythonize
self.distribution.ext_modules[:] = cythonize(
self.distribution.ext_modules, nthreads=nthreads, force=self.force)
super().finalize_options()


def readme():
with open('README.md', encoding='utf-8') as f:
content = f.read()
Expand Down Expand Up @@ -177,16 +187,16 @@ def make_cuda_ext(name, module, sources, include_dirs=[]):
'https://github.com/open-mmlab/mmdetection/tarball/v1.0rc1/#egg=mmdet-v1.0rc1'
],
install_requires=install_requires,
# ext_modules=[
# make_cython_ext(name='cpu_nms',
# module='mmskeleton.ops.nms',
# sources=['cpu_nms.pyx']),
# make_cuda_ext(name='gpu_nms',
# module='mmskeleton.ops.nms',
# sources=['nms_kernel.cu', 'gpu_nms.pyx'],
# include_dirs=[np.get_include()]),
# ],
ext_modules=[
make_cython_ext(name='cpu_nms',
module='mmskeleton.ops.nms',
sources=['cpu_nms.pyx']),
make_cuda_ext(name='gpu_nms',
module='mmskeleton.ops.nms',
sources=['nms_kernel.cu', 'gpu_nms.pyx'],
include_dirs=[np.get_include()]),
],
cmdclass={
'build_ext': BuildExtension,
'build_ext': torch_and_cython_build_ext,
},
zip_safe=False)

0 comments on commit b4c076b

Please sign in to comment.