Skip to content

Commit

Permalink
Use Parameterized Test (TeamGraphix#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
king-p3nguin authored Jan 10, 2024
1 parent cc33ca9 commit 9bfb228
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 385 deletions.
12 changes: 12 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Style
black==22.8.0

# Tests
pytest
parameterized
tox

# Optional dependencies
qiskit
qiskit-aer
rustworkx
6 changes: 4 additions & 2 deletions tests/random_circuit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
from copy import deepcopy

import numpy as np

from graphix.transpiler import Circuit

GLOBAL_SEED = None
Expand All @@ -16,7 +18,7 @@ def get_rng(seed=None):
elif seed is None and GLOBAL_SEED is not None:
return np.random.default_rng(GLOBAL_SEED)
else:
np.random.default_rng()
return np.random.default_rng()


def first_rotation(circuit, nqubits, rng):
Expand Down
134 changes: 16 additions & 118 deletions tests/test_extraction.py
Original file line number Diff line number Diff line change
@@ -1,122 +1,20 @@
import sys
import unittest

from parameterized import parameterized_class

import graphix
from graphix import extraction


@parameterized_class([{"use_rustworkx": False}, {"use_rustworkx": True}])
class TestExtraction(unittest.TestCase):
def test_cluster_extraction_one_ghz_cluster(self):
gs = graphix.GraphState()
nodes = [0, 1, 2, 3, 4]
edges = [(0, 1), (0, 2), (0, 3), (0, 4)]
gs.add_nodes_from(nodes)
gs.add_edges_from(edges)
clusters = extraction.get_fusion_network_from_graph(gs)

self.assertEqual(len(clusters), 1)
self.assertEqual(clusters[0] == extraction.ResourceGraph(type=extraction.ResourceType.GHZ, graph=gs), True)

# we consider everything smaller than 4, a GHZ
def test_cluster_extraction_small_ghz_cluster_1(self):
gs = graphix.GraphState()
nodes = [0, 1, 2]
edges = [(0, 1), (1, 2)]
gs.add_nodes_from(nodes)
gs.add_edges_from(edges)
clusters = extraction.get_fusion_network_from_graph(gs)

self.assertEqual(len(clusters), 1)
self.assertEqual(clusters[0] == extraction.ResourceGraph(type=extraction.ResourceType.GHZ, graph=gs), True)

# we consider everything smaller than 4, a GHZ
def test_cluster_extraction_small_ghz_cluster_2(self):
gs = graphix.GraphState()
nodes = [0, 1]
edges = [(0, 1)]
gs.add_nodes_from(nodes)
gs.add_edges_from(edges)
clusters = extraction.get_fusion_network_from_graph(gs)

self.assertEqual(len(clusters), 1)
self.assertEqual(clusters[0] == extraction.ResourceGraph(type=extraction.ResourceType.GHZ, graph=gs), True)

def test_cluster_extraction_one_linear_cluster(self):
gs = graphix.GraphState()
nodes = [0, 1, 2, 3, 4, 5, 6]
edges = [(0, 1), (1, 2), (2, 3), (5, 4), (4, 6), (6, 0)]
gs.add_nodes_from(nodes)
gs.add_edges_from(edges)
clusters = extraction.get_fusion_network_from_graph(gs)

self.assertEqual(len(clusters), 1)
self.assertEqual(clusters[0] == extraction.ResourceGraph(type=extraction.ResourceType.LINEAR, graph=gs), True)

def test_cluster_extraction_one_ghz_one_linear(self):
gs = graphix.GraphState()
nodes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
edges = [(0, 1), (0, 2), (0, 3), (0, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9)]
gs.add_nodes_from(nodes)
gs.add_edges_from(edges)
clusters = extraction.get_fusion_network_from_graph(gs)
self.assertEqual(len(clusters), 2)

clusters_expected = []
lin_cluster = graphix.GraphState()
lin_cluster.add_nodes_from([4, 5, 6, 7, 8, 9])
lin_cluster.add_edges_from([(4, 5), (5, 6), (6, 7), (7, 8), (8, 9)])
clusters_expected.append(extraction.ResourceGraph(extraction.ResourceType.LINEAR, lin_cluster))
ghz_cluster = graphix.GraphState()
ghz_cluster.add_nodes_from([0, 1, 2, 3, 4])
ghz_cluster.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4)])
clusters_expected.append(extraction.ResourceGraph(extraction.ResourceType.GHZ, ghz_cluster))

self.assertEqual(
(clusters[0] == clusters_expected[0] and clusters[1] == clusters_expected[1])
or (clusters[0] == clusters_expected[1] and clusters[1] == clusters_expected[0]),
True,
)

def test_cluster_extraction_pentagonal_cluster(self):
gs = graphix.GraphState()
nodes = [0, 1, 2, 3, 4]
edges = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)]
gs.add_nodes_from(nodes)
gs.add_edges_from(edges)
clusters = extraction.get_fusion_network_from_graph(gs)
self.assertEqual(len(clusters), 2)
self.assertEqual(
(clusters[0].type == extraction.ResourceType.GHZ and clusters[1].type == extraction.ResourceType.LINEAR)
or (clusters[0].type == extraction.ResourceType.LINEAR and clusters[1].type == extraction.ResourceType.GHZ),
True,
)
self.assertEqual(
(len(clusters[0].graph.nodes) == 3 and len(clusters[1].graph.nodes) == 4)
or (len(clusters[0].graph.nodes) == 4 and len(clusters[1].graph.nodes) == 3),
True,
)

def test_cluster_extraction_one_plus_two(self):
gs = graphix.GraphState()
nodes = [0, 1, 2]
edges = [(0, 1)]
gs.add_nodes_from(nodes)
gs.add_edges_from(edges)
clusters = extraction.get_fusion_network_from_graph(gs)
self.assertEqual(len(clusters), 2)
self.assertEqual(
(clusters[0].type == extraction.ResourceType.GHZ and clusters[1].type == extraction.ResourceType.GHZ),
True,
)
self.assertEqual(
(len(clusters[0].graph.nodes) == 2 and len(clusters[1].graph.nodes) == 1)
or (len(clusters[0].graph.nodes) == 1 and len(clusters[1].graph.nodes) == 2),
True,
)

def setUp(self):
if sys.modules.get("rustworkx") is None and self.use_rustworkx is True:
self.skipTest("rustworkx not installed")

class TestExtractionWithRustworkX(unittest.TestCase):
def test_cluster_extraction_one_ghz_cluster(self):
gs = graphix.GraphState(use_rustworkx=True)
gs = graphix.GraphState(use_rustworkx=self.use_rustworkx)
nodes = [0, 1, 2, 3, 4]
edges = [(0, 1), (0, 2), (0, 3), (0, 4)]
gs.add_nodes_from(nodes)
Expand All @@ -128,7 +26,7 @@ def test_cluster_extraction_one_ghz_cluster(self):

# we consider everything smaller than 4, a GHZ
def test_cluster_extraction_small_ghz_cluster_1(self):
gs = graphix.GraphState(use_rustworkx=True)
gs = graphix.GraphState(use_rustworkx=self.use_rustworkx)
nodes = [0, 1, 2]
edges = [(0, 1), (1, 2)]
gs.add_nodes_from(nodes)
Expand All @@ -140,7 +38,7 @@ def test_cluster_extraction_small_ghz_cluster_1(self):

# we consider everything smaller than 4, a GHZ
def test_cluster_extraction_small_ghz_cluster_2(self):
gs = graphix.GraphState(use_rustworkx=True)
gs = graphix.GraphState(use_rustworkx=self.use_rustworkx)
nodes = [0, 1]
edges = [(0, 1)]
gs.add_nodes_from(nodes)
Expand All @@ -151,7 +49,7 @@ def test_cluster_extraction_small_ghz_cluster_2(self):
self.assertEqual(clusters[0] == extraction.ResourceGraph(type=extraction.ResourceType.GHZ, graph=gs), True)

def test_cluster_extraction_one_linear_cluster(self):
gs = graphix.GraphState(use_rustworkx=True)
gs = graphix.GraphState(use_rustworkx=self.use_rustworkx)
nodes = [0, 1, 2, 3, 4, 5, 6]
edges = [(0, 1), (1, 2), (2, 3), (5, 4), (4, 6), (6, 0)]
gs.add_nodes_from(nodes)
Expand All @@ -162,7 +60,7 @@ def test_cluster_extraction_one_linear_cluster(self):
self.assertEqual(clusters[0] == extraction.ResourceGraph(type=extraction.ResourceType.LINEAR, graph=gs), True)

def test_cluster_extraction_one_ghz_one_linear(self):
gs = graphix.GraphState(use_rustworkx=True)
gs = graphix.GraphState(use_rustworkx=self.use_rustworkx)
nodes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
edges = [(0, 1), (0, 2), (0, 3), (0, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9)]
gs.add_nodes_from(nodes)
Expand All @@ -171,11 +69,11 @@ def test_cluster_extraction_one_ghz_one_linear(self):
self.assertEqual(len(clusters), 2)

clusters_expected = []
lin_cluster = graphix.GraphState(use_rustworkx=True)
lin_cluster = graphix.GraphState(use_rustworkx=self.use_rustworkx)
lin_cluster.add_nodes_from([4, 5, 6, 7, 8, 9])
lin_cluster.add_edges_from([(4, 5), (5, 6), (6, 7), (7, 8), (8, 9)])
clusters_expected.append(extraction.ResourceGraph(extraction.ResourceType.LINEAR, lin_cluster))
ghz_cluster = graphix.GraphState(use_rustworkx=True)
ghz_cluster = graphix.GraphState(use_rustworkx=self.use_rustworkx)
ghz_cluster.add_nodes_from([0, 1, 2, 3, 4])
ghz_cluster.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4)])
clusters_expected.append(extraction.ResourceGraph(extraction.ResourceType.GHZ, ghz_cluster))
Expand All @@ -187,7 +85,7 @@ def test_cluster_extraction_one_ghz_one_linear(self):
)

def test_cluster_extraction_pentagonal_cluster(self):
gs = graphix.GraphState(use_rustworkx=True)
gs = graphix.GraphState(use_rustworkx=self.use_rustworkx)
nodes = [0, 1, 2, 3, 4]
edges = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)]
gs.add_nodes_from(nodes)
Expand All @@ -206,7 +104,7 @@ def test_cluster_extraction_pentagonal_cluster(self):
)

def test_cluster_extraction_one_plus_two(self):
gs = graphix.GraphState(use_rustworkx=True)
gs = graphix.GraphState(use_rustworkx=self.use_rustworkx)
nodes = [0, 1, 2]
edges = [(0, 1)]
gs.add_nodes_from(nodes)
Expand Down
106 changes: 16 additions & 90 deletions tests/test_graphsim.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import sys
import unittest

import numpy as np
from networkx import Graph
from networkx.utils import graphs_equal
from rustworkx import PyGraph
from parameterized import parameterized_class

try:
from rustworkx import PyGraph
except ModuleNotFoundError:
pass

from graphix.graphsim.graphstate import GraphState
from graphix.graphsim.utils import convert_rustworkx_to_networkx, is_graphs_equal
Expand Down Expand Up @@ -31,100 +37,19 @@ def get_state(g):
return gstate


@parameterized_class([{"use_rustworkx": False}, {"use_rustworkx": True}])
class TestGraphSim(unittest.TestCase):
def test_fig2(self):
"""Example of three single-qubit measurements
presented in Fig.2 of M. Elliot et al (2010)
"""
nqubit = 6
edges = [(0, 1), (1, 2), (3, 4), (4, 5), (0, 3), (1, 4), (2, 5)]
g = GraphState(nodes=np.arange(nqubit), edges=edges)
gstate = get_state(g)
g.measure_x(0)
gstate.evolve_single(meas_op(0), [0]) # x meas
gstate.normalize()
gstate.remove_qubit(0)
gstate2 = get_state(g)
np.testing.assert_almost_equal(np.abs(np.dot(gstate.flatten().conjugate(), gstate2.flatten())), 1)

g.measure_y(1, choice=0)
gstate.evolve_single(meas_op(0.5 * np.pi), [0]) # y meas
gstate.normalize()
gstate.remove_qubit(0)
gstate2 = get_state(g)
np.testing.assert_almost_equal(np.abs(np.dot(gstate.flatten().conjugate(), gstate2.flatten())), 1)

g.measure_z(3)
gstate.evolve_single(meas_op(0.5 * np.pi, plane="YZ"), 1) # z meas
gstate.normalize()
gstate.remove_qubit(1)
gstate2 = get_state(g)
np.testing.assert_almost_equal(np.abs(np.dot(gstate.flatten().conjugate(), gstate2.flatten())), 1)

def test_E2(self):
nqubit = 6
edges = [(0, 1), (1, 2), (3, 4), (4, 5), (0, 3), (1, 4), (2, 5)]
g = GraphState(nodes=np.arange(nqubit), edges=edges)
g.h(3)
gstate = get_state(g)

g.equivalent_graph_E2(3, 4)
gstate2 = get_state(g)
np.testing.assert_almost_equal(np.abs(np.dot(gstate.flatten().conjugate(), gstate2.flatten())), 1)

g.equivalent_graph_E2(4, 0)
gstate3 = get_state(g)
np.testing.assert_almost_equal(np.abs(np.dot(gstate.flatten().conjugate(), gstate3.flatten())), 1)

g.equivalent_graph_E2(4, 5)
gstate4 = get_state(g)
np.testing.assert_almost_equal(np.abs(np.dot(gstate.flatten().conjugate(), gstate4.flatten())), 1)

g.equivalent_graph_E2(0, 3)
gstate5 = get_state(g)
np.testing.assert_almost_equal(np.abs(np.dot(gstate.flatten().conjugate(), gstate5.flatten())), 1)

g.equivalent_graph_E2(0, 3)
gstate6 = get_state(g)
np.testing.assert_almost_equal(np.abs(np.dot(gstate.flatten().conjugate(), gstate6.flatten())), 1)

def test_E1(self):
nqubit = 6
edges = [(0, 1), (1, 2), (3, 4), (4, 5), (0, 3), (1, 4), (2, 5)]
g = GraphState(nodes=np.arange(nqubit), edges=edges)
g.nodes[3]["loop"] = True
gstate = get_state(g)
g.equivalent_graph_E1(3)

gstate2 = get_state(g)
np.testing.assert_almost_equal(np.abs(np.dot(gstate.flatten().conjugate(), gstate2.flatten())), 1)
g.z(4)
gstate = get_state(g)
g.equivalent_graph_E1(4)
gstate2 = get_state(g)
np.testing.assert_almost_equal(np.abs(np.dot(gstate.flatten().conjugate(), gstate2.flatten())), 1)
g.equivalent_graph_E1(4)
gstate3 = get_state(g)
np.testing.assert_almost_equal(np.abs(np.dot(gstate.flatten().conjugate(), gstate3.flatten())), 1)

def test_local_complement(self):
nqubit = 6
edges = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)]
exp_edges = [(0, 1), (1, 2), (0, 2), (2, 3), (3, 4), (4, 0)]
g = GraphState(nodes=np.arange(nqubit), edges=edges)
g.local_complement(1)
exp_g = GraphState(nodes=np.arange(nqubit), edges=exp_edges)
self.assertTrue(is_graphs_equal(g, exp_g))

def setUp(self):
if sys.modules.get("rustworkx") is None and self.use_rustworkx is True:
self.skipTest("rustworkx not installed")

class TestGraphSimWithRustworkX(unittest.TestCase):
def test_fig2(self):
"""Example of three single-qubit measurements
presented in Fig.2 of M. Elliot et al (2010)
"""
nqubit = 6
edges = [(0, 1), (1, 2), (3, 4), (4, 5), (0, 3), (1, 4), (2, 5)]
g = GraphState(nodes=np.arange(nqubit), edges=edges, use_rustworkx=True)
g = GraphState(nodes=np.arange(nqubit), edges=edges, use_rustworkx=self.use_rustworkx)
gstate = get_state(g)
g.measure_x(0)
gstate.evolve_single(meas_op(0), [0]) # x meas
Expand All @@ -150,7 +75,7 @@ def test_fig2(self):
def test_E2(self):
nqubit = 6
edges = [(0, 1), (1, 2), (3, 4), (4, 5), (0, 3), (1, 4), (2, 5)]
g = GraphState(nodes=np.arange(nqubit), edges=edges, use_rustworkx=True)
g = GraphState(nodes=np.arange(nqubit), edges=edges, use_rustworkx=self.use_rustworkx)
g.h(3)
gstate = get_state(g)

Expand All @@ -177,7 +102,7 @@ def test_E2(self):
def test_E1(self):
nqubit = 6
edges = [(0, 1), (1, 2), (3, 4), (4, 5), (0, 3), (1, 4), (2, 5)]
g = GraphState(nodes=np.arange(nqubit), edges=edges, use_rustworkx=True)
g = GraphState(nodes=np.arange(nqubit), edges=edges, use_rustworkx=self.use_rustworkx)
g.nodes[3]["loop"] = True
gstate = get_state(g)
g.equivalent_graph_E1(3)
Expand All @@ -197,12 +122,13 @@ def test_local_complement(self):
nqubit = 6
edges = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)]
exp_edges = [(0, 1), (1, 2), (0, 2), (2, 3), (3, 4), (4, 0)]
g = GraphState(nodes=np.arange(nqubit), edges=edges, use_rustworkx=True)
g = GraphState(nodes=np.arange(nqubit), edges=edges, use_rustworkx=self.use_rustworkx)
g.local_complement(1)
exp_g = GraphState(nodes=np.arange(nqubit), edges=exp_edges)
self.assertTrue(is_graphs_equal(g, exp_g))


@unittest.skipIf(sys.modules.get("rustworkx") is None, "rustworkx not installed")
class TestGraphSimUtils(unittest.TestCase):
def test_is_graphs_equal_nx_nx(self):
nnode = 6
Expand Down
Loading

0 comments on commit 9bfb228

Please sign in to comment.