Skip to content

Commit

Permalink
Backport CFG simplification pass from XLA backend
Browse files Browse the repository at this point in the history
Right now we don't do too much CFG simplification (esp since
we turned off doing so for constant folded branches). This
isn't too much of a problem, because LLVM is very good at
cutting down any excess basic block we happen to emit.

However, for non-SSA backends, excess CFG can be a
significant problem that the backend may not be able
to optimize away (even if it's trivial at the Julia
IR level). Plus it's annoying for humans to read.
The XLA backend had a simple CFG simplification pass.
Backport this pass to Base, so it can live alongside
the code it depends on (it has a fairly close dependency
on the details of the CFG and IncrementalCompact).
As it stands, I don't think it's useful to have this
pass in the default compiler pipeline (both because
LLVM can handle it easily and because our round-trip
to statement based representations cleans up some
of this), but I do think it's useful interactively
and for non-standard compiler backends. We should
re-evaluate whether to put this in the standard
compiler pipeline once we re-enable the CFG
transformations of constant folded conditions. If
that ends up leaving a lot of basic block chains
around, it might yet be worth putting this in.

Co-authored-by: Valentin Churavy <[email protected]>
  • Loading branch information
Keno and vchuravy committed Apr 1, 2019
1 parent 0557467 commit f62a4b3
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 32 deletions.
76 changes: 45 additions & 31 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,8 @@ mutable struct IncrementalCompact
result_flags::Vector{UInt8}
result_bbs::Vector{BasicBlock}
ssa_rename::Vector{Any}
bb_rename::Vector{Int}
bb_rename_pred::Vector{Int}
bb_rename_succ::Vector{Int}
used_ssas::Vector{Int}
late_fixup::Vector{Int}
# This could be Stateful, but bootstrapping doesn't like that
Expand All @@ -481,7 +482,8 @@ mutable struct IncrementalCompact
result_idx::Int
active_result_bb::Int
renamed_new_nodes::Bool
allow_cfg_transforms::Bool
cfg_transforms_enabled::Bool
fold_constant_branches::Bool
function IncrementalCompact(code::IRCode, allow_cfg_transforms::Bool=false)
# Sort by position with attach after nodes affter regular ones
perm = my_sortperm(Int[(code.new_nodes[i].pos*2 + Int(code.new_nodes[i].attach_after)) for i in 1:length(code.new_nodes)])
Expand Down Expand Up @@ -525,9 +527,9 @@ mutable struct IncrementalCompact
new_new_nodes = NewNode[]
pending_nodes = NewNode[]
pending_perm = Int[]
return new(code, result, result_types, result_lines, result_flags, result_bbs, ssa_rename, bb_rename, used_ssas, late_fixup, perm, 1,
return new(code, result, result_types, result_lines, result_flags, result_bbs, ssa_rename, bb_rename, bb_rename, used_ssas, late_fixup, perm, 1,
new_new_nodes, pending_nodes, pending_perm,
1, 1, 1, false, allow_cfg_transforms)
1, 1, 1, false, allow_cfg_transforms, allow_cfg_transforms)
end

# For inlining
Expand All @@ -542,10 +544,10 @@ mutable struct IncrementalCompact
pending_nodes = NewNode[]
pending_perm = Int[]
return new(code, parent.result, parent.result_types, parent.result_lines, parent.result_flags,
parent.result_bbs, ssa_rename, bb_rename, parent.used_ssas,
parent.result_bbs, ssa_rename, bb_rename, bb_rename, parent.used_ssas,
late_fixup, perm, 1,
new_new_nodes, pending_nodes, pending_perm,
1, result_offset, parent.active_result_bb, false, false)
1, result_offset, parent.active_result_bb, false, false, false)
end
end

Expand Down Expand Up @@ -646,6 +648,18 @@ function insert_node!(compact::IncrementalCompact, before, @nospecialize(typ), @
end
end

function append_node!(ir, @nospecialize(typ), @nospecialize(node), line)
push!(ir.stmts, node)
push!(ir.types, typ)
push!(ir.lines, line)
push!(ir.flags, 0)
last_bb = ir.cfg.blocks[end]
ir.cfg.blocks[end] = BasicBlock(first(last_bb.stmts):length(ir.stmts),
last_bb.preds,
last_bb.succs)
return SSAValue(length(ir.stmts))
end

function insert_node_here!(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typ), ltable_idx::Int32, reverse_affinity::Bool=false)
if compact.result_idx > length(compact.result)
@assert compact.result_idx == length(compact.result) + 1
Expand Down Expand Up @@ -823,17 +837,17 @@ function kill_edge!(compact::IncrementalCompact, active_bb::Int, from::Int, to::
# Note: We recursively kill as many edges as are obviously dead. However, this
# may leave dead loops in the IR. We kill these later in a CFG cleanup pass (or
# worstcase during codegen).
preds, succs = compact.result_bbs[compact.bb_rename[to]].preds, compact.result_bbs[compact.bb_rename[from]].succs
deleteat!(preds, findfirst(x->x === compact.bb_rename[from], preds)::Int)
deleteat!(succs, findfirst(x->x === compact.bb_rename[to], succs)::Int)
preds, succs = compact.result_bbs[compact.bb_rename_succ[to]].preds, compact.result_bbs[compact.bb_rename_pred[from]].succs
deleteat!(preds, findfirst(x->x === compact.bb_rename_pred[from], preds)::Int)
deleteat!(succs, findfirst(x->x === compact.bb_rename_succ[to], succs)::Int)
# Check if the block is now dead
if length(preds) == 0
for succ in copy(compact.result_bbs[compact.bb_rename[to]].succs)
kill_edge!(compact, active_bb, to, findfirst(x->x === succ, compact.bb_rename))
for succ in copy(compact.result_bbs[compact.bb_rename_succ[to]].succs)
kill_edge!(compact, active_bb, to, findfirst(x->x === succ, compact.bb_rename_pred))
end
if to < active_bb
# Kill all statements in the block
stmts = compact.result_bbs[compact.bb_rename[to]].stmts
stmts = compact.result_bbs[compact.bb_rename_succ[to]].stmts
for stmt in stmts
compact.result[stmt] = nothing
end
Expand All @@ -842,12 +856,12 @@ function kill_edge!(compact::IncrementalCompact, active_bb::Int, from::Int, to::
else
# We need to remove this edge from any phi nodes
if to < active_bb
idx = first(compact.result_bbs[compact.bb_rename[to]].stmts)
idx = first(compact.result_bbs[compact.bb_rename_succ[to]].stmts)
while idx < length(compact.result)
stmt = compact.result[idx]
stmt === nothing && continue
isa(stmt, PhiNode) || break
i = findfirst(x-> x === compact.bb_rename[from], stmt.edges)
i = findfirst(x-> x === compact.bb_rename_pred[from], stmt.edges)
if i !== nothing
deleteat!(stmt.edges, i)
deleteat!(stmt.values, i)
Expand Down Expand Up @@ -879,34 +893,34 @@ function process_node!(compact::IncrementalCompact, result::Vector{Any},
ssa_rename[idx] = stmt
elseif isa(stmt, OldSSAValue)
ssa_rename[idx] = ssa_rename[stmt.id]
elseif isa(stmt, GotoNode) && compact.allow_cfg_transforms
result[result_idx] = GotoNode(compact.bb_rename[stmt.label])
elseif isa(stmt, GotoNode) && compact.cfg_transforms_enabled
result[result_idx] = GotoNode(compact.bb_rename_succ[stmt.label])
result_idx += 1
elseif isa(stmt, GlobalRef) || isa(stmt, GotoNode)
result[result_idx] = stmt
result_idx += 1
elseif isa(stmt, GotoIfNot) && compact.allow_cfg_transforms
elseif isa(stmt, GotoIfNot) && compact.cfg_transforms_enabled
stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa)::GotoIfNot
result[result_idx] = stmt
cond = stmt.cond
if isa(cond, Bool)
if isa(cond, Bool) && compact.fold_constant_branches
if cond
result[result_idx] = nothing
kill_edge!(compact, active_bb, active_bb, stmt.dest)
# Don't increment result_idx => Drop this statement
else
result[result_idx] = GotoNode(compact.bb_rename[stmt.dest])
result[result_idx] = GotoNode(compact.bb_rename_succ[stmt.dest])
kill_edge!(compact, active_bb, active_bb, active_bb+1)
result_idx += 1
end
else
result[result_idx] = GotoIfNot(cond, compact.bb_rename[stmt.dest])
result[result_idx] = GotoIfNot(cond, compact.bb_rename_succ[stmt.dest])
result_idx += 1
end
elseif isa(stmt, Expr)
stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa)::Expr
if compact.allow_cfg_transforms && isexpr(stmt, :enter)
stmt.args[1] = compact.bb_rename[stmt.args[1]::Int]
if compact.cfg_transforms_enabled && isexpr(stmt, :enter)
stmt.args[1] = compact.bb_rename_succ[stmt.args[1]::Int]
end
result[result_idx] = stmt
result_idx += 1
Expand Down Expand Up @@ -936,13 +950,13 @@ function process_node!(compact::IncrementalCompact, result::Vector{Any},
elseif isa(stmt, PhiNode)
values = process_phinode_values(stmt.values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas, do_rename_ssa)
if length(stmt.edges) == 1 && isassigned(values, 1) &&
length(compact.allow_cfg_transforms ?
compact.result_bbs[compact.bb_rename[active_bb]].preds :
length(compact.cfg_transforms_enabled ?
compact.result_bbs[compact.bb_rename_succ[active_bb]].preds :
compact.ir.cfg.blocks[active_bb].preds) == 1
# There's only one predecessor left - just replace it
ssa_rename[idx] = values[1]
else
edges = compact.allow_cfg_transforms ? map!(i->compact.bb_rename[i], stmt.edges, stmt.edges) : stmt.edges
edges = compact.cfg_transforms_enabled ? map!(i->compact.bb_rename_pred[i], stmt.edges, stmt.edges) : stmt.edges
result[result_idx] = PhiNode(edges, values)
result_idx += 1
end
Expand Down Expand Up @@ -983,14 +997,14 @@ end

function finish_current_bb!(compact, active_bb, old_result_idx=compact.result_idx, unreachable=false)
if compact.active_result_bb > length(compact.result_bbs)
@assert compact.bb_rename[active_bb] == 0
#@assert compact.bb_rename[active_bb] == 0
return true
end
bb = compact.result_bbs[compact.active_result_bb]
# If this was the last statement in the BB and we decided to skip it, insert a
# dummy `nothing` node, to prevent changing the structure of the CFG
skipped = false
if !compact.allow_cfg_transforms || active_bb == 0 || active_bb > length(compact.bb_rename) || compact.bb_rename[active_bb] != 0
if !compact.cfg_transforms_enabled || active_bb == 0 || active_bb > length(compact.bb_rename_succ) || compact.bb_rename_succ[active_bb] != 0
if compact.result_idx == first(bb.stmts)
length(compact.result) < old_result_idx && resize!(compact, old_result_idx)
if unreachable
Expand All @@ -1003,7 +1017,7 @@ function finish_current_bb!(compact, active_bb, old_result_idx=compact.result_id
compact.result_lines[old_result_idx] = 0
compact.result_flags[old_result_idx] = 0x00
compact.result_idx = old_result_idx + 1
elseif compact.allow_cfg_transforms && compact.result_idx - 1 == first(bb.stmts)
elseif compact.cfg_transforms_enabled && compact.result_idx - 1 == first(bb.stmts)
# Optimization: If this BB consists of only a branch, eliminate this bb
end
compact.result_bbs[compact.active_result_bb] = BasicBlock(bb, StmtRange(first(bb.stmts), compact.result_idx-1))
Expand Down Expand Up @@ -1083,7 +1097,7 @@ function iterate(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}=
resize!(compact, old_result_idx)
end
bb = compact.ir.cfg.blocks[active_bb]
if compact.allow_cfg_transforms && active_bb > 1 && active_bb <= length(compact.bb_rename) && length(bb.preds) == 0
if compact.cfg_transforms_enabled && active_bb > 1 && active_bb <= length(compact.bb_rename_succ) && length(bb.preds) == 0
# No predecessors, kill the entire block.
compact.idx = last(bb.stmts)
# Pop any remaining insertion nodes
Expand Down Expand Up @@ -1274,8 +1288,8 @@ function complete(compact::IncrementalCompact)
return IRCode(compact.ir, compact.result, compact.result_types, compact.result_lines, compact.result_flags, cfg, compact.new_new_nodes)
end

function compact!(code::IRCode)
compact = IncrementalCompact(code)
function compact!(code::IRCode, allow_cfg_transforms=false)
compact = IncrementalCompact(code, allow_cfg_transforms)
# Just run through the iterator without any processing
foreach(x -> nothing, compact) # x isa Pair{Int, Any}
return finish(compact)
Expand Down
128 changes: 128 additions & 0 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1012,3 +1012,131 @@ function type_lift_pass!(ir::IRCode)
end
ir
end

function cfg_simplify!(ir::IRCode)
bbs = ir.cfg.blocks
merge_into = zeros(Int, length(bbs))
merged_succ = zeros(Int, length(bbs))

# Walk the CFG at from the entry block and aggressively combine blocks
for (idx, bb) in enumerate(bbs)
if length(bb.succs) == 1
succ = bb.succs[1]
if length(bbs[succ].preds) == 1
merge_into[succ] = idx
merged_succ[idx] = succ
end
end
end
max_bb_num = 1
bb_rename_succ = zeros(Int, length(bbs))
# Lay out the basic blocks
for i = 1:length(bbs)
if merge_into[i] != 0
bb_rename_succ[i] = -1
continue
end
# Drop unreachable blocks
if i != 1 && length(ir.cfg.blocks[i].preds) == 0
bb_rename_succ[i] = -1
end
bb_rename_succ[i] != 0 && continue
curr = i
while true
bb_rename_succ[curr] = max_bb_num
max_bb_num += 1
# Now walk the chain of blocks we merged.
# If we end in something that may fall through,
# we have to schedule that block next
while merged_succ[curr] != 0
curr = merged_succ[curr]
end
terminator = ir.stmts[ir.cfg.blocks[curr].stmts[end]]
if isa(terminator, GotoNode) || isa(terminator, ReturnNode)
break
end
curr += 1
end
end
bb_rename_pred = zeros(Int, length(bbs))
for i = 1:length(bbs)
if merged_succ[i] != 0
bb_rename_pred[i] = -1
continue
end
bbnum = i
while merge_into[bbnum] != 0
bbnum = merge_into[bbnum]
end
bb_rename_pred[i] = bb_rename_succ[bbnum]
end
result_bbs = Int[findfirst(j->i==j, bb_rename_succ) for i = 1:max_bb_num-1]
result_bbs_lengths = zeros(Int, max_bb_num-1)
for (idx, orig_bb) in enumerate(result_bbs)
ms = orig_bb
while ms != 0
result_bbs_lengths[idx] += length(bbs[ms].stmts)
ms = merged_succ[ms]
end
end
bb_starts = Vector{Int}(undef, 1+length(result_bbs_lengths))
bb_starts[1] = 1
for i = 1:length(result_bbs_lengths)
bb_starts[i+1] = bb_starts[i] + result_bbs_lengths[i]
end
# Look at the original successor
function compute_succs(i)
orig_bb = result_bbs[i]
while merged_succ[orig_bb] != 0
orig_bb = merged_succ[orig_bb]
end
map(i->bb_rename_succ[i], bbs[orig_bb].succs)
end

function compute_preds(i)
orig_bb = result_bbs[i]
preds = bbs[orig_bb].preds
map(preds) do pred
while merge_into[pred] != 0
pred = merge_into[pred]
end
bb_rename_succ[pred]
end
end
cresult_bbs = BasicBlock[BasicBlock(
StmtRange(bb_starts[i], i+1 > length(bb_starts) ? length(compact.result) : bb_starts[i+1]-1),
compute_preds(i), compute_succs(i)) for i = 1:length(result_bbs)]
compact = IncrementalCompact(ir, true)
# We're messing with the CFG. We don't want compaction to do
# so independently
compact.fold_constant_branches = false
compact.bb_rename_succ = bb_rename_succ
compact.bb_rename_pred = bb_rename_pred
compact.result_bbs = cresult_bbs
result_idx = 1
for (idx, orig_bb) in enumerate(result_bbs)
ms = orig_bb
while ms != 0
for i in bbs[ms].stmts
stmt = ir.stmts[i]
compact.result[compact.result_idx] = nothing
compact.result_types[compact.result_idx] = ir.types[i]
compact.result_lines[compact.result_idx] = ir.lines[i]
compact.result_flags[compact.result_idx] = ir.flags[i]
# If we merged a basic block, we need remove the trailing GotoNode (if any)
if isa(stmt, GotoNode) && merged_succ[ms] != 0
# Do nothing
else
process_node!(compact, compact.result_idx, stmt, i, i, ms, true)
end
# We always increase the result index to ensure a predicatable
# placement of the resulting nodes.
compact.result_idx += 1
end
ms = merged_succ[ms]
end
end

compact.active_result_bb = length(bb_starts)
return finish(compact)
end
2 changes: 1 addition & 1 deletion base/compiler/ssair/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ function show_ir(io::IO, code::IRCode, expr_type_printer=default_expr_type_print
# Compute BB guard rail
if bb_idx > length(cfg.blocks)
# Even if invariants are violated, try our best to still print
bbrange = (last(cfg.blocks[end].stmts) + 1):typemax(Int)
bbrange = (length(cfg.blocks) == 0 ? 1 : last(cfg.blocks[end].stmts) + 1):typemax(Int)
bb_idx_str = "!"
bb_type = ""
else
Expand Down
1 change: 1 addition & 0 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,7 @@ module IRShow
using Core.IR
import ..Base
import .Compiler: IRCode, ReturnNode, GotoIfNot, CFG, scan_ssa_use!, Argument, isexpr, compute_basic_blocks, block_for_inst
Base.getindex(r::Compiler.StmtRange, ind::Integer) = Compiler.getindex(r, ind)
Base.size(r::Compiler.StmtRange) = Compiler.size(r)
Base.first(r::Compiler.StmtRange) = Compiler.first(r)
Base.last(r::Compiler.StmtRange) = Compiler.last(r)
Expand Down
Loading

0 comments on commit f62a4b3

Please sign in to comment.