diff --git a/ljd/ast/unwarper.py b/ljd/ast/unwarper.py index 5713c79..0758ca5 100644 --- a/ljd/ast/unwarper.py +++ b/ljd/ast/unwarper.py @@ -238,15 +238,16 @@ def _split_by_slot_use(statements, min_i, warp, slot): split_i = min_i for i, statement in enumerate(statements): - sets, _uses = _extract_statement_slots(statement) + if isinstance(statement, nodes.Assignment): + sets = _extract_destination_slots(statement) - if i < min_i: - known_slots |= sets - else: - known_slots -= sets + if i < min_i: + known_slots |= sets + else: + known_slots -= sets - if len(known_slots) == 0: - break + if len(known_slots) == 0: + break split_i = i + 1 @@ -262,33 +263,17 @@ def _split_by_slot_use(statements, min_i, warp, slot): return split_i -def _extract_statement_slots(statement): +def _extract_destination_slots(statement): sets = set() - uses = set() - if isinstance(statement, nodes.Assignment): - for node in statement.destinations.contents: - if isinstance(node, nodes.Identifier): - if node.type == nodes.Identifier.T_SLOT: - sets.add(node.slot) - else: - # Anything else is a use action, not a set action - uses.update(_gather_slots(node)) - - uses.update(_gather_slots(statement.expressions)) - elif isinstance(statement, (nodes.IteratorFor, nodes.NumericFor)): - uses = _gather_slots(statement.expressions) - elif isinstance(statement, nodes.Return): - uses = _gather_slots(statement.returns) - elif isinstance(statement, nodes.BlackHole): - uses = _gather_slots(statement.contents) - elif isinstance(statement, nodes.FunctionCall): - uses = _gather_slots(statement.arguments) - uses.update(_gather_slots(statement.function)) - elif isinstance(statement, nodes.While): - uses = _gather_slots(statement.expression) - - return sets, uses + for node in statement.destinations.contents: + if not isinstance(node, nodes.Identifier): + continue + + if node.type == nodes.Identifier.T_SLOT: + sets.add(node.slot) + + return sets def _gather_slots(node):