Skip to content

Commit

Permalink
store types of sparams during inference instead of values
Browse files Browse the repository at this point in the history
This is more consistent with how other variables and arguments are
treated, and more general.
  • Loading branch information
JeffBezanson committed Feb 2, 2019
1 parent c7338ea commit c712bb1
Show file tree
Hide file tree
Showing 11 changed files with 90 additions and 104 deletions.
25 changes: 5 additions & 20 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -863,21 +863,6 @@ function abstract_eval_cfunction(e::Expr, vtypes::VarTable, sv::InferenceState)
nothing
end

# convert an inferred static parameter value to the inferred type of a static_parameter expression
function sparam_type(@nospecialize(val))
if isa(val, TypeVar)
if Any <: val.ub
# static param bound to typevar
# if the tvar is not known to refer to anything more specific than Any,
# the static param might actually be an integer, symbol, etc.
return Any
else
return UnionAll(val, Type{val})
end
end
return AbstractEvalConstant(val)
end

function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
if isa(e, QuoteNode)
return AbstractEvalConstant((e::QuoteNode).value)
Expand Down Expand Up @@ -940,8 +925,8 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
elseif e.head === :static_parameter
n = e.args[1]
t = Any
if 1 <= n <= length(sv.sp)
t = sparam_type(sv.sp[n])
if 1 <= n <= length(sv.sptypes)
t = sv.sptypes[n]
end
elseif e.head === :method
t = (length(e.args) == 1) ? Any : Nothing
Expand Down Expand Up @@ -975,9 +960,9 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
end
elseif isa(sym, Expr) && sym.head === :static_parameter
n = sym.args[1]
if 1 <= n <= length(sv.sp)
val = sv.sp[n]
if !isa(val, TypeVar)
if 1 <= n <= length(sv.sptypes)
spty = sv.sptypes[n]
if isa(spty, Const)
t = Const(true)
end
end
Expand Down
59 changes: 30 additions & 29 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ const LineNum = Int
mutable struct InferenceState
params::Params # describes how to compute the result
result::InferenceResult # remember where to put the result
linfo::MethodInstance # used here for the tuple (specTypes, env, Method) and world-age validity
sp::SimpleVector # static parameters
linfo::MethodInstance # used here for the tuple (specTypes, env, Method) and world-age validity
sptypes::Vector{Any} # types of static parameter
slottypes::Vector{Any}
mod::Module
currpc::LineNum
Expand Down Expand Up @@ -48,7 +48,7 @@ mutable struct InferenceState
code = src.code::Array{Any,1}
toplevel = !isa(linfo.def, Method)

sp = spvals_from_meth_instance(linfo::MethodInstance)
sp = sptypes_from_meth_instance(linfo::MethodInstance)

nssavalues = src.ssavaluetypes::Int
src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
Expand Down Expand Up @@ -120,7 +120,7 @@ function InferenceState(result::InferenceResult, cached::Bool, params::Params)
return InferenceState(result, src, cached, params)
end

function spvals_from_meth_instance(linfo::MethodInstance)
function sptypes_from_meth_instance(linfo::MethodInstance)
toplevel = !isa(linfo.def, Method)
if !toplevel && isempty(linfo.sparam_vals) && !isempty(linfo.def.sparam_syms)
# linfo is unspecialized
Expand All @@ -130,35 +130,36 @@ function spvals_from_meth_instance(linfo::MethodInstance)
push!(sp, sig.var)
sig = sig.body
end
sp = svec(sp...)
else
sp = linfo.sparam_vals
if _any(t->isa(t,TypeVar), sp)
sp = collect(Any, sp)
end
sp = collect(Any, linfo.sparam_vals)
end
if !isa(sp, SimpleVector)
for i = 1:length(sp)
v = sp[i]
if v isa TypeVar
ub = v.ub
while ub isa TypeVar
ub = ub.ub
end
if has_free_typevars(ub)
ub = Any
end
lb = v.lb
while lb isa TypeVar
lb = lb.lb
end
if has_free_typevars(lb)
lb = Bottom
end
sp[i] = TypeVar(v.name, lb, ub)
for i = 1:length(sp)
v = sp[i]
if v isa TypeVar
ub = v.ub
while ub isa TypeVar
ub = ub.ub
end
if has_free_typevars(ub)
ub = Any
end
lb = v.lb
while lb isa TypeVar
lb = lb.lb
end
if has_free_typevars(lb)
lb = Bottom
end
if Any <: ub && lb <: Bottom
ty = Any
else
tv = TypeVar(v.name, lb, ub)
ty = UnionAll(tv, Type{tv})
end
else
ty = Const(v)
end
sp = svec(sp...)
sp[i] = ty
end
return sp
end
Expand Down
32 changes: 16 additions & 16 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mutable struct OptimizationState
min_valid::UInt
max_valid::UInt
params::Params
sp::SimpleVector # static parameters
sptypes::Vector{Any} # static parameters
slottypes::Vector{Any}
const_api::Bool
function OptimizationState(frame::InferenceState)
Expand All @@ -27,7 +27,7 @@ mutable struct OptimizationState
s_edges::Vector{Any},
src, frame.mod, frame.nargs,
frame.min_valid, frame.max_valid,
frame.params, frame.sp, frame.slottypes, false)
frame.params, frame.sptypes, frame.slottypes, false)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo,
params::Params)
Expand All @@ -54,7 +54,7 @@ mutable struct OptimizationState
s_edges::Vector{Any},
src, inmodule, nargs,
min_world(linfo), max_world(linfo),
params, spvals_from_meth_instance(linfo), slottypes, false)
params, sptypes_from_meth_instance(linfo), slottypes, false)
end
end

Expand Down Expand Up @@ -135,7 +135,7 @@ function isinlineable(m::Method, me::OptimizationState, bonus::Int=0)
end
end
if !inlineable
inlineable = inline_worthy(me.src.code, me.src, me.sp, me.slottypes, me.params, cost_threshold + bonus)
inlineable = inline_worthy(me.src.code, me.src, me.sptypes, me.slottypes, me.params, cost_threshold + bonus)
end
return inlineable
end
Expand All @@ -148,7 +148,7 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
return false
end
if isa(stmt, GotoIfNot)
t = argextype(stmt.cond, ir, ir.spvals)
t = argextype(stmt.cond, ir, ir.sptypes)
return !(t Bool)
end
if isa(stmt, Expr)
Expand All @@ -175,7 +175,7 @@ function optimize(opt::OptimizationState, @nospecialize(result))
proven_pure = true
for i in 1:length(ir.stmts)
stmt = ir.stmts[i]
if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, ir.types[i], ir, ir.spvals)
if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, ir.types[i], ir, ir.sptypes)
proven_pure = false
break
end
Expand Down Expand Up @@ -268,19 +268,19 @@ plus_saturate(x::Int, y::Int) = max(x, y, x+y)
# known return type
isknowntype(@nospecialize T) = (T == Union{}) || isconcretetype(T)

function statement_cost(ex::Expr, line::Int, src::CodeInfo, spvals::SimpleVector, slottypes::Vector{Any}, params::Params)
function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any}, params::Params)
head = ex.head
if is_meta_expr_head(head)
return 0
elseif head === :call
farg = ex.args[1]
ftyp = argextype(farg, src, spvals, slottypes)
ftyp = argextype(farg, src, sptypes, slottypes)
if ftyp === IntrinsicFunction && farg isa SSAValue
# if this comes from code that was already inlined into another function,
# Consts have been widened. try to recover in simple cases.
farg = src.code[farg.id]
if isa(farg, GlobalRef) || isa(farg, QuoteNode) || isa(farg, IntrinsicFunction) || isexpr(farg, :static_parameter)
ftyp = argextype(farg, src, spvals, slottypes)
ftyp = argextype(farg, src, sptypes, slottypes)
end
end
f = singleton_type(ftyp)
Expand All @@ -302,7 +302,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, spvals::SimpleVector
# return plus_saturate(argcost, isknowntype(extyp) ? 1 : params.inline_nonleaf_penalty)
return 0
elseif f === Main.Core.arrayref && length(ex.args) >= 3
atyp = argextype(ex.args[3], src, spvals, slottypes)
atyp = argextype(ex.args[3], src, sptypes, slottypes)
return isknowntype(atyp) ? 4 : params.inline_nonleaf_penalty
end
fidx = find_tfunc(f)
Expand All @@ -325,7 +325,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, spvals::SimpleVector
elseif head === :return
a = ex.args[1]
if a isa Expr
return statement_cost(a, -1, src, spvals, slottypes, params)
return statement_cost(a, -1, src, sptypes, slottypes, params)
end
return 0
elseif head === :(=)
Expand All @@ -336,7 +336,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, spvals::SimpleVector
end
a = ex.args[2]
if a isa Expr
cost = plus_saturate(cost, statement_cost(a, -1, src, spvals, slottypes, params))
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, slottypes, params))
end
return cost
elseif head === :copyast
Expand All @@ -357,13 +357,13 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, spvals::SimpleVector
return 0
end

function inline_worthy(body::Array{Any,1}, src::CodeInfo, spvals::SimpleVector, slottypes::Vector{Any},
function inline_worthy(body::Array{Any,1}, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any},
params::Params, cost_threshold::Integer=params.inline_cost_threshold)
bodycost::Int = 0
for line = 1:length(body)
stmt = body[line]
if stmt isa Expr
thiscost = statement_cost(stmt, line, src, spvals, slottypes, params)::Int
thiscost = statement_cost(stmt, line, src, sptypes, slottypes, params)::Int
elseif stmt isa GotoNode
# loops are generally always expensive
# but assume that forward jumps are already counted for from
Expand All @@ -378,11 +378,11 @@ function inline_worthy(body::Array{Any,1}, src::CodeInfo, spvals::SimpleVector,
return true
end

function is_known_call(e::Expr, @nospecialize(func), src, spvals::SimpleVector, slottypes::Vector{Any} = empty_slottypes)
function is_known_call(e::Expr, @nospecialize(func), src, sptypes::Vector{Any}, slottypes::Vector{Any} = empty_slottypes)
if e.head !== :call
return false
end
f = argextype(e.args[1], src, spvals, slottypes)
f = argextype(e.args[1], src, sptypes, slottypes)
return isa(f, Const) && f.val === func
end

Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ function just_construct_ssa(ci::CodeInfo, code::Vector{Any}, nargs::Int, sv::Opt
@timeit "domtree 1" domtree = construct_domtree(cfg)
ir = let code = Any[nothing for _ = 1:length(code)]
argtypes = sv.slottypes[1:(nargs+1)]
IRCode(code, Any[], ci.codelocs, flags, cfg, collect(LineInfoNode, ci.linetable), argtypes, meta, sv.sp)
IRCode(code, Any[], ci.codelocs, flags, cfg, collect(LineInfoNode, ci.linetable), argtypes, meta, sv.sptypes)
end
@timeit "construct_ssa" ir = construct_ssa!(ci, code, ir, domtree, defuse_insts, nargs, sv.sp, sv.slottypes)
@timeit "construct_ssa" ir = construct_ssa!(ci, code, ir, domtree, defuse_insts, nargs, sv.sptypes, sv.slottypes)
return ir
end

Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ function assemble_inline_todo!(ir::IRCode, linetable::Vector{LineInfoNode}, sv::
isempty(eargs) && continue
arg1 = eargs[1]

ft = argextype(arg1, ir, sv.sp)
ft = argextype(arg1, ir, sv.sptypes)
has_free_typevars(ft) && continue
f = singleton_type(ft)
f === Core.Intrinsics.llvmcall && continue
Expand All @@ -797,7 +797,7 @@ function assemble_inline_todo!(ir::IRCode, linetable::Vector{LineInfoNode}, sv::
atypes[1] = ft
ok = true
for i = 2:length(stmt.args)
a = argextype(stmt.args[i], ir, sv.sp)
a = argextype(stmt.args[i], ir, sv.sptypes)
(a === Bottom || isvarargtype(a)) && (ok = false; break)
atypes[i] = a
end
Expand Down
10 changes: 5 additions & 5 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,20 +213,20 @@ struct IRCode
lines::Vector{Int32}
flags::Vector{UInt8}
argtypes::Vector{Any}
spvals::SimpleVector
sptypes::Vector{Any}
linetable::Vector{LineInfoNode}
cfg::CFG
new_nodes::Vector{NewNode}
meta::Vector{Any}

function IRCode(stmts::Vector{Any}, types::Vector{Any}, lines::Vector{Int32}, flags::Vector{UInt8},
cfg::CFG, linetable::Vector{LineInfoNode}, argtypes::Vector{Any}, meta::Vector{Any},
spvals::SimpleVector)
return new(stmts, types, lines, flags, argtypes, spvals, linetable, cfg, NewNode[], meta)
sptypes::Vector{Any})
return new(stmts, types, lines, flags, argtypes, sptypes, linetable, cfg, NewNode[], meta)
end
function IRCode(ir::IRCode, stmts::Vector{Any}, types::Vector{Any}, lines::Vector{Int32}, flags::Vector{UInt8},
cfg::CFG, new_nodes::Vector{NewNode})
return new(stmts, types, lines, flags, ir.argtypes, ir.spvals, ir.linetable, cfg, new_nodes, ir.meta)
return new(stmts, types, lines, flags, ir.argtypes, ir.sptypes, ir.linetable, cfg, new_nodes, ir.meta)
end
end
copy(code::IRCode) = IRCode(code, copy(code.stmts), copy(code.types),
Expand Down Expand Up @@ -1143,7 +1143,7 @@ function maybe_erase_unused!(extra_worklist, compact, idx, callback = x->nothing
if compact_exprtype(compact, SSAValue(idx)) === Bottom
effect_free = false
else
effect_free = stmt_effect_free(stmt, compact.result_types[idx], compact, compact.ir.spvals)
effect_free = stmt_effect_free(stmt, compact.result_types[idx], compact, compact.ir.sptypes)
end
if effect_free
for ops in userefs(stmt)
Expand Down
10 changes: 5 additions & 5 deletions base/compiler/ssair/legacy.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

inflate_ir(ci::CodeInfo) = inflate_ir(ci, Core.svec(), Any[ Any for i = 1:length(ci.slotnames) ])
inflate_ir(ci::CodeInfo) = inflate_ir(ci, Any[], Any[ Any for i = 1:length(ci.slotnames) ])

function inflate_ir(ci::CodeInfo, linfo::MethodInstance)
spvals = spvals_from_meth_instance(linfo)
sptypes = sptypes_from_meth_instance(linfo)
if ci.inferred
argtypes, _ = matching_cache_argtypes(linfo, nothing)
else
argtypes = Any[ Any for i = 1:length(ci.slotnames) ]
end
return inflate_ir(ci, spvals, argtypes)
return inflate_ir(ci, sptypes, argtypes)
end

function inflate_ir(ci::CodeInfo, spvals::SimpleVector, argtypes::Vector{Any})
function inflate_ir(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any})
code = copy_exprargs(ci.code)
for i = 1:length(code)
if isa(code[i], Expr)
Expand Down Expand Up @@ -46,7 +46,7 @@ function inflate_ir(ci::CodeInfo, spvals::SimpleVector, argtypes::Vector{Any})
end
ssavaluetypes = ci.ssavaluetypes isa Vector{Any} ? copy(ci.ssavaluetypes) : Any[ Any for i = 1:(ci.ssavaluetypes::Int) ]
ir = IRCode(code, ssavaluetypes, copy(ci.codelocs), copy(ci.ssaflags), cfg, collect(LineInfoNode, ci.linetable),
argtypes, Any[], spvals)
argtypes, Any[], sptypes)
return ir
end

Expand Down
Loading

0 comments on commit c712bb1

Please sign in to comment.