Skip to content

Commit

Permalink
Add groupby raw node result option (aws#253)
Browse files Browse the repository at this point in the history
* Add --group-by-raw query option

* Add unit tests

Co-authored-by: Michael Chin <[email protected]>
  • Loading branch information
michaelnchin and michaelnchin authored Feb 3, 2022
1 parent bf8f49d commit 76c3c04
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 22 deletions.
11 changes: 10 additions & 1 deletion src/graph_notebook/magics/graph_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ def sparql(self, line='', cell='', local_ns: dict = None):
choices=['text/csv', 'text/html'])
parser.add_argument('-g', '--group-by', type=str, default='',
help='Property used to group nodes.')
parser.add_argument('-gr', '--group-by-raw', action='store_true', default=False,
help="Group nodes by the raw binding")
parser.add_argument('-d', '--display-property', type=str, default='',
help='Property to display the value of on each node.')
parser.add_argument('-de', '--edge-display-property', type=str, default='',
Expand Down Expand Up @@ -357,7 +359,8 @@ def sparql(self, line='', cell='', local_ns: dict = None):
label_max_length=args.label_max_length,
edge_label_max_length=args.edge_label_max_length,
ignore_groups=args.ignore_groups,
expand_all=args.expand_all)
expand_all=args.expand_all,
group_by_raw=args.group_by_raw)

sn.extract_prefix_declarations_from_query(cell)
try:
Expand Down Expand Up @@ -470,6 +473,8 @@ def gremlin(self, line, cell, local_ns: dict = None):
parser.add_argument('-p', '--path-pattern', default='', help='path pattern')
parser.add_argument('-g', '--group-by', type=str, default='T.label',
help='Property used to group nodes (e.g. code, T.region) default is T.label')
parser.add_argument('-gr', '--group-by-raw', action='store_true', default=False,
help="Group nodes by the raw result")
parser.add_argument('-d', '--display-property', type=str, default='T.label',
help='Property to display the value of on each node, default is T.label')
parser.add_argument('-de', '--edge-display-property', type=str, default='T.label',
Expand Down Expand Up @@ -572,6 +577,7 @@ def gremlin(self, line, cell, local_ns: dict = None):
logger.debug(f'label_max_length: {args.label_max_length}')
logger.debug(f'ignore_groups: {args.ignore_groups}')
gn = GremlinNetwork(group_by_property=args.group_by, display_property=args.display_property,
group_by_raw=args.group_by_raw,
edge_display_property=args.edge_display_property,
tooltip_property=args.tooltip_property,
edge_tooltip_property=args.edge_tooltip_property,
Expand Down Expand Up @@ -1659,6 +1665,8 @@ def handle_opencypher_query(self, line, cell, local_ns):
parser = argparse.ArgumentParser()
parser.add_argument('-g', '--group-by', type=str, default='~labels',
help='Property used to group nodes (e.g. code, ~id) default is ~labels')
parser.add_argument('-gr', '--group-by-raw', action='store_true', default=False,
help="Group nodes by the raw result")
parser.add_argument('mode', nargs='?', default='query', help='query mode [query|bolt]',
choices=['query', 'bolt'])
parser.add_argument('-d', '--display-property', type=str, default='~labels',
Expand Down Expand Up @@ -1712,6 +1720,7 @@ def handle_opencypher_query(self, line, cell, local_ns):
query_time=query_time)
try:
gn = OCNetwork(group_by_property=args.group_by, display_property=args.display_property,
group_by_raw=args.group_by_raw,
edge_display_property=args.edge_display_property,
tooltip_property=args.tooltip_property,
edge_tooltip_property=args.edge_tooltip_property,
Expand Down
14 changes: 9 additions & 5 deletions src/graph_notebook/network/EventfulNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
EVENT_ADD_EDGE = 'add_edge'
EVENT_ADD_EDGE_DATA = 'add_edge_data'
DEFAULT_GRP = 'DEFAULT_GROUP'
DEFAULT_RAW_GRP_KEY = '__RAW_RESULT__'
DEFAULT_LABEL_MAX_LENGTH = 10

VALID_EVENTS = [EVENT_ADD_NODE, EVENT_ADD_NODE_DATA, EVENT_ADD_NODE_PROPERTY, EVENT_ADD_EDGE, EVENT_ADD_EDGE_DATA]
Expand All @@ -41,16 +42,19 @@ def __init__(self, graph: MultiDiGraph = None, callbacks: dict = None,
edge_label_max_length: int = DEFAULT_LABEL_MAX_LENGTH, group_by_property: str = '',
display_property: str = '', edge_display_property: str = '',
tooltip_property: str = '', edge_tooltip_property: str = '',
ignore_groups=False):
ignore_groups=False, group_by_raw=False):
if callbacks is None:
callbacks = defaultdict(list)
self.callbacks = callbacks
self.label_max_length = 3 if label_max_length < 3 else label_max_length
self.edge_label_max_length = 3 if edge_label_max_length < 3 else edge_label_max_length
try:
self.group_by_property = json.loads(group_by_property)
except (TypeError, ValueError) as e:
self.group_by_property = group_by_property
if group_by_raw:
self.group_by_property = DEFAULT_RAW_GRP_KEY
else:
try:
self.group_by_property = json.loads(group_by_property)
except (TypeError, ValueError) as e:
self.group_by_property = group_by_property
self.display_property = self.convert_property_name(display_property)
self.edge_display_property = self.convert_property_name(edge_display_property)
self.tooltip_property = self.convert_property_name(tooltip_property) if tooltip_property \
Expand Down
30 changes: 23 additions & 7 deletions src/graph_notebook/network/gremlin/GremlinNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
from enum import Enum

from graph_notebook.network.EventfulNetwork import EventfulNetwork, DEFAULT_GRP
from graph_notebook.network.EventfulNetwork import EventfulNetwork, DEFAULT_GRP, DEFAULT_RAW_GRP_KEY
from gremlin_python.process.traversal import T, Direction
from gremlin_python.structure.graph import Path, Vertex, Edge
from networkx import MultiDiGraph
Expand Down Expand Up @@ -99,12 +99,13 @@ class GremlinNetwork(EventfulNetwork):

def __init__(self, graph: MultiDiGraph = None, callbacks=None, label_max_length=DEFAULT_LABEL_MAX_LENGTH,
edge_label_max_length=DEFAULT_LABEL_MAX_LENGTH, group_by_property=T_LABEL, display_property=T_LABEL,
edge_display_property=T_LABEL, tooltip_property=None, edge_tooltip_property=None, ignore_groups=False):
edge_display_property=T_LABEL, tooltip_property=None, edge_tooltip_property=None, ignore_groups=False,
group_by_raw=False):
if graph is None:
graph = MultiDiGraph()
super().__init__(graph, callbacks, label_max_length, edge_label_max_length, group_by_property,
display_property, edge_display_property, tooltip_property, edge_tooltip_property,
ignore_groups)
ignore_groups, group_by_raw)

def get_dict_element_property_value(self, element, k, temp_label, custom_property):
property_value = None
Expand Down Expand Up @@ -347,7 +348,9 @@ def add_vertex(self, v):
using_custom_tooltip = True
vertex_dict = v.__dict__
if not isinstance(self.group_by_property, dict): # Handle string format group_by
if str(self.group_by_property) in [T_LABEL, 'label']: # this handles if it's just a string
if str(self.group_by_property) == DEFAULT_RAW_GRP_KEY:
group = str(v)
elif str(self.group_by_property) in [T_LABEL, 'label']: # this handles if it's just a string
# This sets the group key to the label if either "label" is passed in or
# T.label is set in order to handle the default case of grouping by label
# when no explicit key is specified
Expand All @@ -359,7 +362,9 @@ def add_vertex(self, v):
else: # handle dict format group_by
try:
if str(v.label) in self.group_by_property:
if self.group_by_property[str(v.label)] in [T_LABEL, 'label']:
if self.group_by_property[str(v.label)] == DEFAULT_RAW_GRP_KEY:
group = str(v)
elif self.group_by_property[str(v.label)] in [T_LABEL, 'label']:
group = v.label
elif self.group_by_property[str(v.label)] in [T_ID, 'id']:
group = v.id
Expand Down Expand Up @@ -416,14 +421,21 @@ def add_vertex(self, v):
else:
title_plc = ''
group = DEFAULT_GRP

if str(self.group_by_property) == DEFAULT_RAW_GRP_KEY:
group = str(v)
group_is_set = True
for k in v:
if str(k) == T_ID:
node_id = str(v[k])
properties[k] = str(v[k]) if isinstance(v[k], dict) else v[k]
if not group_is_set:
if isinstance(self.group_by_property, dict):
try:
if str(k) == self.group_by_property[title_plc]:
if self.group_by_property[title_plc] == DEFAULT_RAW_GRP_KEY:
group = str(v)
group_is_set = True
elif str(k) == self.group_by_property[title_plc]:
group = str(v[k])
group_is_set = True
except KeyError:
Expand Down Expand Up @@ -468,8 +480,12 @@ def add_vertex(self, v):
else:
node_id = str(v)
title = str(v)
if str(self.group_by_property) == DEFAULT_RAW_GRP_KEY:
group = str(v)
else:
group = DEFAULT_GRP
label = title if len(title) <= self.label_max_length else title[:self.label_max_length - 3] + '...'
data = {'title': title, 'label': label, 'group': DEFAULT_GRP}
data = {'title': title, 'label': label, 'group': group}

if self.ignore_groups:
data['group'] = DEFAULT_GRP
Expand Down
14 changes: 9 additions & 5 deletions src/graph_notebook/network/opencypher/OCNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import logging

from graph_notebook.network.EventfulNetwork import EventfulNetwork, DEFAULT_GRP
from graph_notebook.network.EventfulNetwork import EventfulNetwork, DEFAULT_GRP, DEFAULT_RAW_GRP_KEY
from networkx import MultiDiGraph

logging.basicConfig()
Expand Down Expand Up @@ -35,12 +35,12 @@ def __init__(self, graph: MultiDiGraph = None, callbacks=None, label_max_length=
edge_label_max_length=DEFAULT_LABEL_MAX_LENGTH, group_by_property=LABEL_KEY,
display_property=LABEL_KEY, edge_display_property=EDGE_TYPE_KEY,
tooltip_property=None, edge_tooltip_property=None,
ignore_groups=False):
ignore_groups=False, group_by_raw=False):
if graph is None:
graph = MultiDiGraph()
super().__init__(graph, callbacks, label_max_length, edge_label_max_length, group_by_property,
display_property, edge_display_property, tooltip_property, edge_tooltip_property,
ignore_groups)
ignore_groups, group_by_raw)

def get_node_property_value(self, node: dict, props: dict, title, custom_property):
try:
Expand Down Expand Up @@ -127,7 +127,9 @@ def parse_node(self, node: dict, path_index: int = -2):

if not isinstance(self.group_by_property, dict): # Handle string format group_by
try:
if self.group_by_property in [LABEL_KEY, 'labels'] and len(node[LABEL_KEY]) > 0:
if self.group_by_property == DEFAULT_RAW_GRP_KEY:
group = str(node)
elif self.group_by_property in [LABEL_KEY, 'labels'] and len(node[LABEL_KEY]) > 0:
group = node[LABEL_KEY][0]
elif self.group_by_property in [ID_KEY, 'id']:
group = node[ID_KEY]
Expand All @@ -143,7 +145,9 @@ def parse_node(self, node: dict, path_index: int = -2):
try:
if str(node[LABEL_KEY][0]) in self.group_by_property and len(node[LABEL_KEY]) > 0:
key = node[LABEL_KEY][0]
if self.group_by_property[key] in [LABEL_KEY, 'labels']:
if self.group_by_property[key] == DEFAULT_RAW_GRP_KEY:
group = str(node)
elif self.group_by_property[key] in [LABEL_KEY, 'labels']:
group = node[LABEL_KEY][0]
elif self.group_by_property[key] in [ID_KEY, 'id']:
group = node[ID_KEY]
Expand Down
14 changes: 10 additions & 4 deletions src/graph_notebook/network/sparql/SPARQLNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from networkx import MultiDiGraph
from rdflib.namespace import RDF, RDFS, OWL, XSD, SKOS, DOAP, FOAF, DC, DCTERMS, VOID

from graph_notebook.network.EventfulNetwork import EventfulNetwork, DEFAULT_GRP
from graph_notebook.network.EventfulNetwork import EventfulNetwork, DEFAULT_GRP, DEFAULT_RAW_GRP_KEY

NAMESPACE_RDFS = str(RDFS.uri)
NAMESPACE_RDF = str(RDF.uri)
Expand Down Expand Up @@ -58,14 +58,15 @@ def __init__(self,
tooltip_property='',
edge_tooltip_property='',
ignore_groups=False,
expand_all: bool = False):
expand_all: bool = False,
group_by_raw=False):
if graph is None:
graph = MultiDiGraph()
self.expand_all = expand_all

super().__init__(graph, callbacks, label_max_length, edge_label_max_length, group_by_property,
display_property, edge_display_property, tooltip_property, edge_tooltip_property,
ignore_groups)
ignore_groups, group_by_raw)
self.namespace_to_prefix = { # http://foo/bar/ -> bar
NAMESPACE_RDFS: PREFIX_RDFS,
NAMESPACE_RDF: PREFIX_RDF,
Expand Down Expand Up @@ -168,11 +169,16 @@ def parse_node(self, node_id: str, node_binding: dict = None, data: dict = None)
data = self.generate_node_label_title(node_id=node_id, data=data)
if self.ignore_groups or not node_binding:
data['group'] = DEFAULT_GRP
elif str(self.group_by_property) == DEFAULT_RAW_GRP_KEY:
data['group'] = str(node_binding)
else:
if isinstance(self.group_by_property, dict):
try:
if str(node_binding["type"]) in self.group_by_property:
data['group'] = node_binding[self.group_by_property[node_binding["type"]]]
if self.group_by_property[node_binding["type"]] == DEFAULT_RAW_GRP_KEY:
data['group'] = str(node_binding)
else:
data['group'] = node_binding[self.group_by_property[node_binding["type"]]]
else:
data['group'] = node_binding["type"]
except KeyError:
Expand Down
61 changes: 61 additions & 0 deletions test/unit/network/gremlin/test_gremlin_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,22 @@ def test_group_with_groupby_multiple_labels_with_same_property(self):
self.assertEqual(node1['group'], 'Seattle-Tacoma International Airport')
self.assertEqual(node2['group'], 'Austria')

def test_group_with_groupby_raw_string(self):
vertex = {
T.id: '1234',
T.label: 'airport',
'type': 'Airport',
'runways': '4',
'code': 'SEA'
}

expected = "{<T.id: 1>: '1234', <T.label: 4>: 'airport', 'type': 'Airport', 'runways': '4', 'code': 'SEA'}"

gn = GremlinNetwork(group_by_property='__RAW_RESULT__')
gn.add_vertex(vertex)
node = gn.graph.nodes.get(vertex[T.id])
self.assertEqual(node['group'], expected)

def test_group_with_groupby_properties_json_single_label(self):
vertex1 = {
T.id: '1234',
Expand Down Expand Up @@ -1487,6 +1503,21 @@ def test_group_with_groupby_properties_json_multiple_labels(self):
self.assertEqual(node1['group'], 'SEA')
self.assertEqual(node2['group'], 'Europe')

def test_group_with_groupby_raw_json(self):
vertex = {
T.id: '1234',
T.label: 'airport',
'type': 'Airport',
'runways': '4',
'code': 'SEA'
}

expected = "{<T.id: 1>: '1234', <T.label: 4>: 'airport', 'type': 'Airport', 'runways': '4', 'code': 'SEA'}"
gn = GremlinNetwork(group_by_property='{"airport":"__RAW_RESULT__"}')
gn.add_vertex(vertex)
node = gn.graph.nodes.get(vertex[T.id])
self.assertEqual(node['group'], expected)

def test_group_with_groupby_invalid_json_single_label(self):
vertex1 = {
T.id: '1234',
Expand Down Expand Up @@ -1950,6 +1981,36 @@ def test_group_returnvertex_groupby_id_properties_json(self):
node = gn.graph.nodes.get('1')
self.assertEqual(node['group'], '1')

def test_group_by_raw_explicit(self):
vertex = {
T.id: '1234',
T.label: 'airport',
'type': 'Airport',
'runways': '4',
'code': 'SEA'
}

expected = "{<T.id: 1>: '1234', <T.label: 4>: 'airport', 'type': 'Airport', 'runways': '4', 'code': 'SEA'}"
gn = GremlinNetwork(group_by_raw=True)
gn.add_vertex(vertex)
node = gn.graph.nodes.get(vertex[T.id])
self.assertEqual(node['group'], expected)

def test_group_by_raw_explicit_overrule_gbp(self):
vertex = {
T.id: '1234',
T.label: 'airport',
'type': 'Airport',
'runways': '4',
'code': 'SEA'
}

expected = "{<T.id: 1>: '1234', <T.label: 4>: 'airport', 'type': 'Airport', 'runways': '4', 'code': 'SEA'}"
gn = GremlinNetwork(group_by_raw=True, group_by_property='{"airport":"code"}')
gn.add_vertex(vertex)
node = gn.graph.nodes.get(vertex[T.id])
self.assertEqual(node['group'], expected)

def test_add_elementmap_edge(self):
edge_map = {
T.id: '5298',
Expand Down
Loading

0 comments on commit 76c3c04

Please sign in to comment.