Skip to content

Commit

Permalink
Support packages (mypyc/mypyc#293)
Browse files Browse the repository at this point in the history
Closes mypyc/mypyc#227.

This adds package support for the `mypyc` command line but not to
`test_run`.
  • Loading branch information
msullivan authored Jul 19, 2018
1 parent 17b70a6 commit d04a9f0
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 42 deletions.
59 changes: 41 additions & 18 deletions mypyc/buildc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tempfile
import sys
from typing import List, Tuple
from mypyc.namegen import exported_name


class BuildError(Exception):
Expand All @@ -23,7 +24,7 @@ def build_c_extension(cpath: str, module_name: str, preserve_setup: bool = False
else:
tempdir = tempfile.mkdtemp()
try:
setup_path = make_setup_py(cpath, tempdir, '', '')
setup_path = make_setup_py(module_name, '', cpath, tempdir, [], [], [])
return run_setup_py_build(setup_path, module_name)
finally:
if not preserve_setup:
Expand All @@ -33,28 +34,44 @@ def build_c_extension(cpath: str, module_name: str, preserve_setup: bool = False
shim_template = """\
#include <Python.h>
PyObject *CPyInit_{modname}(void);
PyObject *CPyInit_{full_modname}(void);
PyMODINIT_FUNC
PyInit_{modname}(void)
{{
return CPyInit_{modname}();
return CPyInit_{full_modname}();
}}
"""


def build_c_extension_shim(module_name: str, shared_lib: str) -> str:
def build_c_extension_shim(full_module_name: str, shared_lib: str, is_package: bool=False) -> str:
module_parts = full_module_name.split('.')
module_name = module_parts[-1]
if is_package:
module_parts.append('__init__')
assert shared_lib.startswith('lib') and shared_lib.endswith('.so')
libname = shared_lib[3:-3]
tempdir = tempfile.mkdtemp()
cpath = os.path.join(tempdir, '%s.c' % module_name)
cpath = os.path.join(tempdir, '%s.c' % full_module_name.replace('.', '___')) # XXX
if '.' in full_module_name:
packages = 'packages=[{}],'.format(repr('.'.join(full_module_name.split('.')[:-1])))
else:
packages = ''
if len(module_parts) > 1:
relative_lib_path = os.path.join(*(['..'] * (len(module_parts) - 1)))
else:
relative_lib_path = '.'
with open(cpath, 'w') as f:
f.write(shim_template.format(modname=module_name))
f.write(shim_template.format(modname=module_name,
full_modname=exported_name(full_module_name)))
try:
setup_path = make_setup_py(cpath,
setup_path = make_setup_py(full_module_name,
packages,
cpath,
tempdir,
libraries=repr(libname),
library_dirs=repr('.'))
libraries=[libname],
library_dirs=['.'],
runtime_library_dirs=[relative_lib_path])
return run_setup_py_build(setup_path, module_name)
finally:
shutil.rmtree(tempdir)
Expand Down Expand Up @@ -83,6 +100,7 @@ def include_dir() -> str:
from distutils.core import setup, Extension
from distutils import sysconfig
import sys
import os
extra_compile_args = ['-Werror', '-Wno-unused-function', '-Wno-unused-label',
'-Wno-unreachable-code', '-Wno-unused-variable']
Expand All @@ -98,8 +116,6 @@ def include_dir() -> str:
# And on Linux, set the rpath to $ORIGIN so they will look for the shared
# library in the directory that they live in.
elif sys.platform == 'linux':
vars['LDSHARED'] += ' -Wl,-rpath,"$ORIGIN"'
# This flag is needed for gcc but does not exist on clang. Currently we only support gcc for
# linux.
# TODO: Add support for clang on linux. Possibly also add support for gcc on Darwin.
Expand All @@ -108,8 +124,11 @@ def include_dir() -> str:
module = Extension('{package_name}',
sources=['{cpath}'],
extra_compile_args=extra_compile_args,
libraries=[{libraries}],
library_dirs=[{library_dirs}])
{packages}
libraries={libraries},
library_dirs={library_dirs},
runtime_library_dirs=[os.path.join("$ORIGIN", s) for s in {rt_library_dirs}],
)
setup(name='{package_name}',
version='1.0',
Expand All @@ -119,18 +138,22 @@ def include_dir() -> str:
"""


def make_setup_py(cpath: str, dirname: str, libraries: str, library_dirs: str) -> str:
def make_setup_py(package_name: str, packages: str,
cpath: str, dirname: str,
libraries: List[str],
library_dirs: List[str],
runtime_library_dirs: List[str]) -> str:
setup_path = os.path.join(dirname, 'setup.py')
basename = os.path.basename(cpath)
package_name = os.path.splitext(basename)[0]
with open(setup_path, 'w') as f:
f.write(
setup_format.format(
package_name=package_name,
cpath=cpath,
packages=packages,
libraries=libraries,
library_dirs=library_dirs,
include_dir=include_dir()
include_dir=include_dir(),
rt_library_dirs=runtime_library_dirs,
)
)
return setup_path
Expand All @@ -141,6 +164,6 @@ def run_setup_py_build(setup_path: str, module_name: str) -> str:
subprocess.check_output(['python', setup_path, 'build'], stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as err:
raise BuildError(err.output)
so_path = glob.glob('build/*/%s.*.so' % module_name)
so_path = glob.glob('build/**/%s.*.so' % module_name, recursive=True)
assert len(so_path) == 1, so_path
return so_path[0]
17 changes: 11 additions & 6 deletions mypyc/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from mypyc.ops import FuncIR, ClassIR, ModuleIR
from mypyc.refcount import insert_ref_count_opcodes
from mypyc.exceptions import insert_exception_handling
from mypyc.emit import EmitterContext, Emitter, HeaderDeclaration
from mypyc.namegen import exported_name


class MarkedDeclaration:
Expand All @@ -29,6 +31,7 @@ def __init__(self, declaration: HeaderDeclaration, mark: bool) -> None:


def compile_modules_to_c(sources: List[BuildSource], module_names: List[str], options: Options,
use_shared_lib: bool,
alt_lib_path: Optional[str] = None) -> str:
"""Compile Python module(s) to C that can be used from Python C extension modules."""
assert options.strict_optional, 'strict_optional must be turned on'
Expand All @@ -52,7 +55,7 @@ def compile_modules_to_c(sources: List[BuildSource], module_names: List[str], op
# Generate C code.
source_paths = {module_name: result.files[module_name].path
for module_name in module_names}
generator = ModuleGenerator(modules, source_paths)
generator = ModuleGenerator(modules, source_paths, use_shared_lib)
return generator.generate_c_for_modules()


Expand Down Expand Up @@ -85,11 +88,13 @@ def encode_bytes_as_c_string(b: bytes) -> Tuple[str, int]:
class ModuleGenerator:
def __init__(self,
modules: List[Tuple[str, ModuleIR]],
source_paths: Dict[str, str]) -> None:
source_paths: Dict[str, str],
use_shared_lib: bool) -> None:
self.modules = modules
self.source_paths = source_paths
self.context = EmitterContext([name for name, _ in modules])
self.names = self.context.names
self.use_shared_lib = use_shared_lib

def generate_c_for_modules(self) -> str:
emitter = Emitter(self.context)
Expand Down Expand Up @@ -181,11 +186,11 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
# generate a shared library for the modules and shims that call into
# the shared library, and in this case we use an internal module
# initialized function that will be called by the shim.
if len(self.modules) == 1:
declaration = 'PyMODINIT_FUNC PyInit_{}(void)'
if not self.use_shared_lib:
declaration = 'PyMODINIT_FUNC PyInit_{}(void)'.format(module_name)
else:
declaration = 'PyObject *CPyInit_{}(void)'
emitter.emit_lines(declaration.format(module_name),
declaration = 'PyObject *CPyInit_{}(void)'.format(exported_name(module_name))
emitter.emit_lines(declaration,
'{')
module_static = self.module_static_name(module_name, emitter)
emitter.emit_lines('if ({} != NULL) {{'.format(module_static),
Expand Down
18 changes: 12 additions & 6 deletions mypyc/test/test_commandline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
These are slow -- do not add test cases unless you have a very good reason to do so.
"""

import glob
import os
import os.path
import re
Expand Down Expand Up @@ -43,13 +44,18 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
with open(program_path, 'w') as f:
f.write(text)

# Compile program
subprocess.check_call(['%s/scripts/mypyc' % base_path] + args, cwd='tmp')
try:
# Compile program
subprocess.check_call(['%s/scripts/mypyc' % base_path] + args, cwd='tmp')

# Run main program
out = subprocess.check_output(
[python3_path, program],
cwd='tmp')
# Run main program
out = subprocess.check_output(
[python3_path, program],
cwd='tmp')
finally:
so_paths = glob.glob('tmp/**/*.so', recursive=True)
for path in so_paths:
os.remove(path)

# Verify output
actual = out.decode().splitlines()
Expand Down
1 change: 1 addition & 0 deletions mypyc/test/test_emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
sources=[source],
module_names=['prog'],
options=options,
use_shared_lib=False,
alt_lib_path=test_temp_dir)
out = ctext.splitlines()
except CompileError as e:
Expand Down
1 change: 1 addition & 0 deletions mypyc/test/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
sources=sources,
module_names=module_names,
options=options,
use_shared_lib=len(module_names) > 1,
alt_lib_path=test_temp_dir)
except CompileError as e:
for line in e.messages:
Expand Down
25 changes: 18 additions & 7 deletions scripts/mypyc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ from typing import List, Optional, IO

from mypy.errors import CompileError
from mypy.options import Options
from mypy.build import BuildSource
from mypy.main import process_options
from mypy import build

Expand All @@ -45,7 +46,8 @@ def handle_build_error(err: BuildError, c_path: Optional[str]) -> None:
sys.exit('Internal error: C compilation failed' + extra)


def build_using_shared_lib(fobj: IO[str], ctext: str, module_names: List[str]) -> None:
def build_using_shared_lib(fobj: IO[str], ctext: str, module_names: List[str],
sources: List[BuildSource]) -> None:
common_path = fobj.name
fobj.write(ctext)
fobj.flush()
Expand All @@ -54,10 +56,15 @@ def build_using_shared_lib(fobj: IO[str], ctext: str, module_names: List[str]) -
except BuildError as err:
handle_build_error(err, common_path)

for module in module_names:
so_path = '%s.so' % module
for source in sources:
module = source.module
module_path = module.replace('.', '/')
is_package = source.path is not None and os.path.split(source.path)[1] == '__init__.py'
if is_package:
module_path = os.path.join(module_path, '__init__')
so_path = '%s.so' % module_path
try:
native_lib_path = build_c_extension_shim(module, shared_lib)
native_lib_path = build_c_extension_shim(module, shared_lib, is_package)
except BuildError as err:
handle_build_error(err, None)
shutil.copy(native_lib_path, so_path)
Expand Down Expand Up @@ -91,24 +98,28 @@ def main() -> None:
options.incremental = False

module_names = [source.module for source in sources]
# We generate a shared lib if there are multiple modules or if any
# of the modules are in package. (Because I didn't want to fuss
# around with making the single module code handle packages.)
use_shared_lib = len(module_names) > 1 or any('.' in x for x in module_names)

try:
ctext = emitmodule.compile_modules_to_c(
sources=sources,
module_names=module_names,
options=options)
options=options,
use_shared_lib=use_shared_lib)
except CompileError as e:
for line in e.messages:
print(line)
sys.exit(1)

use_shared_lib = len(module_names) > 1
with tempfile.NamedTemporaryFile(mode='w+',
prefix='mypyc-tmp-',
suffix='.c',
dir='.') as fobj:
if use_shared_lib:
build_using_shared_lib(fobj, ctext, module_names)
build_using_shared_lib(fobj, ctext, module_names, sources)
else:
build_single_module(fobj, ctext, module_names[0])

Expand Down
35 changes: 30 additions & 5 deletions test-data/commandline.test
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
--
-- These are slow -- do not add test cases unless you have a very good reason to do so.

[case testCompileTwoModules]
# cmd: a.py b.py
[case testCompileMypyc]
# cmd: a.py b.py p/__init__.py p/q.py
import os.path
import p
import p.q
import a
import b
print('<main>', b.g(a.A()))
Expand All @@ -13,9 +16,13 @@ except TypeError:
pass
else:
assert False
for x in [a, b, p, p.q]:
assert os.path.splitext(x.__file__)[1] != '.py'
[file z.py]

[file a.py]
import b
import c

print('<a>', ord('A') == 65) # Test full builtins

Expand All @@ -31,21 +38,39 @@ class B:
self.x = x

print('<a>', f(5).x)
print('<c>', c.foo())

[file b.py]
import a
import p.q

class B:
def __init__(self, x: int) -> None:
self.x = x

def g(a: a.A) -> int:
return a.x
def g(z: 'a.A') -> int:
return p.q.foo(z.x)

print('<b>', 'here')

[file c.py]
def foo() -> int:
return 10

[file p/__init__.py]

[file p/q.py]
import p.r
def foo(x: int) -> int:
return x*p.r.foo(x)

[file p/r.py]
def foo(x: int) -> int:
return x

[out]
<b> here
<a> True
<a> 5
<main> 4
<c> 10
<main> 16

0 comments on commit d04a9f0

Please sign in to comment.