Skip to content

Commit

Permalink
Better performance
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed May 23, 2024
1 parent 150d8a0 commit 07ff177
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 77 deletions.
198 changes: 127 additions & 71 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,37 @@ end

showerror(io::IO, e::NotImplementedError) = print(io, "NotImplementedError: " * e.message)

"""
Type stable iterator over NamedTuples. Only supports `Base.foreach`.
```jldoctest
julia> named_tuple = (a = 1, b = 2)
julia> foreach(GraphPPL.NamedTupleIterator(named_tuple)) do key, value
println(key, " ", value)
end
a 1
b 2
```
"""
struct NamedTupleIterator{T}
namedtuple::T
end

function Base.foreach(f::F, iterator::NamedTupleIterator) where {F}
return named_tuple_iterator_foreach(f, keys(iterator.namedtuple), values(iterator.namedtuple))
end

function named_tuple_iterator_foreach(f::F, keys::K, values::V) where {F, K, V}
if length(keys) === 0
return nothing
end
first_key, remaining_keys = Base.first(keys), Base.tail(keys)
first_value, remaining_values = Base.first(values), Base.tail(values)
f(first_key, first_value)
return named_tuple_iterator_foreach(f, remaining_keys, remaining_values)
end

"""
FunctionalIndex
Expand Down Expand Up @@ -136,6 +167,56 @@ Base.:(==)(left::IndexedVariable, right::IndexedVariable) = (left.name == right.
Base.show(io::IO, variable::IndexedVariable{Nothing}) = print(io, variable.name)
Base.show(io::IO, variable::IndexedVariable) = print(io, variable.name, "[", variable.index, "]")

"""
NodeType
Abstract type representing either `Composite` or `Atomic` trait for a given object. By default is `Atomic` unless specified otherwise.
"""
abstract type NodeType end

"""
Composite
`Composite` object used as a trait of structs and functions that are composed of multiple nodes and therefore implement `make_node!`.
"""
struct Composite <: NodeType end

"""
Atomic
`Atomic` object used as a trait of structs and functions that are composed of a single node and are therefore materialized as a single node in the factor graph.
"""
struct Atomic <: NodeType end

NodeType(backend, fform) = error("Backend $backend must implement a method for `NodeType` for `$(fform)`.")

"""
NodeBehaviour
Abstract type representing either `Deterministic` or `Stochastic` for a given object. By default is `Deterministic` unless specified otherwise.
"""
abstract type NodeBehaviour end

"""
Stochastic
`Stochastic` object used to parametrize factor node object with stochastic type of relationship between variables.
"""
struct Stochastic <: NodeBehaviour end

"""
Deterministic
`Deterministic` object used to parametrize factor node object with determinstic type of relationship between variables.
"""
struct Deterministic <: NodeBehaviour end

"""
NodeBehaviour(backend, fform)
Returns a `NodeBehaviour` object for a given `backend` and `fform`.
"""
NodeBehaviour(backend, fform) = error("Backend $backend must implement a method for `NodeBehaviour` for `$(fform)`.")

"""
FactorID(fform, index)
Expand Down Expand Up @@ -184,6 +265,9 @@ setcounter!(model::Model, value) = model.counter[] = value
Graphs.savegraph(file::AbstractString, model::GraphPPL.Model) = save(file, "__model__", model)
Graphs.loadgraph(file::AbstractString, ::Type{GraphPPL.Model}) = load(file, "__model__")

NodeType(model::Model, fform::F) where {F} = NodeType(getbackend(model), fform)
NodeBehaviour(model::Model, fform::F) where {F} = NodeBehaviour(getbackend(model), fform)

"""
NodeLabel(name, global_counter::Int64)
Expand Down Expand Up @@ -834,27 +918,50 @@ internal_collection_typeof(::Type{VariableRef{M, C, O, I, E, L}}) where {M, C, O
external_collection(ref::VariableRef) = ref.external_collection
internal_collection(ref::VariableRef) = ref.internal_collection

function VariableRef(model::Model, context::Context, name::Symbol, index, external_collection = nothing)
return VariableRef(model, context, NodeCreationOptions(), name, index, external_collection)
end

Base.show(io::IO, ref::VariableRef) = variable_ref_show(io, ref.name, ref.index)
variable_ref_show(io::IO, name::Symbol, index::Nothing) = print(io, name)
variable_ref_show(io::IO, name::Symbol, index::Tuple{Nothing}) = print(io, name)
variable_ref_show(io::IO, name::Symbol, index::Tuple) = print(io, name, "[", join(index, ","), "]")
variable_ref_show(io::IO, name::Symbol, index::Any) = print(io, name, "[", index, "]")

function VariableRef(
model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index::Tuple, external_collection = nothing
)
function makevarref(fform::F, model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index::Tuple) where {F}
return makevarref(NodeType(model, fform), model, context, options, name, index)
end

function makevarref(::Atomic, model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index::Tuple)
# In the case of `Atomic` variable reference, we always create the variable
# (unless the index is empty, which may happen during broadcasting)
internal_collection = isempty(index) ? nothing : getorcreate!(model, context, name, index...)
return VariableRef(model, context, options, name, index, nothing, internal_collection)
end

function makevarref(::Composite, model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index::Tuple)
# In the case of `Composite` variable reference, we create it immediatelly only when the variable is instantiated
# with indexing operation
internal_collection = if !all(isnothing, index)
getorcreate!(model, context, name, index...)
elseif haskey(context, name)
context[name]
else
nothing
end
return VariableRef(model, context, options, name, index, external_collection, internal_collection)
return VariableRef(model, context, options, name, index, nothing, internal_collection)
end

function VariableRef(
model::Model,
context::Context,
options::NodeCreationOptions,
name::Symbol,
index::Tuple,
external_collection = nothing,
internal_collection = nothing
)
M = typeof(model)
C = typeof(context)
O = typeof(options)
I = typeof(index)
E = typeof(external_collection)
L = typeof(internal_collection)
return VariableRef{M, C, O, I, E, L}(model, context, options, name, index, external_collection, internal_collection)
end

function unroll(p::ProxyLabel, ref::VariableRef, index, maycreate, liftedindex)
Expand Down Expand Up @@ -1196,58 +1303,6 @@ end
Base.getindex(context::Context, ivar::IndexedVariable{Nothing}) = context[getname(ivar)]
Base.getindex(context::Context, ivar::IndexedVariable) = context[getname(ivar)][index(ivar)]

"""
NodeType
Abstract type representing either `Composite` or `Atomic` trait for a given object. By default is `Atomic` unless specified otherwise.
"""
abstract type NodeType end

"""
Composite
`Composite` object used as a trait of structs and functions that are composed of multiple nodes and therefore implement `make_node!`.
"""
struct Composite <: NodeType end

"""
Atomic
`Atomic` object used as a trait of structs and functions that are composed of a single node and are therefore materialized as a single node in the factor graph.
"""
struct Atomic <: NodeType end

NodeType(backend, fform) = error("Backend $backend must implement a method for `NodeType` for `$(fform)`.")
NodeType(model::Model, fform::F) where {F} = NodeType(getbackend(model), fform)

"""
NodeBehaviour
Abstract type representing either `Deterministic` or `Stochastic` for a given object. By default is `Deterministic` unless specified otherwise.
"""
abstract type NodeBehaviour end

"""
Stochastic
`Stochastic` object used to parametrize factor node object with stochastic type of relationship between variables.
"""
struct Stochastic <: NodeBehaviour end

"""
Deterministic
`Deterministic` object used to parametrize factor node object with determinstic type of relationship between variables.
"""
struct Deterministic <: NodeBehaviour end

"""
NodeBehaviour(backend, fform)
Returns a `NodeBehaviour` object for a given `backend` and `fform`.
"""
NodeBehaviour(backend, fform) = error("Backend $backend must implement a method for `NodeBehaviour` for `$(fform)`.")
NodeBehaviour(model::Model, fform::F) where {F} = NodeBehaviour(getbackend(model), fform)

"""
aliases(backend, fform)
Expand Down Expand Up @@ -1290,7 +1345,7 @@ This function copies the variables in the Markov blanket of the parent context s
- `interfaces::NamedTuple`: A named tuple that maps child variable names to parent variable names.
"""
function copy_markov_blanket_to_child_context(child_context::Context, interfaces::NamedTuple)
for (name_in_child, object_in_parent) in zip(keys(interfaces), values(interfaces))
foreach(NamedTupleIterator(interfaces)) do name_in_child, object_in_parent
add_to_child_context(child_context, name_in_child, object_in_parent)
end
end
Expand Down Expand Up @@ -1732,10 +1787,12 @@ function prepare_interfaces(::StaticInterfaces{I}, fform::F, lhs_interface, rhs_
return NamedTuple{(missing_interface, keys(rhs_interfaces)...)}((lhs_interface, values(rhs_interfaces)...))
end

materialize_interface(interface) = interface
function materialize_interface(model, context, interface)
return getifcreated(model, context, unroll(interface))
end

function materialze_interfaces(interfaces::NamedTuple)
return map(materialize_interface, interfaces)
function materialze_interfaces(model, context, interfaces)
return map(interface -> materialize_interface(model, context, interface), interfaces)
end

"""
Expand Down Expand Up @@ -1970,8 +2027,9 @@ function make_node!(
NamedTuple, interface_aliases(model, fform, StaticInterfaces(keys(rhs_interfaces))), values(rhs_interfaces)
)
aliased_fform = factor_alias(model, fform, StaticInterfaces(keys(aliased_rhs_interfaces)))
interfaces = materialze_interfaces(prepare_interfaces(model, aliased_fform, lhs_interface, aliased_rhs_interfaces))
nodeid, _, _ = materialize_factor_node!(model, context, options, aliased_fform, interfaces)
interfaces = materialze_interfaces(model, context, prepare_interfaces(model, aliased_fform, lhs_interface, aliased_rhs_interfaces))
sorted_interfaces = sort_interfaces(model, aliased_fform, interfaces)
nodeid, _, _ = materialize_factor_node!(model, context, options, aliased_fform, sorted_interfaces)
return nodeid, unroll(lhs_interface)
end

Expand All @@ -1984,10 +2042,8 @@ function sort_interfaces(::StaticInterfaces{I}, defined_interfaces::NamedTuple)
end

function materialize_factor_node!(model::Model, context::Context, options::NodeCreationOptions, fform::F, interfaces::NamedTuple) where {F}
interfaces = sort_interfaces(model, fform, interfaces)
interfaces = map(interface -> getifcreated(model, context, unroll(interface)), interfaces)
factor_node_id, factor_node_data, factor_node_properties = add_atomic_factor_node!(model, context, options, fform)
for (interface_name, interface) in zip(keys(interfaces), values(interfaces))
foreach(NamedTupleIterator(interfaces)) do interface_name, interface
add_edge!(model, factor_node_id, factor_node_properties, interface, interface_name)
end
return factor_node_id, factor_node_data, factor_node_properties
Expand Down
13 changes: 7 additions & 6 deletions src/model_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ function add_get_or_create_expression(e::Expr)
if @capture(e, (lhs_ ~ rhs_ where {options__}))
@capture(lhs, (var_[index__]) | (var_))
return quote
$(generate_get_or_create(var, index))
$(generate_get_or_create(var, index, rhs))
$e
end
end
Expand All @@ -371,7 +371,7 @@ Generates code to get or create a variable in the graph. This function is used t
# Returns
A `quote` block with the code to get or create the variable in the graph.
"""
generate_get_or_create(s::Symbol, index::Nothing) = generate_get_or_create(s, :((nothing,)))
generate_get_or_create(s::Symbol, index::Nothing, rhs) = generate_get_or_create(s, :((nothing,)), rhs)

"""
generate_get_or_create(s::Symbol, lhs::Expr, index::AbstractArray)
Expand All @@ -385,12 +385,13 @@ Generates code to get or create a variable in the graph. This function is used t
# Returns
A `quote` block with the code to get or create the variable in the graph.
"""
generate_get_or_create(s::Symbol, index::AbstractArray) = generate_get_or_create(s, :(($(index...),)))
generate_get_or_create(s::Symbol, index::AbstractArray, rhs) = generate_get_or_create(s, :(($(index...),)), rhs)

function generate_get_or_create(s::Symbol, index::Expr)
function generate_get_or_create(s::Symbol, index::Expr, rhs)
type = @capture(rhs, (f_()) | (f_(args__) | (f_(; kwargs__)) | (f_(args__; kwargs__)))) ? f : :(GraphPPL.Composite())
return quote
$s = if !@isdefined($s)
GraphPPL.VariableRef(__model__, __context__, $(QuoteNode(s)), $(index))
GraphPPL.makevarref($type, __model__, __context__, GraphPPL.NodeCreationOptions(), $(QuoteNode(s)), $(index))
else
$s
end
Expand Down Expand Up @@ -597,7 +598,7 @@ function convert_tilde_expression(e::Expr)
combinable_args = kwargs === nothing ? args : vcat(args, [kwarg.args[2] for kwarg in kwargs])
@capture(lhs, (var_[index__]) | (var_)) || error("Invalid left-hand side $(lhs). Must be in a `var` or `var[index]` form.")
combinablesym = gensym()
getorcreate_lhs = generate_get_or_create(var, :(GraphPPL.__combine_axes($combinablesym...)))
getorcreate_lhs = generate_get_or_create(var, :(GraphPPL.__combine_axes($combinablesym...)), :(($fform)()))
returnvalsym = gensym()
return quote
$combinablesym = ($(combinable_args...),)
Expand Down
34 changes: 34 additions & 0 deletions test/graph_engine_tests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
@testitem "NamedTupleIterator" begin
values = Dict()

foreach(GraphPPL.NamedTupleIterator((a = 1, b = 2.0))) do key, value
values[key] = value
end

@test values[:a] === 1
@test values[:b] === 2.0

foreach(GraphPPL.NamedTupleIterator((c = 1, d = 2.0))) do key, value
values[key] = value
end

@test values[:a] === 1
@test values[:b] === 2.0
@test values[:c] === 1
@test values[:d] === 2.0
end

@testitem "IndexedVariable" begin
import GraphPPL: IndexedVariable, CombinedRange, SplittedRange, getname, index

Expand Down Expand Up @@ -707,6 +727,7 @@ end
@testitem "`VariableRef` in combination with `ProxyLabel` should create variables in the model" begin
import GraphPPL:
VariableRef,
makevarref,
getcontext,
getifcreated,
unroll,
Expand All @@ -721,6 +742,8 @@ end
MissingCollection,
getorcreate!

using Distributions

include("testutils.jl")

@testset "Individual variable creation" begin
Expand Down Expand Up @@ -797,6 +820,17 @@ end
@test ctx[:x] === unroll(proxylabel(:x, xref, nothing, False()))
@test getifcreated(model, ctx, xref) === ctx[:x]
end

@testset "Variable should be created if the `Atomic` fform is used as a first argument with `makevarref`" begin
model = create_test_model()
ctx = getcontext(model)
# `x` is not created here, but `makevarref` takes into account the `Atomic/Composite`
# we always create a variable when used with `Atomic`
xref = makevarref(Normal, model, ctx, NodeCreationOptions(), :x, (nothing,))
# `@inferred` here is important for simple use cases like `x ~ Normal(0, 1)`, so
# `x` can be inferred properly
@test ctx[:x] === @inferred(unroll(proxylabel(:x, xref, nothing, False())))
end
end

@testitem "NodeLabel properties" begin
Expand Down

0 comments on commit 07ff177

Please sign in to comment.