Skip to content

Commit

Permalink
Reimplementation of Graph Neural Networks with Scatter/Gather layers (L…
Browse files Browse the repository at this point in the history
…LNL#1933)

* Added lbann data generator

* Remove extraneous data generator file

* Updated GNN implementations to use 2D Scatter Gather
- Added new ChannelwiseGRU implementation

* Added new implementation of PROTEINS dataset for Sparse Scatter-Gather based GNNs
- Adds new implementation of the sparse graph trainer for new GNN modules with 2D scatter-gather based message passing
- PROTEINS dataset ready for unit testing of sparse GNNs (GCN, GIN, GatedGraph, and Graph conv modules)

* Updated graph trainer with new graph neural network calls
- Added sample graph data slice function

* - Add 2D-scatter-gather based implementation of graph kernel
- Add distconv enabled graph kernel for channelwise fully connected layer based kernel

* - Fixed typos and added documentation for Graph modules
- Updated integration test
- Integration tests in application/graph/GNN/test passing
- Minor change in Reshape.hpp to include layer name when throwing error
- Updated trainer code according to new Sparse data format

* Removed MNIST_Superpixel dataset for PyTorch dependency
 - Updated integration tests on applications/graph/GNN/test/ for stable testing of GNNs
 - Added complete distconv support to NNConv

* Updated README for test directions
- Added ChannelwiseGRU class documentation

* Updated NNConvModel to include distconv through command line interface

* Fix Channelwise-RNN implementation with correct Slice impl

- Add layer info to error message in Reshape
- Fix typos in Python documentation of implementation

Co-authored-by: Tim Moon <[email protected]>

* Added Channelwise GRUCell unit cell
- Fixed link to point to correct PyTorch doc for GRUCell

* Move lbann import in Bamboo test inside test function

Co-authored-by: Tim Moon <[email protected]>
  • Loading branch information
szaman19 and Tim Moon authored Aug 18, 2021
1 parent 0ec9831 commit 1faa1a4
Show file tree
Hide file tree
Showing 33 changed files with 1,227 additions and 876 deletions.
25 changes: 4 additions & 21 deletions applications/graph/GNN/Dense_Graph_Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def DGraph_Layer(feature_matrix,adj_matrix, node_features):
def make_model(num_vertices = None,
node_features = None,
num_classes = None,
dataset = None,
kernel_type = 'GCN',
callbacks = None,
num_epochs = 1):
Expand All @@ -80,9 +79,6 @@ def make_model(num_vertices = None,
num_vertices (int): Number of vertices of each graph (default: None)
node_features (int): Number of features per noded (default: None)
num_classes (int): Number of classes as targets (default: None)
dataset (str): Preset data set to use. Either a datset parameter has to be
supplied or all of num_vertices, node_features, and
num_classes have to be supplied. (default: None)
kernel_type (str): Graph Kernel to use in model. Expected one of
GCN, or Graph (deafult: GCN)
callbacks (list): Callbacks for the model. If set to None the model description,
Expand All @@ -94,22 +90,9 @@ def make_model(num_vertices = None,
presets, and graph kernels.
'''

assert num_vertices != dataset #Ensure atleast one of the values is set

if dataset is not None:
assert num_vertices is None

if dataset == 'MNIST':
num_vertices = 75
num_classes = 10
node_features = 1

elif dataset == 'PROTEINS':
num_vertices = 100
num_classes = 2
node_features = 3
else:
raise Exception("Unkown Dataset")
num_vertices = 100
num_classes = 2
node_features = 3

assert num_vertices is not None
assert num_classes is not None
Expand All @@ -120,7 +103,7 @@ def make_model(num_vertices = None,
# Reshape and Slice Input Tensor
#----------------------------------

input_ = lbann.Input(target_mode = 'classification')
input_ = lbann.Input(target_mode='N/A')

# Input dimensions should be (num_vertices * node_features + num_vertices^2 + num_classes )
# input should have atleast two children since the target is classification
Expand Down
Loading

0 comments on commit 1faa1a4

Please sign in to comment.