Skip to content

Commit

Permalink
Make Gremlin traversal source configurable (aws#221)
Browse files Browse the repository at this point in the history
* Add Gremlin traversal source option to %%graph_notebook_config

* Update Changelog

* Prevent traversal source from being configured for Neptune hosts

* Catch GremlinServerError for invalid TraversalSource, and add more descriptive error message

Co-authored-by: Michael Chin <[email protected]>
  • Loading branch information
michaelnchin and michaelnchin authored Nov 3, 2021
1 parent 82ac40f commit de3475c
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 12 deletions.
1 change: 1 addition & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Starting with v1.31.6, this file will contain a record of major features and updates made in each release of graph-notebook.

## Upcoming
- Added support for specifying the Gremlin traversal source ([Link to PR](https://github.com/aws/graph-notebook/pull/221))
- Added edge tooltips, and options for specifying edge label length ([Link to PR](https://github.com/aws/graph-notebook/pull/218))

## Release 3.0.7 (October 25, 2021)
Expand Down
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,17 @@ python -m graph_notebook.start_notebook --notebooks-dir ~/notebook/destination/d

### Gremlin Server

In a new cell in the Jupyter notebook, change the configuration using `%%graph_notebook_config` and modify the fields for `host`, `port`, and `ssl`. For a local Gremlin server (HTTP or WebSockets), you can use the following command:
In a new cell in the Jupyter notebook, change the configuration using `%%graph_notebook_config` and modify the fields for `host`, `port`, and `ssl`. Optionally, modify `traversal_source` if your graph traversal source name differs from the default value. For a local Gremlin server (HTTP or WebSockets), you can use the following command:

```
%%graph_notebook_config
{
"host": "localhost",
"port": 8182,
"ssl": false
"ssl": false,
"gremlin": {
"traversal_source": "g"
}
}
```

Expand Down
30 changes: 27 additions & 3 deletions src/graph_notebook/configuration/generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,31 @@ def to_dict(self):
return self.__dict__


class GremlinSection(object):
"""
Used for gremlin-specific settings in a notebook's configuration
"""

def __init__(self, traversal_source: str = ''):
"""
:param traversal_source: used to specify the traversal source for a Gremlin traversal, in the case that we are
connected to an endpoint that can access multiple graphs.
"""

if traversal_source == '':
traversal_source = 'g'

self.traversal_source = traversal_source

def to_dict(self):
return self.__dict__


class Configuration(object):
def __init__(self, host: str, port: int,
auth_mode: AuthModeEnum = AuthModeEnum.DEFAULT,
load_from_s3_arn='', ssl: bool = True, aws_region: str = 'us-east-1',
sparql_section: SparqlSection = None):
sparql_section: SparqlSection = None, gremlin_section: GremlinSection = None):
self.host = host
self.port = port
self.ssl = ssl
Expand All @@ -56,8 +76,10 @@ def __init__(self, host: str, port: int,
self.auth_mode = auth_mode
self.load_from_s3_arn = load_from_s3_arn
self.aws_region = aws_region
self.gremlin = GremlinSection()
else:
self.is_neptune_config = False
self.gremlin = gremlin_section if gremlin_section is not None else GremlinSection()

def to_dict(self) -> dict:
if self.is_neptune_config:
Expand All @@ -68,14 +90,16 @@ def to_dict(self) -> dict:
'load_from_s3_arn': self.load_from_s3_arn,
'ssl': self.ssl,
'aws_region': self.aws_region,
'sparql': self.sparql.to_dict()
'sparql': self.sparql.to_dict(),
'gremlin': self.gremlin.to_dict()
}
else:
return {
'host': self.host,
'port': self.port,
'ssl': self.ssl,
'sparql': self.sparql.to_dict()
'sparql': self.sparql.to_dict(),
'gremlin': self.gremlin.to_dict()
}

def write_to_file(self, file_path=DEFAULT_CONFIG_LOCATION):
Expand Down
11 changes: 8 additions & 3 deletions src/graph_notebook/configuration/get_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@
import json

from graph_notebook.configuration.generate_config import DEFAULT_CONFIG_LOCATION, Configuration, AuthModeEnum, \
SparqlSection
SparqlSection, GremlinSection


def get_config_from_dict(data: dict) -> Configuration:
sparql_section = SparqlSection(**data['sparql']) if 'sparql' in data else SparqlSection('')
gremlin_section = GremlinSection(**data['gremlin']) if 'gremlin' in data else GremlinSection('')
if "amazonaws.com" in data['host']:
if gremlin_section.to_dict()['traversal_source'] != 'g':
print('Ignoring custom traversal source, Amazon Neptune does not support this functionality.\n')
config = Configuration(host=data['host'], port=data['port'], auth_mode=AuthModeEnum(data['auth_mode']),
ssl=data['ssl'], load_from_s3_arn=data['load_from_s3_arn'],
aws_region=data['aws_region'], sparql_section=sparql_section)
aws_region=data['aws_region'], sparql_section=sparql_section,
gremlin_section=gremlin_section)
else:
config = Configuration(host=data['host'], port=data['port'], ssl=data['ssl'], sparql_section=sparql_section)
config = Configuration(host=data['host'], port=data['port'], ssl=data['ssl'], sparql_section=sparql_section,
gremlin_section=gremlin_section)
return config


Expand Down
3 changes: 2 additions & 1 deletion src/graph_notebook/magics/graph_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def _generate_client_from_config(self, config: Configuration):
.with_host(config.host) \
.with_port(config.port) \
.with_tls(config.ssl) \
.with_sparql_path(config.sparql.path)
.with_sparql_path(config.sparql.path) \
.with_gremlin_traversal_source(config.gremlin.traversal_source)

self.client = builder.build()

Expand Down
17 changes: 15 additions & 2 deletions src/graph_notebook/neptune/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from gremlin_python.driver import client
from gremlin_python.driver.protocol import GremlinServerError
from neo4j import GraphDatabase
import nest_asyncio

Expand Down Expand Up @@ -79,11 +80,12 @@

class Client(object):
def __init__(self, host: str, port: int = DEFAULT_PORT, ssl: bool = True, region: str = DEFAULT_REGION,
sparql_path: str = '/sparql', auth=None, session: Session = None):
sparql_path: str = '/sparql', gremlin_traversal_source: str = 'g', auth=None, session: Session = None):
self.host = host
self.port = port
self.ssl = ssl
self.sparql_path = sparql_path
self.gremlin_traversal_source = gremlin_traversal_source
self.region = region
self._auth = auth
self._session = session
Expand Down Expand Up @@ -174,7 +176,9 @@ def get_gremlin_connection(self) -> client.Client:
request = self._prepare_request('GET', uri)

ws_url = f'{self._ws_protocol}://{self.host}:{self.port}/gremlin'
return client.Client(ws_url, 'g', headers=dict(request.headers))

traversal_source = 'g' if "neptune.amazonaws.com" in self.host else self.gremlin_traversal_source
return client.Client(ws_url, traversal_source, headers=dict(request.headers))

def gremlin_query(self, query, bindings=None):
c = self.get_gremlin_connection()
Expand All @@ -185,6 +189,11 @@ def gremlin_query(self, query, bindings=None):
c.close()
return results
except Exception as e:
if isinstance(e, GremlinServerError):
if e.status_code == 499:
print("Error returned by the Gremlin Server for the traversal_source specified in notebook "
"configuration. Please ensure that your graph database endpoint supports re-naming of "
"GraphTraversalSource from the default of 'g' in Gremlin Server.")
c.close()
raise e

Expand Down Expand Up @@ -667,6 +676,10 @@ def with_sparql_path(self, path: str):
self.args['sparql_path'] = path
return ClientBuilder(self.args)

def with_gremlin_traversal_source(self, traversal_source: str):
self.args['gremlin_traversal_source'] = traversal_source
return ClientBuilder(self.args)

def with_tls(self, tls: bool):
self.args['ssl'] = tls
return ClientBuilder(self.args)
Expand Down
3 changes: 2 additions & 1 deletion test/integration/IntegrationTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def setup_client_builder(config: Configuration) -> ClientBuilder:
.with_port(config.port) \
.with_region(config.aws_region) \
.with_tls(config.ssl) \
.with_sparql_path(config.sparql.path)
.with_sparql_path(config.sparql.path) \
.with_gremlin_traversal_source(config.gremlin.traversal_source)

if config.auth_mode == AuthModeEnum.IAM:
builder = builder.with_iam(get_session())
Expand Down
2 changes: 2 additions & 0 deletions test/integration/iam/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ def setup_iam_client(config: Configuration) -> Client:
.with_region(config.aws_region) \
.with_tls(config.ssl) \
.with_sparql_path(config.sparql.path) \
.with_gremlin_traversal_source(config.gremlin.traversal_source) \
.with_iam(get_session()) \
.build()

assert client.host == config.host
assert client.port == config.port
assert client.region == config.aws_region
assert client.sparql_path == config.sparql.path
assert client.gremlin_traversal_source == config.gremlin.traversal_source
assert client.ssl is config.ssl
return client

0 comments on commit de3475c

Please sign in to comment.