-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathgraph_sequence.py
57 lines (41 loc) · 1.52 KB
/
graph_sequence.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import keras
import numpy as np
import more_itertools
import json
from neo4j.v1 import GraphDatabase, Driver
class GraphSequence(keras.utils.Sequence):
def __init__(self, args, batch_size=32, test=False):
self.batch_size = batch_size
self.query = """
MATCH p=
(person:PERSON)
-[:WROTE]->
(review:REVIEW {dataset_name:{dataset_name}, test:{test}})
-[:OF]->
(product:PRODUCT)
RETURN person.style_preference + product.style as x, review.score as y
"""
self.query_params = {
"dataset_name": "article_0",
"test": test
}
with open('./settings.json') as f:
self.settings = json.load(f)[args.database]
driver = GraphDatabase.driver(
self.settings["neo4j_url"],
auth=(self.settings["neo4j_user"], self.settings["neo4j_password"]))
with driver.session() as session:
data = session.run(self.query, **self.query_params).data()
data = [ (np.array(i["x"]), i["y"]) for i in data]
# Split the data up into "batches"
data = more_itertools.chunked(data, self.batch_size)
# Format our batches in the way Keras expects them:
# An array of tuples (x_batch, y_batch)
# An x_batch is a numpy array of shape (batch_size, 12),
# containing the concatenated style and style_preference vectors.
# A y_batch is a numpy array of shape (batch_size,1) containing the review scores.
self.data = [ (np.array([j[0] for j in i]), np.array([j[1] for j in i])) for i in data]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]