Skip to content

Commit

Permalink
Merge pull request #69 from matthew-brett/fix-example
Browse files Browse the repository at this point in the history
MRG: fix example and port to Python 3
  • Loading branch information
matthew-brett committed Jun 8, 2016
2 parents 8168679 + a74780e commit d599a26
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
7 changes: 6 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,15 @@ script:
- if [ "${DOCTESTS}" == "1" ]; then
DOCTEST_ARGS="--with-doctest";
fi
# Run unit tests
- nosetests $COVER_ARGS $DOCTEST_ARGS datarray
# Run example to check for errors
- pip install networkx
- python ../examples/inference_algs.py
# Run doc doctests
- if [ "${DOC_DOCTEST}" == "1" ]; then
pip install sphinx;
cd ../doc && make doctest;
(cd ../doc && make doctest);
fi

after_success:
Expand Down
47 changes: 28 additions & 19 deletions examples/inference_algs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from __future__ import division
from __future__ import division, print_function

import sys
if sys.version_info[0] < 3: # Use range iterator for Python 2
range = xrange
from functools import reduce

import operator
from itertools import combinations

import networkx as nx
import numpy as np

import networkx as nx, numpy as np,itertools as it, operator as op
from datarray import DataArray

from numpy.testing import assert_almost_equal


def test_pearl_network():
""" From Russell and Norvig, "Artificial Intelligence, A Modern Approach,"
Section 15.1 originally from Pearl.
Expand Down Expand Up @@ -72,8 +80,8 @@ def test_pearl_network():
assert_almost_equal(marg,margs1["burglary"])
assert_almost_equal(lik,lik1)

print "p(burglary) = %s" % margs1["burglary"].__array__()
print "likelihood of observations = %.3f" % lik1
print("p(burglary) = %s" % margs1["burglary"].__array__())
print("likelihood of observations = %.3f" % lik1)

####### DataArray utilities ################

Expand All @@ -98,8 +106,8 @@ def match_shape(x,yshape,axes):
if isinstance(axes,int): axes = [axes]
assert len(x.shape) == len(axes)
assert all(xsize == yshape[yax] for xsize,yax in zip(x.shape,axes))
strides = np.zeros(len(yshape))
for yax,xstride in zip(axes,x.strides):
strides = np.zeros(len(yshape), dtype=np.intp)
for yax,xstride in zip(axes,x.strides):
strides[yax] = xstride
return np.ndarray.__new__(np.ndarray, strides=strides, shape=yshape, buffer=x, dtype=x.dtype)

Expand Down Expand Up @@ -133,7 +141,7 @@ def multiply_potentials(*DAs):
if len(DAs) == 0: return 1

full_names, full_shape = [],[]
for axis,size in zip(_sum(DA.axes for DA in DAs), _sum(DA.shape for DA in DAs)):
for axis,size in zip(_sum(list(DA.axes) for DA in DAs), _sum(DA.shape for DA in DAs)):
if axis.name not in full_names:
full_names.append(axis.name)
full_shape.append(size)
Expand Down Expand Up @@ -164,8 +172,8 @@ def sum_over_other_axes(DA, kept_axis_name):
return sum_over_axes(DA,
[axname for axname in DA.names if axname != kept_axis_name])

def _sum(seq): return reduce(op.add, seq)
def _prod(seq): return reduce(op.mul, seq)
def _sum(seq): return reduce(operator.add, seq)
def _prod(seq): return reduce(operator.mul, seq)

####### Simple marginalization #############

Expand All @@ -187,7 +195,7 @@ def calc_marginals_simple(cpts,evidence):
likelihood : likelihood of observations in the model
"""
joint_dist = multiply_potentials(*cpts)
joint_dist = joint_dist.axis.johncalls[evidence['johncalls']].axis.marycalls[evidence['marycalls']]
joint_dist = joint_dist.axes.johncalls[evidence['johncalls']].axes.marycalls[evidence['marycalls']]
return (dict((ax.name, normalize(sum_over_other_axes(joint_dist, ax.name)))
for ax in joint_dist.axes),
joint_dist.sum())
Expand Down Expand Up @@ -221,15 +229,15 @@ def digraph_eliminate(cpts,evidence,query_list):
rvs_elim = [rv for rv in rvs if rv not in query_list] + query_list
for rv in rvs_elim:
# find potentials that reference that node
pots_here = filter(lambda cpt: rv in cpt.names, cpts)
pots_here = [cpt for cpt in cpts if rv in cpt.names]
# remove them from cpts
cpts = filter(lambda cpt: rv not in cpt.names, cpts)
cpts = [cpt for cpt in cpts if rv not in cpt.names]
# Find joint probability distribution of this variable and the ones coupled to it
product_pot = multiply_potentials(*pots_here)
# if node is in query set, we don't sum over it
if rv not in query_list:
# if node is in evidence set, take slice
if rv in evidence: product_pot = product_pot.axis[rv][evidence[rv]]
if rv in evidence: product_pot = product_pot.axes(rv)[evidence[rv]]
# otherwise, sum over it
else: product_pot = product_pot.sum(axis=rv)

Expand All @@ -249,8 +257,9 @@ def cpts2digraph(cpts):
"""
G = nx.DiGraph()
for cpt in cpts:
sources,targ = cpt.axes[:-1],cpt.axes[-1]
G.add_edges_from([(src.name,targ.name) for src in sources])
names = [ax.name for ax in cpt.axes]
target = names[-1]
G.add_edges_from((source, target) for source in names[:-1])
return G

############# Sum-product #############
Expand Down Expand Up @@ -346,7 +355,7 @@ def dfs_edges(G):
(source,target) for edges in directed spanning tree resulting from depth
first search
"""
DG = nx.dfs_tree(G)
DG = nx.dfs_tree(G, source=None)
return [(src,targ) for targ in nx.dfs_postorder_nodes(DG) for src in DG.predecessors(targ)]


Expand Down Expand Up @@ -421,7 +430,7 @@ def triangulate_min_fill(G):
nodes,degrees = zip(*G_elim.degree().items())
min_deg_node = nodes[np.argmin(degrees)]
new_edges = [(n1,n2) for (n1,n2) in
it.combinations(G_elim.neighbors(min_deg_node),2) if not
combinations(G_elim.neighbors(min_deg_node),2) if not
G_elim.has_edge(n1,n2)]
added_edges.extend(new_edges)
G_elim.remove_node(min_deg_node)
Expand All @@ -438,7 +447,7 @@ def make_jtree_from_tri_graph(G):
# (i.e., it satisfies running intersection property)
# where weight is the size of the intersection between adjacent cliques.
CG.add_weighted_edges_from((tuple(c1),tuple(c2),-c1c2)
for (c1,c2) in it.combinations(nx.find_cliques(G),2)
for (c1,c2) in combinations(nx.find_cliques(G),2)
for c1c2 in [len(set(c1).intersection(set(c2)))] if c1c2 > 0)
JT = nx.Graph(nx.mst(CG)) # Minimal weight spanning tree for CliqueGraph
for src,targ in JT.edges():
Expand Down Expand Up @@ -480,7 +489,7 @@ def make_jtree_from_factors(factors):
def moral_graph_from_factors(factors):
G = nx.Graph()
for factor in factors:
for label1,label2 in it.combinations(factor.names, 2):
for label1,label2 in combinations(factor.names, 2):
G.add_edge(label1,label2)

return G
Expand Down

0 comments on commit d599a26

Please sign in to comment.