Skip to content

Commit

Permalink
[mypyc] Refactor: extract code from IRBuilder (python#8509)
Browse files Browse the repository at this point in the history
Extract methods that seemed out of place in IRBuilder, including file and
import handling, and for loop and comprehension helpers.

This also removes a cyclic import.

Closes mypyc/mypyc#714.
  • Loading branch information
JukkaL authored Mar 7, 2020
1 parent 0a05e61 commit b2edab2
Show file tree
Hide file tree
Showing 7 changed files with 413 additions and 389 deletions.
404 changes: 46 additions & 358 deletions mypyc/irbuild/builder.py

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from mypyc.primitives.set_ops import new_set_op, set_add_op, set_update_op
from mypyc.irbuild.specialize import specializers
from mypyc.irbuild.builder import IRBuilder
from mypyc.irbuild.for_helpers import translate_list_comprehension, comprehension_helper


# Name and attribute references
Expand Down Expand Up @@ -495,7 +496,7 @@ def _visit_display(builder: IRBuilder,


def transform_list_comprehension(builder: IRBuilder, o: ListComprehension) -> Value:
return builder.translate_list_comprehension(o.generator)
return translate_list_comprehension(builder, o.generator)


def transform_set_comprehension(builder: IRBuilder, o: SetComprehension) -> Value:
Expand All @@ -507,7 +508,7 @@ def gen_inner_stmts() -> None:
e = builder.accept(gen.left_expr)
builder.primitive_op(set_add_op, [set_ops, e], o.line)

builder.comprehension_helper(loop_params, gen_inner_stmts, o.line)
comprehension_helper(builder, loop_params, gen_inner_stmts, o.line)
return set_ops


Expand All @@ -520,7 +521,7 @@ def gen_inner_stmts() -> None:
v = builder.accept(o.value)
builder.primitive_op(dict_set_item_op, [d, k, v], o.line)

builder.comprehension_helper(loop_params, gen_inner_stmts, o.line)
comprehension_helper(builder, loop_params, gen_inner_stmts, o.line)
return d


Expand All @@ -543,5 +544,5 @@ def get_arg(arg: Optional[Expression]) -> Value:
def transform_generator_expr(builder: IRBuilder, o: GeneratorExpr) -> Value:
builder.warning('Treating generator comprehension as list', o.line)
return builder.primitive_op(
iter_op, [builder.translate_list_comprehension(o)], o.line
iter_op, [translate_list_comprehension(builder, o)], o.line
)
247 changes: 234 additions & 13 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,251 @@
"""Helpers for generating for loops.
"""Helpers for generating for loops and comprehensions.
We special case certain kinds for loops such as "for x in range(...)"
for better efficiency. Each for loop generator class below deals one
such special case.
"""

from typing import Union, List
from typing_extensions import TYPE_CHECKING
from typing import Union, List, Optional, Tuple, Callable

from mypy.nodes import Lvalue, Expression
from mypy.nodes import Lvalue, Expression, TupleExpr, CallExpr, RefExpr, GeneratorExpr, ARG_POS
from mypyc.ir.ops import (
Value, BasicBlock, LoadInt, Branch, Register, AssignmentTarget
)
from mypyc.ir.rtypes import RType, is_short_int_rprimitive, is_list_rprimitive
from mypyc.ir.rtypes import (
RType, is_short_int_rprimitive, is_list_rprimitive, is_sequence_rprimitive
)
from mypyc.primitives.int_ops import unsafe_short_add
from mypyc.primitives.list_ops import list_get_item_unsafe_op
from mypyc.primitives.list_ops import new_list_op, list_append_op, list_get_item_unsafe_op
from mypyc.primitives.misc_ops import iter_op, next_op
from mypyc.primitives.exc_ops import no_err_occurred_op

if TYPE_CHECKING:
import mypyc.irbuild.builder
from mypyc.irbuild.builder import IRBuilder


GenFunc = Callable[[], None]


def for_loop_helper(builder: IRBuilder, index: Lvalue, expr: Expression,
body_insts: GenFunc, else_insts: Optional[GenFunc],
line: int) -> None:
"""Generate IR for a loop.
Args:
index: the loop index Lvalue
expr: the expression to iterate over
body_insts: a function that generates the body of the loop
else_insts: a function that generates the else block instructions
"""
# Body of the loop
body_block = BasicBlock()
# Block that steps to the next item
step_block = BasicBlock()
# Block for the else clause, if we need it
else_block = BasicBlock()
# Block executed after the loop
exit_block = BasicBlock()

# Determine where we want to exit, if our condition check fails.
normal_loop_exit = else_block if else_insts is not None else exit_block

for_gen = make_for_loop_generator(builder, index, expr, body_block, normal_loop_exit, line)

builder.push_loop_stack(step_block, exit_block)
condition_block = BasicBlock()
builder.goto_and_activate(condition_block)

# Add loop condition check.
for_gen.gen_condition()

# Generate loop body.
builder.activate_block(body_block)
for_gen.begin_body()
body_insts()

# We generate a separate step block (which might be empty).
builder.goto_and_activate(step_block)
for_gen.gen_step()
# Go back to loop condition.
builder.goto(condition_block)

for_gen.add_cleanup(normal_loop_exit)
builder.pop_loop_stack()

if else_insts is not None:
builder.activate_block(else_block)
else_insts()
builder.goto(exit_block)

builder.activate_block(exit_block)


def translate_list_comprehension(builder: IRBuilder, gen: GeneratorExpr) -> Value:
list_ops = builder.primitive_op(new_list_op, [], gen.line)
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists))

def gen_inner_stmts() -> None:
e = builder.accept(gen.left_expr)
builder.primitive_op(list_append_op, [list_ops, e], gen.line)

comprehension_helper(builder, loop_params, gen_inner_stmts, gen.line)
return list_ops


def comprehension_helper(builder: IRBuilder,
loop_params: List[Tuple[Lvalue, Expression, List[Expression]]],
gen_inner_stmts: Callable[[], None],
line: int) -> None:
"""Helper function for list comprehensions.
"loop_params" is a list of (index, expr, [conditions]) tuples defining nested loops:
- "index" is the Lvalue indexing that loop;
- "expr" is the expression for the object to be iterated over;
- "conditions" is a list of conditions, evaluated in order with short-circuiting,
that must all be true for the loop body to be executed
"gen_inner_stmts" is a function to generate the IR for the body of the innermost loop
"""
def handle_loop(loop_params: List[Tuple[Lvalue, Expression, List[Expression]]]) -> None:
"""Generate IR for a loop.
Given a list of (index, expression, [conditions]) tuples, generate IR
for the nested loops the list defines.
"""
index, expr, conds = loop_params[0]
for_loop_helper(builder, index, expr,
lambda: loop_contents(conds, loop_params[1:]),
None, line)

def loop_contents(
conds: List[Expression],
remaining_loop_params: List[Tuple[Lvalue, Expression, List[Expression]]],
) -> None:
"""Generate the body of the loop.
"conds" is a list of conditions to be evaluated (in order, with short circuiting)
to gate the body of the loop.
"remaining_loop_params" is the parameters for any further nested loops; if it's empty
we'll instead evaluate the "gen_inner_stmts" function.
"""
# Check conditions, in order, short circuiting them.
for cond in conds:
cond_val = builder.accept(cond)
cont_block, rest_block = BasicBlock(), BasicBlock()
# If the condition is true we'll skip the continue.
builder.add_bool_branch(cond_val, rest_block, cont_block)
builder.activate_block(cont_block)
builder.nonlocal_control[-1].gen_continue(builder, cond.line)
builder.goto_and_activate(rest_block)

if remaining_loop_params:
# There's another nested level, so the body of this loop is another loop.
return handle_loop(remaining_loop_params)
else:
# We finally reached the actual body of the generator.
# Generate the IR for the inner loop body.
gen_inner_stmts()

handle_loop(loop_params)


def make_for_loop_generator(builder: IRBuilder,
index: Lvalue,
expr: Expression,
body_block: BasicBlock,
loop_exit: BasicBlock,
line: int,
nested: bool = False) -> 'ForGenerator':
"""Return helper object for generating a for loop over an iterable.
If "nested" is True, this is a nested iterator such as "e" in "enumerate(e)".
"""

rtyp = builder.node_type(expr)
if is_sequence_rprimitive(rtyp):
# Special case "for x in <list>".
expr_reg = builder.accept(expr)
target_type = builder.get_sequence_type(expr)

for_list = ForSequence(builder, index, body_block, loop_exit, line, nested)
for_list.init(expr_reg, target_type, reverse=False)
return for_list

if (isinstance(expr, CallExpr)
and isinstance(expr.callee, RefExpr)):
if (expr.callee.fullname == 'builtins.range'
and (len(expr.args) <= 2
or (len(expr.args) == 3
and builder.extract_int(expr.args[2]) is not None))
and set(expr.arg_kinds) == {ARG_POS}):
# Special case "for x in range(...)".
# We support the 3 arg form but only for int literals, since it doesn't
# seem worth the hassle of supporting dynamically determining which
# direction of comparison to do.
if len(expr.args) == 1:
start_reg = builder.add(LoadInt(0))
end_reg = builder.accept(expr.args[0])
else:
start_reg = builder.accept(expr.args[0])
end_reg = builder.accept(expr.args[1])
if len(expr.args) == 3:
step = builder.extract_int(expr.args[2])
assert step is not None
if step == 0:
builder.error("range() step can't be zero", expr.args[2].line)
else:
step = 1

for_range = ForRange(builder, index, body_block, loop_exit, line, nested)
for_range.init(start_reg, end_reg, step)
return for_range

elif (expr.callee.fullname == 'builtins.enumerate'
and len(expr.args) == 1
and expr.arg_kinds == [ARG_POS]
and isinstance(index, TupleExpr)
and len(index.items) == 2):
# Special case "for i, x in enumerate(y)".
lvalue1 = index.items[0]
lvalue2 = index.items[1]
for_enumerate = ForEnumerate(builder, index, body_block, loop_exit, line,
nested)
for_enumerate.init(lvalue1, lvalue2, expr.args[0])
return for_enumerate

elif (expr.callee.fullname == 'builtins.zip'
and len(expr.args) >= 2
and set(expr.arg_kinds) == {ARG_POS}
and isinstance(index, TupleExpr)
and len(index.items) == len(expr.args)):
# Special case "for x, y in zip(a, b)".
for_zip = ForZip(builder, index, body_block, loop_exit, line, nested)
for_zip.init(index.items, expr.args)
return for_zip

if (expr.callee.fullname == 'builtins.reversed'
and len(expr.args) == 1
and expr.arg_kinds == [ARG_POS]
and is_sequence_rprimitive(rtyp)):
# Special case "for x in reversed(<list>)".
expr_reg = builder.accept(expr.args[0])
target_type = builder.get_sequence_type(expr)

for_list = ForSequence(builder, index, body_block, loop_exit, line, nested)
for_list.init(expr_reg, target_type, reverse=True)
return for_list

# Default to a generic for loop.
expr_reg = builder.accept(expr)
for_obj = ForIterable(builder, index, body_block, loop_exit, line, nested)
item_type = builder._analyze_iterable_item_type(expr)
item_rtype = builder.type_to_rtype(item_type)
for_obj.init(expr_reg, item_rtype)
return for_obj


class ForGenerator:
"""Abstract base class for generating for loops."""

def __init__(self,
builder: 'mypyc.irbuild.builder.IRBuilder',
builder: IRBuilder,
index: Lvalue,
body_block: BasicBlock,
loop_exit: BasicBlock,
Expand Down Expand Up @@ -122,7 +341,7 @@ def gen_cleanup(self) -> None:


def unsafe_index(
builder: 'mypyc.irbuild.builder.IRBuilder', target: Value, index: Value, line: int
builder: IRBuilder, target: Value, index: Value, line: int
) -> Value:
"""Emit a potentially unsafe index into a target."""
# This doesn't really fit nicely into any of our data-driven frameworks
Expand Down Expand Up @@ -297,7 +516,8 @@ def init(self, index1: Lvalue, index2: Lvalue, expr: Expression) -> None:
self.line, nested=True)
self.index_gen.init()
# Iterate over the actual iterable.
self.main_gen = self.builder.make_for_loop_generator(
self.main_gen = make_for_loop_generator(
self.builder,
index2,
expr,
self.body_block,
Expand Down Expand Up @@ -336,7 +556,8 @@ def init(self, indexes: List[Lvalue], exprs: List[Expression]) -> None:
self.cond_blocks = [BasicBlock() for _ in range(len(indexes) - 1)] + [self.body_block]
self.gens = [] # type: List[ForGenerator]
for index, expr, next_block in zip(indexes, exprs, self.cond_blocks):
gen = self.builder.make_for_loop_generator(
gen = make_for_loop_generator(
self.builder,
index,
expr,
next_block,
Expand Down
40 changes: 38 additions & 2 deletions mypyc/irbuild/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,17 @@ def f(x: int) -> int:
from collections import OrderedDict
from typing import List, Dict, Callable, Any, TypeVar, cast

from mypy.nodes import MypyFile, Expression
from mypy.nodes import MypyFile, Expression, ClassDef
from mypy.types import Type
from mypy.state import strict_optional_set
from mypy.build import Graph

from mypyc.common import TOP_LEVEL_NAME
from mypyc.errors import Errors
from mypyc.options import CompilerOptions
from mypyc.ir.rtypes import none_rprimitive
from mypyc.ir.module_ir import ModuleIR, ModuleIRs
from mypyc.ir.func_ir import FuncIR, FuncDecl, FuncSignature
from mypyc.irbuild.prebuildvisitor import PreBuildVisitor
from mypyc.irbuild.vtable import compute_vtable
from mypyc.irbuild.prepare import build_type_map
Expand Down Expand Up @@ -72,7 +75,7 @@ def build_ir(modules: List[MypyFile],
visitor.builder = builder

# Second pass does the bulk of the work.
builder.visit_mypy_file(module)
transform_mypy_file(builder, module)
module_ir = ModuleIR(
module.fullname,
list(builder.imports),
Expand All @@ -89,3 +92,36 @@ def build_ir(modules: List[MypyFile],
compute_vtable(cir)

return result


def transform_mypy_file(builder: IRBuilder, mypyfile: MypyFile) -> None:
if mypyfile.fullname in ('typing', 'abc'):
# These module are special; their contents are currently all
# built-in primitives.
return

builder.set_module(mypyfile.fullname, mypyfile.path)

classes = [node for node in mypyfile.defs if isinstance(node, ClassDef)]

# Collect all classes.
for cls in classes:
ir = builder.mapper.type_to_ir[cls.info]
builder.classes.append(ir)

builder.enter('<top level>')

# Make sure we have a builtins import
builder.gen_import('builtins', -1)

# Generate ops.
for node in mypyfile.defs:
builder.accept(node)
builder.maybe_add_implicit_return()

# Generate special function representing module top level.
blocks, env, ret_type, _ = builder.leave()
sig = FuncSignature([], none_rprimitive)
func_ir = FuncIR(FuncDecl(TOP_LEVEL_NAME, None, builder.module_name, sig), blocks, env,
traceback_name="<module>")
builder.functions.append(func_ir)
Loading

0 comments on commit b2edab2

Please sign in to comment.