Skip to content

Commit

Permalink
massive upgrade and update to ngc-learn, v0.2.0 about to released
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed May 9, 2022
1 parent c31a283 commit e151d54
Showing 1 changed file with 55 additions and 47 deletions.
102 changes: 55 additions & 47 deletions tests/test_fun_dynamics.py → tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,28 +89,27 @@
integrate_kernel=integrate_cfg)

# create cable wiring scheme relating nodes to one another
dcable_cfg = {"type": "dense", "has_bias": True,
"init" : ("diagonal",1), "seed" : seed} #classic_glorot
dcable_cfg = {"type": "dense", "init" : ("diagonal",1), "seed" : seed}
pos_scable_cfg = {"type": "simple", "coeff": 1.0}
neg_scable_cfg = {"type": "simple", "coeff": -1.0}

z3_mu2 = z3.wire_to(mu2, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg)
mu2.wire_to(e2, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg)
z2.wire_to(e2, src_var="z", dest_var="pred_targ", cable_kernel=pos_scable_cfg)
e2.wire_to(z3, src_var="phi(z)", dest_var="dz_bu", mirror_path_kernel=(z3_mu2,"symm_tied"))
e2.wire_to(z2, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg)
z3_mu2 = z3.wire_to(mu2, src_comp="phi(z)", dest_comp="dz_td", cable_kernel=dcable_cfg)
mu2.wire_to(e2, src_comp="phi(z)", dest_comp="pred_mu", cable_kernel=pos_scable_cfg)
z2.wire_to(e2, src_comp="z", dest_comp="pred_targ", cable_kernel=pos_scable_cfg)
e2.wire_to(z3, src_comp="phi(z)", dest_comp="dz_bu", mirror_path_kernel=(z3_mu2,"symm_tied"))
e2.wire_to(z2, src_comp="phi(z)", dest_comp="dz_td", cable_kernel=neg_scable_cfg)

z2_mu1 = z2.wire_to(mu1, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg)
mu1.wire_to(e1, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg)
z1.wire_to(e1, src_var="z", dest_var="pred_targ", cable_kernel=pos_scable_cfg)
e1.wire_to(z2, src_var="phi(z)", dest_var="dz_bu", mirror_path_kernel=(z2_mu1,"symm_tied"))
e1.wire_to(z1, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg)
z2_mu1 = z2.wire_to(mu1, src_comp="phi(z)", dest_comp="dz_td", cable_kernel=dcable_cfg)
mu1.wire_to(e1, src_comp="phi(z)", dest_comp="pred_mu", cable_kernel=pos_scable_cfg)
z1.wire_to(e1, src_comp="z", dest_comp="pred_targ", cable_kernel=pos_scable_cfg)
e1.wire_to(z2, src_comp="phi(z)", dest_comp="dz_bu", mirror_path_kernel=(z2_mu1,"symm_tied"))
e1.wire_to(z1, src_comp="phi(z)", dest_comp="dz_td", cable_kernel=neg_scable_cfg)

z1_mu0 = z1.wire_to(mu0, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg)
mu0.wire_to(e0, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg)
z0.wire_to(e0, src_var="phi(z)", dest_var="pred_targ", cable_kernel=pos_scable_cfg)
e0.wire_to(z1, src_var="phi(z)", dest_var="dz_bu", mirror_path_kernel=(z1_mu0,"symm_tied"))
e0.wire_to(z0, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg)
z1_mu0 = z1.wire_to(mu0, src_comp="phi(z)", dest_comp="dz_td", cable_kernel=dcable_cfg)
mu0.wire_to(e0, src_comp="phi(z)", dest_comp="pred_mu", cable_kernel=pos_scable_cfg)
z0.wire_to(e0, src_comp="phi(z)", dest_comp="pred_targ", cable_kernel=pos_scable_cfg)
e0.wire_to(z1, src_comp="phi(z)", dest_comp="dz_bu", mirror_path_kernel=(z1_mu0,"symm_tied"))
e0.wire_to(z0, src_comp="phi(z)", dest_comp="dz_td", cable_kernel=neg_scable_cfg)

# set up update rules and make relevant edges aware of these
z3_mu2.set_update_rule(preact=(z3,"phi(z)"), postact=(e2,"phi(z)"))
Expand All @@ -125,6 +124,7 @@
ngc_model.set_cycle(nodes=[mu2,mu1,mu0])
ngc_model.set_cycle(nodes=[e2,e1,e0])
ngc_model.apply_constraints()
info = ngc_model.compile()

# build this NGC model's sampling graph
z3_dim = ngc_model.getNode("z3").dim
Expand All @@ -136,11 +136,12 @@
s2 = FNode(name="s2", dim=z2_dim, act_fx="relu")
s1 = FNode(name="s1", dim=z1_dim, act_fx="relu")
s0 = FNode(name="s0", dim=z0_dim, act_fx="identity")
s3_s2 = s3.wire_to(s2, src_var="phi(z)", dest_var="dz", point_to_path=z3_mu2)
s2_s1 = s2.wire_to(s1, src_var="phi(z)", dest_var="dz", point_to_path=z2_mu1)
s1_s0 = s1.wire_to(s0, src_var="phi(z)", dest_var="dz", point_to_path=z1_mu0)
s3_s2 = s3.wire_to(s2, src_comp="phi(z)", dest_comp="dz", mirror_path_kernel=(z3_mu2,"tied"))
s2_s1 = s2.wire_to(s1, src_comp="phi(z)", dest_comp="dz", mirror_path_kernel=(z2_mu1,"tied"))
s1_s0 = s1.wire_to(s0, src_comp="phi(z)", dest_comp="dz", mirror_path_kernel=(z1_mu0,"tied"))
sampler = ProjectionGraph()
sampler.set_cycle(nodes=[s3,s2,s1,s0])
sampler_info = sampler.compile()

# test projection graph
print("----------------")
Expand All @@ -151,6 +152,7 @@
)
x_sample = readouts[0][2]


print(" => Test for: x = 1 = s2.z = s2.phi(z)")
s3_z = sampler.extract("s3","z")
s3_phi = sampler.extract("s3","phi(z)")
Expand Down Expand Up @@ -182,10 +184,10 @@
# test NGC simulation object
print("----------------")
print(" > Checking NGC simulation object:")
readouts = ngc_model.settle(
clamped_vars=[("z3","z",x),("z0","z",x)],
readout_vars=[("mu0","phi(z)"),("mu1","phi(z)"),("mu2","phi(z)")]
)
readouts, delta = ngc_model.settle(
clamped_vars=[("z3","z",x),("z0","z",x)],
readout_vars=[("mu0","phi(z)"),("mu1","phi(z)"),("mu2","phi(z)")]
)
x_hat = readouts[0][2]

print(" => Test for: 0 = e2.z = e2.phi(z)")
Expand Down Expand Up @@ -219,7 +221,7 @@
print(" PASS!")

print(" => Test for update calculation: all dx should be = 0")
delta = ngc_model.calc_updates()
#delta = ngc_model.calc_updates()
for i in range(len(delta)):
target_dx = ngc_model.theta[i] * 0
dx = delta[i]
Expand All @@ -228,6 +230,7 @@

print("#######################################################################")


print("#######################################################################")
print(" > Testing a proxy NGC graph w/ untied weights")
# set up system nodes
Expand All @@ -247,29 +250,34 @@
integrate_kernel=integrate_cfg)

# create cable wiring scheme relating nodes to one another
z3_mu2 = z3.wire_to(mu2, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg)
mu2.wire_to(e2, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg)
z2.wire_to(e2, src_var="z", dest_var="pred_targ", cable_kernel=pos_scable_cfg)
e2.wire_to(z3, src_var="phi(z)", dest_var="dz_bu", cable_kernel=dcable_cfg)
e2.wire_to(z2, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg)

z2_mu1 = z2.wire_to(mu1, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg)
mu1.wire_to(e1, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg)
z1.wire_to(e1, src_var="z", dest_var="pred_targ", cable_kernel=pos_scable_cfg)
e1.wire_to(z2, src_var="phi(z)", dest_var="dz_bu", cable_kernel=dcable_cfg)
e1.wire_to(z1, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg)

z1_mu0 = z1.wire_to(mu0, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg)
mu0.wire_to(e0, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg)
z0.wire_to(e0, src_var="phi(z)", dest_var="pred_targ", cable_kernel=pos_scable_cfg)
e0.wire_to(z1, src_var="phi(z)", dest_var="dz_bu", cable_kernel=dcable_cfg)
e0.wire_to(z0, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg)
z3_mu2 = z3.wire_to(mu2, src_comp="phi(z)", dest_comp="dz_td", cable_kernel=dcable_cfg)
mu2.wire_to(e2, src_comp="phi(z)", dest_comp="pred_mu", cable_kernel=pos_scable_cfg)
z2.wire_to(e2, src_comp="z", dest_comp="pred_targ", cable_kernel=pos_scable_cfg)
e2_z3 = e2.wire_to(z3, src_comp="phi(z)", dest_comp="dz_bu", cable_kernel=dcable_cfg)
e2.wire_to(z2, src_comp="phi(z)", dest_comp="dz_td", cable_kernel=neg_scable_cfg)

z2_mu1 = z2.wire_to(mu1, src_comp="phi(z)", dest_comp="dz_td", cable_kernel=dcable_cfg)
mu1.wire_to(e1, src_comp="phi(z)", dest_comp="pred_mu", cable_kernel=pos_scable_cfg)
z1.wire_to(e1, src_comp="z", dest_comp="pred_targ", cable_kernel=pos_scable_cfg)
e1_z2 = e1.wire_to(z2, src_comp="phi(z)", dest_comp="dz_bu", cable_kernel=dcable_cfg)
e1.wire_to(z1, src_comp="phi(z)", dest_comp="dz_td", cable_kernel=neg_scable_cfg)

z1_mu0 = z1.wire_to(mu0, src_comp="phi(z)", dest_comp="dz_td", cable_kernel=dcable_cfg)
mu0.wire_to(e0, src_comp="phi(z)", dest_comp="pred_mu", cable_kernel=pos_scable_cfg)
z0.wire_to(e0, src_comp="phi(z)", dest_comp="pred_targ", cable_kernel=pos_scable_cfg)
e0_z1 = e0.wire_to(z1, src_comp="phi(z)", dest_comp="dz_bu", cable_kernel=dcable_cfg)
e0.wire_to(z0, src_comp="phi(z)", dest_comp="dz_td", cable_kernel=neg_scable_cfg)

# set up update rules and make relevant edges aware of these
z3_mu2.set_update_rule(preact=(z3,"phi(z)"), postact=(e2,"phi(z)"))
z2_mu1.set_update_rule(preact=(z2,"phi(z)"), postact=(e1,"phi(z)"))
z1_mu0.set_update_rule(preact=(z1,"phi(z)"), postact=(e0,"phi(z)"))

e2_z3.set_update_rule(preact=(e2,"phi(z)"), postact=(z3,"phi(z)"))
e1_z2.set_update_rule(preact=(e1,"phi(z)"), postact=(z2,"phi(z)"))
e0_z1.set_update_rule(preact=(e0,"phi(z)"), postact=(z1,"phi(z)"))


# Set up graph - execution cycle/order
ngc_model = NGCGraph(K=10, name="gncn_t1_ffm")
ngc_model.proj_update_mag = -1.0
Expand All @@ -281,10 +289,10 @@

print("----------------")
print(" > Checking NGC simulation object:")
readouts = ngc_model.settle(
clamped_vars=[("z3","z",x),("z0","z",x)],
readout_vars=[("mu0","phi(z)"),("mu1","phi(z)"),("mu2","phi(z)")]
)
readouts, delta = ngc_model.settle(
clamped_vars=[("z3","z",x),("z0","z",x)],
readout_vars=[("mu0","phi(z)"),("mu1","phi(z)"),("mu2","phi(z)")]
)
x_hat = readouts[0][2]

print(" => Test for: 0 = e2.z = e2.phi(z)")
Expand Down Expand Up @@ -318,7 +326,7 @@
print(" PASS!")

print(" => Test for update calculation: all dx should be = 0")
delta = ngc_model.calc_updates()
#delta = ngc_model.calc_updates()
for i in range(len(delta)):
target_dx = ngc_model.theta[i] * 0
dx = delta[i]
Expand Down

0 comments on commit e151d54

Please sign in to comment.