Skip to content

Commit

Permalink
Add %load --nopoll option (aws#180)
Browse files Browse the repository at this point in the history
* Add %load --nopoll option

* Update logic and widget description for clarity

Co-authored-by: Michael Chin <[email protected]>
  • Loading branch information
michaelnchin and michaelnchin authored Aug 17, 2021
1 parent 9b943a4 commit dc01de1
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 87 deletions.
204 changes: 120 additions & 84 deletions src/graph_notebook/magics/graph_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,8 @@ def load(self, line='', local_ns: dict = None):
parser.add_argument('-m', '--mode', choices=LOAD_JOB_MODES, default=MODE_AUTO)
parser.add_argument('-q', '--queue-request', action='store_true', default=False)
parser.add_argument('-d', '--dependencies', action='append', default=[])
parser.add_argument('-e', '--edge-ids', action='store_true', default=False)
parser.add_argument('-n', '--nopoll', action='store_true', default=False)

args = parser.parse_args(line.split())
region = self.graph_notebook_config.aws_region
Expand All @@ -739,7 +741,7 @@ def load(self, line='', local_ns: dict = None):
value=args.loader_arn,
placeholder='Type something',
disabled=False,
layout = widgets.Layout(width=widget_width)
layout=widgets.Layout(width=widget_width)
)

source_format = widgets.Dropdown(
Expand Down Expand Up @@ -786,7 +788,7 @@ def load(self, line='', local_ns: dict = None):

user_provided_edge_ids = widgets.Dropdown(
options=['TRUE', 'FALSE'],
value=str(args.queue_request).upper(),
value=str(args.edge_ids).upper(),
disabled=False,
layout=widgets.Layout(width=widget_width)
)
Expand All @@ -805,73 +807,93 @@ def load(self, line='', local_ns: dict = None):
layout=widgets.Layout(width=widget_width)
)

poll_status = widgets.Dropdown(
options=['TRUE', 'FALSE'],
value=str(not args.nopoll).upper(),
disabled=False,
layout=widgets.Layout(width=widget_width)
)

# Create a series of HBox containers that will hold the widgets and labels
# that make up the %load form. Some of the labels and widgets are created
# in two parts to support the validation steps that come later. In the case
# of validation errors this allows additional text to easily be added to an
# HBox describing the issue.
source_hbox_label = widgets.Label('Source:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end"))
display="flex",
justify_content="flex-end"))

source_hbox = widgets.HBox([source_hbox_label,source])
source_hbox = widgets.HBox([source_hbox_label, source])

format_hbox_label = widgets.Label('Format:',
layout=widgets.Layout(width=label_width,
display="flex", justify_content="flex-end"))
display="flex",
justify_content="flex-end"))

source_format_hbox = widgets.HBox([format_hbox_label,source_format])
source_format_hbox = widgets.HBox([format_hbox_label, source_format])

region_hbox = widgets.HBox([widgets.Label('Region:',
layout=widgets.Layout(width=label_width,
display="flex", justify_content="flex-end")),
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end")),
region_box])

arn_hbox_label = widgets.Label('Load ARN:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end"))
display="flex",
justify_content="flex-end"))

arn_hbox = widgets.HBox([arn_hbox_label, arn])

mode_hbox = widgets.HBox([widgets.Label('Mode:',
layout=widgets.Layout(width=label_width,
display="flex", justify_content="flex-end")),
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end")),
mode])

fail_hbox = widgets.HBox([widgets.Label('Fail on Error:',
layout=widgets.Layout(width=label_width,
display="flex", justify_content="flex-end")),
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end")),
fail_on_error])

parallelism_hbox = widgets.HBox([widgets.Label('Parallelism:',
layout=widgets.Layout(width=label_width,
display="flex", justify_content="flex-end")),
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end")),
parallelism])


cardinality_hbox = widgets.HBox([widgets.Label('Update Single Cardinality:',
layout=widgets.Layout(width=label_width,
display="flex", justify_content="flex-end")),
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end")),
update_single_cardinality])

queue_hbox = widgets.HBox([widgets.Label('Queue Request:',
layout=widgets.Layout(width=label_width,
display="flex", justify_content="flex-end")),
layout=widgets.Layout(width=label_width,
display="flex", justify_content="flex-end")),
queue_request])

dep_hbox_label = widgets.Label('Dependencies:',
layout=widgets.Layout(width=label_width,
display="flex", justify_content="flex-end"))
layout=widgets.Layout(width=label_width,
display="flex", justify_content="flex-end"))

dep_hbox = widgets.HBox([dep_hbox_label, dependencies])

ids_hbox_label = widgets.Label('User Provided Edge Ids:',
layout=widgets.Layout(width=label_width,
display="flex", justify_content="flex-end"))
display="flex",
justify_content="flex-end"))

ids_hbox = widgets.HBox([ids_hbox_label,user_provided_edge_ids])
ids_hbox = widgets.HBox([ids_hbox_label, user_provided_edge_ids])

poll_status_label = widgets.Label('Poll Load Status:',
layout=widgets.Layout(width=label_width,
display="flex",
justify_content="flex-end"))

poll_status_hbox = widgets.HBox([poll_status_label, poll_status])

display(source_hbox,
source_format_hbox,
Expand All @@ -883,15 +905,16 @@ def load(self, line='', local_ns: dict = None):
cardinality_hbox,
queue_hbox,
dep_hbox,
ids_hbox,
ids_hbox,
poll_status_hbox,
button,
output)

def on_button_clicked(b):
source_hbox.children = (source_hbox_label,source,)
arn_hbox.children = (arn_hbox_label,arn,)
source_format_hbox.children = (format_hbox_label,source_format,)
dep_hbox.children = (dep_hbox_label,dependencies,)
source_hbox.children = (source_hbox_label, source,)
arn_hbox.children = (arn_hbox_label, arn,)
source_format_hbox.children = (format_hbox_label, source_format,)
dep_hbox.children = (dep_hbox_label, dependencies,)

dependencies_list = list(filter(None, dependencies.value.split('\n')))

Expand Down Expand Up @@ -941,9 +964,13 @@ def on_button_clicked(b):
if dependencies:
kwargs['dependencies'] = dependencies_list

if source_format.value.lower()=='opencypher':
kwargs['userProvidedEdgeIds']= user_provided_edge_ids.value
load_res = self.client.load(source.value, source_format.value, arn.value, **kwargs)
if source_format.value.lower() == 'opencypher':
kwargs['userProvidedEdgeIds'] = user_provided_edge_ids.value

if source.value.startswith("s3://"):
load_res = self.client.load(source.value, source_format.value, arn.value, **kwargs)
else:
load_res = self.client.load(source.value, source_format.value, **kwargs)
load_res.raise_for_status()
load_result = load_res.json()
store_to_ns(args.store_to, load_result, local_ns)
Expand All @@ -959,67 +986,76 @@ def on_button_clicked(b):
queue_hbox.close()
dep_hbox.close()
ids_hbox.close()
poll_status_hbox.close()
button.close()
output.close()


if 'status' not in load_result or load_result['status'] != '200 OK':
with output:
print('Something went wrong.')
print(load_result)
logger.error(load_result)
return

load_id_label = widgets.Label(f'Load ID: {load_result["payload"]["loadId"]}')
poll_interval = 5
interval_output = widgets.Output()
job_status_output = widgets.Output()

load_id_hbox = widgets.HBox([load_id_label])
status_hbox = widgets.HBox([interval_output])
vbox = widgets.VBox([load_id_hbox, status_hbox, job_status_output])
display(vbox)

last_poll_time = time.time()
while True:
time_elapsed = int(time.time() - last_poll_time)
time_remaining = poll_interval - time_elapsed
interval_output.clear_output()
if time_elapsed > poll_interval:
with interval_output:
print('checking status...')
job_status_output.clear_output()
with job_status_output:
display_html(HTML(loading_wheel_html))
try:
load_status_res = self.client.load_status(load_result['payload']['loadId'])
load_status_res.raise_for_status()
interval_check_response = load_status_res.json()
except Exception as e:
logger.error(e)
if poll_status.value == 'FALSE':
start_msg_label = widgets.Label(f'Load started successfully!')
polling_msg_label = widgets.Label(f'You can run "%load_status {load_result["payload"]["loadId"]}" '
f'in another cell to check the current status of your bulk load.')
start_msg_hbox = widgets.HBox([start_msg_label])
polling_msg_hbox = widgets.HBox([polling_msg_label])
vbox = widgets.VBox([start_msg_hbox, polling_msg_hbox])
display(vbox)
else:
poll_interval = 5
load_id_label = widgets.Label(f'Load ID: {load_result["payload"]["loadId"]}')
interval_output = widgets.Output()
job_status_output = widgets.Output()
load_id_hbox = widgets.HBox([load_id_label])
status_hbox = widgets.HBox([interval_output])
vbox = widgets.VBox([load_id_hbox, status_hbox, job_status_output])
display(vbox)

last_poll_time = time.time()
while True:
time_elapsed = int(time.time() - last_poll_time)
time_remaining = poll_interval - time_elapsed
interval_output.clear_output()
if time_elapsed > poll_interval:
with interval_output:
print('checking status...')
job_status_output.clear_output()
with job_status_output:
print('Something went wrong updating job status. Ending.')
return
job_status_output.clear_output()
with job_status_output:
print(f'Overall Status: {interval_check_response["payload"]["overallStatus"]["status"]}')
if interval_check_response["payload"]["overallStatus"]["status"] in FINAL_LOAD_STATUSES:
execution_time = interval_check_response["payload"]["overallStatus"]["totalTimeSpent"]
if execution_time == 0:
execution_time_statement = '<1 second'
elif execution_time > 59:
execution_time_statement = str(datetime.timedelta(seconds=execution_time))
else:
execution_time_statement = f'{execution_time} seconds'
print('Total execution time: ' + execution_time_statement)
interval_output.close()
print('Done.')
return
last_poll_time = time.time()
else:
with interval_output:
print(f'checking status in {time_remaining} seconds')
time.sleep(1)
display_html(HTML(loading_wheel_html))
try:
load_status_res = self.client.load_status(load_result['payload']['loadId'])
load_status_res.raise_for_status()
interval_check_response = load_status_res.json()
except Exception as e:
logger.error(e)
with job_status_output:
print('Something went wrong updating job status. Ending.')
return
job_status_output.clear_output()
with job_status_output:
print(f'Overall Status: {interval_check_response["payload"]["overallStatus"]["status"]}')
if interval_check_response["payload"]["overallStatus"]["status"] in FINAL_LOAD_STATUSES:
execution_time = interval_check_response["payload"]["overallStatus"]["totalTimeSpent"]
if execution_time == 0:
execution_time_statement = '<1 second'
elif execution_time > 59:
execution_time_statement = str(datetime.timedelta(seconds=execution_time))
else:
execution_time_statement = f'{execution_time} seconds'
print('Total execution time: ' + execution_time_statement)
interval_output.close()
print('Done.')
return
last_poll_time = time.time()
else:
with interval_output:
print(f'checking status in {time_remaining} seconds')
time.sleep(1)

except HTTPError as httpEx:
output.clear_output()
with output:
Expand Down
8 changes: 5 additions & 3 deletions src/graph_notebook/neptune/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def status(self) -> requests.Response:
res = self._http_session.send(req)
return res

def load(self, source: str, source_format: str, iam_role_arn: str, **kwargs) -> requests.Response:
def load(self, source: str, source_format: str, iam_role_arn: str = None, **kwargs) -> requests.Response:
"""
For a full list of allowed parameters, see aws documentation on the Neptune loader
endpoint: https://docs.aws.amazon.com/neptune/latest/userguide/load-api-reference-load.html
Expand All @@ -271,10 +271,12 @@ def load(self, source: str, source_format: str, iam_role_arn: str, **kwargs) ->
payload = {
'source': source,
'format': source_format,
'region': self.region,
'iamRoleArn': iam_role_arn
'region': self.region
}

if iam_role_arn:
payload['iamRoleArn'] = iam_role_arn

for key, value in kwargs.items():
payload[key] = value

Expand Down

0 comments on commit dc01de1

Please sign in to comment.