Skip to content

Commit

Permalink
feat: phase 0 of networkx query reader, pagerank algo
Browse files Browse the repository at this point in the history
partially implement: #28
  • Loading branch information
wey-gu committed Mar 24, 2023
1 parent d97d17f commit 375adf1
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 205 deletions.
28 changes: 24 additions & 4 deletions ng_ai/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,35 @@ class NebulaEngine(BaseEngine):
def __init__(self, config=None):
self.type = "nebula"
self.config = config

# let's make all nx related import here
import networkx as nx
import ng_nx
from ng_nx import NebulaReader as NxReader
from ng_nx import NxScanReader, NxWriter
from ng_nx.utils import NxConfig, result_to_df

self.nx = nx
self.ng_nx = ng_nx
self.nx_reader = NxReader
self.nx_writer = NxWriter
self.nx_scan_reader = NxScanReader
self._nx_config = NxConfig

self.result_to_df = result_to_df

self.nx_config = None
self.parse_config()

def __str__(self):
return f"NebulaEngine: {self.config}"
return (
f"NebulaEngine(NetworkX): {self.config}, "
f"nx version: {self.nx.__version__}, "
f"ng_nx version: {self.ng_nx.__version__}"
)

def parse_config(self):
"""parse and validate config"""
if self.config is None:
return

def prepare(self):
pass
self.nx_config = self._nx_config(**self.config.__dict__)
21 changes: 18 additions & 3 deletions ng_ai/nebula_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ class NebulaGraphAlgorithm:
def __init__(self, graph):
self.graph = graph
self.algorithms = []
self.engine = graph.engine

def register_algo(self, func):
self.algorithms.append(func.__name__)
Expand All @@ -384,14 +385,28 @@ def check_engine(self):
For spark, we need to convert the NebulaGraphObject
to NebulaDataFrameObject
"""
if self.graph.engine.type == "spark":
if self.engine.type == "spark":
raise Exception(
"For NebulaGraphObject in spark engine,"
"Plz transform it to NebulaDataFrameObject to run algorithm",
"For example: df = nebula_graph.to_df; df.algo.pagerank()",
)
if self.engine.type == "networkx":
return True
else:
raise Exception("Unsupported engine type")

@algo
def pagerank(self, reset_prob=0.15, max_iter=10):
def pagerank(self, reset_prob=0.15, max_iter=10, **kwargs):
self.check_engine()
pass
g = self.graph._graph
weight = kwargs.get("weight", None)
assert type(weight) in [str, type(None)], "weight must be str or None"
assert type(reset_prob) == float, "reset_prob must be float"
assert type(max_iter) == int, "max_iter must be int"
tol = kwargs.get("tol", 1e-06)
assert type(tol) == float, "tol must be float"

return self.engine.nx.pagerank(
g, alpha=1 - reset_prob, max_iter=max_iter, tol=tol, weight=weight
)
42 changes: 29 additions & 13 deletions ng_ai/nebula_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@ def get_engine(self):

@property
def algo(self):
from ng_ai.nebula_algo import NebulaAlgorithm as NebulaAlgorithmImpl
if self.engine.type == "spark":
print(
"NebulaGraphObject.algo is not supported in spark engine, "
"please use NebulaDataFrameObject.algo instead"
)
raise NotImplementedError
if self.engine.type == "nebula":
from ng_ai.nebula_algo import NebulaAlgorithm as NebulaAlgorithmImpl

return NebulaAlgorithmImpl(self)
return NebulaAlgorithmImpl(self)

def get_nx_graph(self):
if self.engine.type == "nebula":
Expand Down Expand Up @@ -81,9 +88,16 @@ def get_engine(self):

@property
def algo(self):
from ng_ai.nebula_algo import NebulaAlgorithm as NebulaAlgorithmImpl
if self.engine.type == "spark":
from ng_ai.nebula_algo import NebulaAlgorithm as NebulaAlgorithmImpl

return NebulaAlgorithmImpl(self)
return NebulaAlgorithmImpl(self)
else:
print(
"NebulaDataFrameObject.algo is not supported in nebula engine, "
"please use NebulaGraphObject.algo instead"
)
raise NotImplementedError

def to_spark_df(self):
if self.engine.type == "spark":
Expand All @@ -103,14 +117,11 @@ def to_pandas_df(self):
raise NotImplementedError

def to_networkx(self):
if self.engine.type == "nebula":
return nx.from_pandas_edgelist(self.data, "src", "dst")
else:
# for now the else case will be spark, to networkx is not supported
raise Exception(
"For NebulaDataFrameObject in spark engine,"
"convert to networkx graph is not supported",
)
# for now the else case will be spark, to networkx is not supported
raise Exception(
"For NebulaDataFrameObject in spark engine,"
"convert to networkx graph is not supported",
)

def to_graphx(self):
if self.engine.type == "spark":
Expand All @@ -128,4 +139,9 @@ def to_graph(self):
return NebulaGraphObject(self)

def show(self, *keywords, **kwargs):
return self.data.show(*keywords, **kwargs)
if self.engine.type == "spark":
self.data.show(*keywords, **kwargs)
elif self.engine.type == "nebula":
print(self.data)
else:
raise NotImplementedError
40 changes: 36 additions & 4 deletions ng_ai/nebula_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from ng_ai.config import NebulaGraphConfig
from ng_ai.nebula_data import NebulaDataFrameObject

DEFAULT_NEBULA_QUERY_LIMIT = 1000


class NebulaReaderBase(object):
def __init__(self, engine=None, config=None, **kwargs):
Expand Down Expand Up @@ -54,22 +56,52 @@ def __init__(self, config: NebulaGraphConfig, **kwargs):
self.engine = NebulaEngine(config)
self.raw_df = None
self.df = None
self.reader = None

def scan(self, **kwargs):
# Implement the scan method specific to Nebula engine
raise NotImplementedError

def query(self, **kwargs):
# Implement the query method specific to Nebula engine
raise NotImplementedError
limit = kwargs.get("limit", DEFAULT_NEBULA_QUERY_LIMIT)
assert type(limit) == int, "limit should be an integer"
assert "space" in kwargs, "space is required"
space = kwargs["space"]
assert "edges" in kwargs, "edges is required"
edges = kwargs["edges"]
assert type(edges) == list, "edges should be a list"
length_of_edges = len(edges)
props = kwargs.get("props", [[]] * length_of_edges)
assert type(props) == list, "props should be a list"
assert (
len(props) == length_of_edges
), "length of props should be equal to length of edges"
for prop in props:
assert type(prop) == list, "props should be a list of list"
for item in prop:
assert type(item) == str, "props should be a list of list of string"

self.reader = NebulaReader(
space=space,
edges=edges,
properties=props,
nebula_config=self.engine._nx_config,
limit=limit,
)

return self.reader

def load(self, **kwargs):
# Implement the load method specific to Nebula engine
raise NotImplementedError

def read(self, **kwargs):
# Implement the read method specific to Nebula engine
raise NotImplementedError
if self.reader is None:
raise Exception(
"reader is not initialized, please call query or scan first"
)
self._graph = self.reader.read()
return self._graph

def show(self, **kwargs):
# Implement the show method specific to Nebula engine
Expand Down
Loading

0 comments on commit 375adf1

Please sign in to comment.