Skip to content

Commit

Permalink
Merge pull request ReactiveBayes#227 from ReactiveBayes/duplicate_args
Browse files Browse the repository at this point in the history
Bug fix for duplicate arguments
  • Loading branch information
wouterwln authored Apr 18, 2024
2 parents c9a744a + 5b1ae25 commit 4f8613b
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 12 deletions.
39 changes: 34 additions & 5 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,26 @@ Returns a collection of aliases for `fform` depending on the `backend`.
aliases(backend, fform) = error("Backend $backend must implement a method for `aliases` for `$(fform)`.")
aliases(model::Model, fform::F) where {F} = aliases(getbackend(model), fform)

function add_vertex!(model::Model, label, data)
# This is an unsafe procedure that implements behaviour from `MetaGraphsNext`.
code = nv(model) + 1
model.graph.vertex_labels[code] = label
model.graph.vertex_properties[label] = (code, data)
Graphs.add_vertex!(model.graph.graph)
end

function add_edge!(model::Model, src, dst, data)
# This is an unsafe procedure that implements behaviour from `MetaGraphsNext`.
code_src, code_dst = MetaGraphsNext.code_for(model.graph, src), MetaGraphsNext.code_for(model.graph, dst)
model.graph.edge_data[(src, dst)] = data
return Graphs.add_edge!(model.graph.graph, code_src, code_dst)
end

function has_edge(model::Model, src, dst)
code_src, code_dst = MetaGraphsNext.code_for(model.graph, src), MetaGraphsNext.code_for(model.graph, dst)
return Graphs.has_edge(model.graph.graph, code_src, code_dst)
end

"""
copy_markov_blanket_to_child_context(child_context::Context, interfaces::NamedTuple)
Expand Down Expand Up @@ -1307,10 +1327,8 @@ function add_variable_node!(model::Model, context::Context, options::NodeCreatio
label, nodedata = preprocess_plugins(
UnionPluginType(VariableNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options
)

context[name, index] = label
model[label] = nodedata

add_vertex!(model, label, nodedata)
return label
end

Expand Down Expand Up @@ -1428,7 +1446,7 @@ function add_atomic_factor_node!(model::Model, context::Context, options::NodeCr
UnionPluginType(FactorNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options
)

model[label] = nodedata
add_vertex!(model, label, nodedata)
context[factornode_id] = label

return label, nodedata, convert(FactorNodeProperties, getproperties(nodedata))
Expand Down Expand Up @@ -1500,7 +1518,18 @@ function add_edge!(
label = EdgeLabel(interface_name, index)
neighbor_node_label = unroll(variable_node_id)
addneighbor!(factor_node_propeties, neighbor_node_label, label, model[neighbor_node_label])
model.graph[unroll(variable_node_id), factor_node_id] = label
edge_added = add_edge!(model, neighbor_node_label, factor_node_id, label)
if !edge_added
# Double check if the edge has already been added
if has_edge(model, neighbor_node_label, factor_node_id)
error(
lazy"Trying to create duplicate edge $(label) between variable $(neighbor_node_label) and factor node $(factor_node_id). Make sure that all the arguments to the `~` operator are unique (both left hand side and right hand side)."
)
else
error(lazy"Cannot create an edge $(label) between variable $(neighbor_node_label) and factor node $(factor_node_id).")
end
end
return label
end

function add_edge!(
Expand Down
102 changes: 97 additions & 5 deletions test/graph_construction_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1500,21 +1500,113 @@ end
end
end

@testitem "LazyIndex should support empty indices if array is passed" begin
@testitem "LazyIndex should support empty indices if array is passed" begin
import GraphPPL: create_model, getorcreate!, NodeCreationOptions, LazyIndex

include("testutils.jl")

@model function foo(y)
@model function foo(y)
x ~ MvNormal([1, 1], [1 0.0; 0.0 1.0])
y ~ MvNormal(x, [1.0 0.0; 0.0 1.0])
end

model = create_model(foo()) do model, ctx
return (; y = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :y, LazyIndex([ 1.0, 1.0 ])))
model = create_model(foo()) do model, ctx
return (; y = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :y, LazyIndex([1.0, 1.0])))
end

@test length(collect(filter(as_node(MvNormal), model))) == 2
@test length(collect(filter(as_variable(:x), model))) == 1
@test length(collect(filter(as_variable(:y), model))) == 1
end
end

@testitem "Node arguments must be unique" begin
import GraphPPL: create_model, getorcreate!, NodeCreationOptions, LazyIndex

include("testutils.jl")

@model function simple_model_duplicate_1()
x ~ Normal(0.0, 1.0)
y ~ x + x
end

@model function simple_model_duplicate_2()
x ~ Normal(0.0, 1.0)
y ~ x + x + x
end

@model function simple_model_duplicate_3()
x ~ Normal(0.0, 1.0)
y ~ Normal(x, x)
end

@model function simple_model_duplicate_4()
x ~ Normal(0.0, 1.0)
hide_x = x
y ~ Normal(hide_x, x)
end

@model function simple_model_duplicate_5()
x ~ Normal(0.0, 1.0)
x ~ Normal(x, 1)
end

@model function simple_model_duplicate_6()
x ~ Normal(0.0, 1.0)
hide_x = x
hide_x ~ Normal(x, 1)
end

for modelfn in [
simple_model_duplicate_1,
simple_model_duplicate_2,
simple_model_duplicate_3,
simple_model_duplicate_4,
simple_model_duplicate_5,
simple_model_duplicate_6
]
@test_throws r"Trying to create duplicate edge.*Make sure that all the arguments to the `~` operator are unique.*" create_model(
modelfn()
)
end

@model function my_model(obs, N, sigma)
local x
for i in 1:N
x[i] ~ Bernoulli(0.5)
end
local C
# This model creation is not allowed since `C` is used twice in the `~` operator
for i in 1:N
C ~ C + x[i]
end
obs ~ NormalMeanVariance(C, sigma^2)
end

@test_throws r"Trying to create duplicate edge.*Make sure that all the arguments to the `~` operator are unique.*" create_model(
my_model(N = 3, sigma = 1.0)
) do model, ctx
obs = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :obs, LazyIndex(0.0))
return (obs = obs,)
end

@model function my_model(obs, N, sigma)
local x
for i in 1:N
x[i] ~ Bernoulli(0.5)
end
accum_C = x[1]
for i in 2:N
# Here `next_C` will be used twice on the second iteration
next_C ~ accum_C + x[i]
accum_C = next_C
end
obs ~ NormalMeanVariance(accum_C, sigma^2)
end

@test_throws r"Trying to create duplicate edge.*Make sure that all the arguments to the `~` operator are unique.*" create_model(
my_model(N = 3, sigma = 1.0)
) do model, ctx
obs = getorcreate!(model, ctx, NodeCreationOptions(kind = :data, factorized = true), :obs, LazyIndex(0.0))
return (obs = obs,)
end
end
14 changes: 12 additions & 2 deletions test/graph_engine_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ end
EdgeLabel,
getname,
add_edge!,
has_edge,
getproperties

include("testutils.jl")
Expand All @@ -770,12 +771,21 @@ end
b = NodeLabel(:b, 2)
model[a] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing))
model[b] = NodeData(ctx, FactorNodeProperties(fform = sum))
@test !has_edge(model, a, b)
@test !has_edge(model, b, a)
add_edge!(model, b, getproperties(model[b]), a, :edge, 1)
@test has_edge(model, a, b)
@test has_edge(model, b, a)
@test length(edges(model)) == 1

c = NodeLabel(:c, 2)
model[c] = NodeData(ctx, FactorNodeProperties(fform = sum))
@test !has_edge(model, a, c)
@test !has_edge(model, c, a)
add_edge!(model, c, getproperties(model[c]), a, :edge, 2)
@test has_edge(model, a, c)
@test has_edge(model, c, a)

@test length(edges(model)) == 2

# Test 2: Test getting all edges from a model with a specific node
Expand Down Expand Up @@ -1842,13 +1852,13 @@ end
model = create_test_model()
ctx = getcontext(model)
options = NodeCreationOptions()
x, xdata, xproperties = GraphPPL.add_atomic_factor_node!(model, ctx, options, sum)
y = getorcreate!(model, ctx, :y, nothing)

variable_nodes = [getorcreate!(model, ctx, i, nothing) for i in [:a, :b, :c]]
x, xdata, xproperties = GraphPPL.add_atomic_factor_node!(model, ctx, options, sum)
add_edge!(model, x, xproperties, variable_nodes, :interface)

@test ne(model) == 3 && model[x, variable_nodes[1]] == EdgeLabel(:interface, 1)
@test ne(model) == 3 && model[variable_nodes[1], x] == EdgeLabel(:interface, 1)
end

@testitem "default_parametrization" begin
Expand Down

0 comments on commit 4f8613b

Please sign in to comment.