Skip to content

Commit

Permalink
Added typing hints to various smaller files.
Browse files Browse the repository at this point in the history
Right now mypy type checking is disabled for optimize.py due to the overwhelming amount of false positive '"Gate" has no attribute..." errors
  • Loading branch information
jvdwetering committed Apr 23, 2020
1 parent f215db1 commit 0709025
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 127 deletions.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
[mypy]

[mypy-pyzx.optimize]
ignore_errors = True

[mypy-numpy.*]
ignore_missing_imports = True

Expand Down
100 changes: 56 additions & 44 deletions pyzx/editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,19 @@
from fractions import Fraction
import traceback

from typing import Callable, Optional, List, Tuple, Set, Dict, Any, Union

from .utils import EdgeType, VertexType, toggle_edge, vertex_is_zx, toggle_vertex
from .utils import settings, phase_to_s
from .utils import settings, phase_to_s, FloatInt, FractionLike
from .graph.base import BaseGraph, VT, ET
from .graph.graph import GraphS
from . import rules

try:
if settings.mode == 'notebook':
import ipywidgets as widgets # type: ignore
from traitlets import Unicode, validate, Bool, Int, Float # type: ignore
from IPython.display import display, HTML # type: ignore
in_notebook = True
except ImportError:
in_notebook = False
else:
# Make some dummy classes to prevent errors with the definition
# of ZXEditorWidget
class DOMWidget(object):
Expand All @@ -46,8 +49,6 @@ class widgets(object): # type: ignore
register = lambda x: x
DOMWidget = DOMWidget

from . import rules

__all__ = ['edit', 'help']

HELP_STRING = """To create an editor, call `e = zx.editor.edit(g)` on a graph g.
Expand Down Expand Up @@ -91,8 +92,8 @@ def help():
Run %%pip install ipywidgets in a cell in your notebook to install the correct package.
"""

def load_js():
if not in_notebook:
def load_js() -> None:
if settings.mode != 'notebook':
raise Exception(ERROR_STRING)
with open(os.path.join(settings.javascript_location,"zx_editor_widget.js")) as f:
data1 = f.read()
Expand All @@ -107,7 +108,7 @@ def load_js():
</script>""".format(settings.d3_load_string,data1,data2)
display(HTML(text))

def s_to_phase(s, t=1):
def s_to_phase(s: str, t:VertexType.Type=VertexType.Z) -> Fraction:
if not s:
if t!= VertexType.H_BOX: return Fraction(0)
else: return Fraction(1)
Expand All @@ -119,22 +120,25 @@ def s_to_phase(s, t=1):
if not s: return Fraction(1)
return Fraction(int(s))

def graph_to_json(g, scale):
nodes = [{'name': int(v),
def graph_to_json(g: BaseGraph[VT,ET], scale:FloatInt) -> str:
nodes = [{'name': int(v), # type: ignore
'x': (g.row(v) + 1) * scale,
'y': (g.qubit(v) + 2) * scale,
't': g.type(v),
'phase': phase_to_s(g.phase(v),g.type(v)) }
for v in g.vertices()]
links = [{'source': int(g.edge_s(e)),
'target': int(g.edge_t(e)),
links = [{'source': int(g.edge_s(e)), # type: ignore
'target': int(g.edge_t(e)), # type: ignore
't': g.edge_type(e) } for e in g.edges()]
return json.dumps({'nodes': nodes, 'links': links})



def colour_change_matcher(g, vertexf):
if vertexf != None: candidates = set([v for v in g.vertices() if vertexf(v)])
def colour_change_matcher(
g: BaseGraph[VT,ET],
vertexf: Optional[Callable[[VT],bool]] = None
) -> List[VT]:
if vertexf is not None: candidates = set([v for v in g.vertices() if vertexf(v)])
else: candidates = g.vertex_set()
types = g.types()

Expand All @@ -146,16 +150,21 @@ def colour_change_matcher(g, vertexf):

return m

def colour_change(g, matches):
def colour_change(g: BaseGraph[VT,ET], matches: List[VT]) -> rules.RewriteOutputType[ET,VT]:
for v in matches:
g.set_type(v, VertexType.Z)
for e in g.incident_edges(v):
et = g.edge_type(e)
g.set_edge_type(e, toggle_edge(et))
return ({}, [],[],False)

def copy_matcher(g, vertexf=None):
if vertexf != None: candidates = set([v for v in g.vertices() if vertexf(v)])
MatchCopyType = Tuple[VT,VT,EdgeType.Type,FractionLike,FractionLike,List[VT]]

def copy_matcher(
g: BaseGraph[VT,ET],
vertexf:Optional[Callable[[VT],bool]]=None
) -> List[MatchCopyType[VT]]:
if vertexf is not None: candidates = set([v for v in g.vertices() if vertexf(v)])
else: candidates = g.vertex_set()
phases = g.phases()
types = g.types()
Expand All @@ -178,7 +187,10 @@ def copy_matcher(g, vertexf=None):

return m

def apply_copy(g, matches):
def apply_copy(
g: BaseGraph[VT,ET],
matches: List[MatchCopyType[VT]]
) -> rules.RewriteOutputType[ET,VT]:
rem = []
types = g.types()
for v,w,t,a,alpha, neigh in matches:
Expand All @@ -193,7 +205,7 @@ def apply_copy(g, matches):
u = g.add_vertex(vt, g.qubit(n)-0.8, r, a)
e = g.edge(n,w)
et = g.edge_type(e)
g.add_edge((n,u), et)
g.add_edge(g.edge(n,u), et)
return ({}, rem, [], True)

MATCHES_VERTICES = 1
Expand Down Expand Up @@ -233,7 +245,7 @@ def apply_copy(g, matches):
}


def operations_to_js():
def operations_to_js() -> str:
global operations
return json.dumps({k:{"active":False, "text":v["text"], "tooltip":v["tooltip"]} for k,v in operations.items()})

Expand All @@ -259,23 +271,23 @@ class ZXEditorWidget(widgets.DOMWidget):
last_operation = Unicode('').tag(sync=True)
action = Unicode('').tag(sync=True)

def __init__(self, graph, *args, **kwargs):
def __init__(self, graph: GraphS, *args, **kwargs) -> None:
super().__init__(*args,**kwargs)
self.observe(self._handle_graph_change, 'graph_json')
self.observe(self._selection_changed, 'graph_selected')
self.observe(self._apply_operation, 'button_clicked')
self.observe(self._perform_action, 'action')
self.graph = graph
self.undo_stack = [('initial',str(self.graph_json))]
self.undo_position = 1
self.halt_callbacks = False
self.msg = []
self.undo_stack: List[Tuple[str,str]] = [('initial',str(self.graph_json))]
self.undo_position: int = 1
self.halt_callbacks: bool = False
self.msg: List[str] = []
self.output = widgets.Output()

def update(self):
self.graph_json = graph_to_json(self.graph, self.graph.scale)
def update(self) -> None:
self.graph_json = graph_to_json(self.graph, self.graph.scale) # type: ignore

def _parse_selection(self):
def _parse_selection(self) -> Tuple[Set[VT],Set[ET]]:
"""Helper function for `_selection_changed` and `_apply_operation`."""
selection = json.loads(self.graph_selected)
g = self.graph
Expand Down Expand Up @@ -343,13 +355,13 @@ def _perform_action(self, change):
with self.output: print(traceback.format_exc())


def undo_stack_add(self, description, js):
def undo_stack_add(self, description: str, js: str) -> None:
self.undo_stack = self.undo_stack[:len(self.undo_stack)-self.undo_position+1]
self.undo_position = 1
self.undo_stack.append((description,js))
self.msg.append("Adding to undo stack: " + description)

def undo(self):
def undo(self) -> None:
if self.undo_position == len(self.undo_stack): return
self.undo_position += 1
description, js = self.undo_stack[len(self.undo_stack)-self.undo_position]
Expand All @@ -360,7 +372,7 @@ def undo(self):
self.update()
self.halt_callbacks = False

def redo(self):
def redo(self) -> None:
if self.undo_position == 1: return
self.undo_position -= 1
description, js = self.undo_stack[len(self.undo_stack)-self.undo_position]
Expand All @@ -371,23 +383,23 @@ def redo(self):
self.update()
self.halt_callbacks = False

def graph_from_json(self, js):
def graph_from_json(self, js: Dict[str,Any]) -> None:
try:
scale = self.graph.scale
marked = self.graph.vertex_set()
scale = self.graph.scale # type: ignore
marked: Union[Set[int],Set[Tuple[int,int]]] = self.graph.vertex_set()
for n in js["nodes"]:
v = n["name"]
r = float(n["x"])/scale -1
q = float(n["y"])/scale -2
t = int(n["t"])
phase = s_to_phase(n["phase"], t)
phase = s_to_phase(n["phase"], t) # type: ignore
if v not in marked:
self.graph.add_vertex_indexed(v)
self.graph.add_vertex_indexed(v) # type: ignore
else:
marked.remove(v)
self.graph.set_position(v, q, r)
self.graph.set_phase(v, phase)
self.graph.set_type(v, t)
self.graph.set_type(v, t) # type: ignore
self.graph.remove_vertices(marked)
marked = self.graph.edge_set()
for e in js["links"]:
Expand All @@ -399,7 +411,7 @@ def graph_from_json(self, js):
marked.remove(f)
self.graph.set_edge_type(f, et)
else:
self.graph.add_edge((s,t),et)
self.graph.add_edge((s,t),et) # type: ignore
self.graph.remove_edges(marked)
except Exception as e:
with self.output: print(traceback.format_exc())
Expand All @@ -416,25 +428,25 @@ def _handle_graph_change(self, change):
with self.output: print(traceback.format_exc())


def to_graph(self, zh=True):
def to_graph(self, zh:bool=True) -> GraphS:
return self.graph



_d3_editor_id = 0

def edit(g, scale=None):
def edit(g: GraphS, scale:Optional[FloatInt]=None) -> ZXEditorWidget:
load_js()
global _d3_editor_id
_d3_editor_id += 1
seq = _d3_editor_id

if scale == None:
if scale is None:
scale = 800 / (g.depth() + 2)
if scale > 50: scale = 50
if scale < 20: scale = 20

g.scale = scale
g.scale = scale # type: ignore

node_size = 0.2 * scale
if node_size < 2: node_size = 2
Expand Down
Loading

0 comments on commit 0709025

Please sign in to comment.