This example trains a Graph Neural Network to classify molecules on the basis of their biological activities.
We use Jraph, a JAX library for Graph Neural Networks, to define models which are trained on the ogbg-molpcba dataset, part of the Open Graph Benchmark.
You can run this code and even modify it directly in Google Colab, no installation required! The Colab notebook can even create visualizations of model predictions:
We depend on TensorFlow Datasets for ogbg-molpcba.
To run with the default configuration:
python main.py --workdir=./ogbg_molpcba --config=configs/default.py
Since the configuration is defined using config_flags, you can override hyperparameters. For example, to change the number of epochs and the batch size:
python main.py --workdir=./ogbg_molpcba --config=configs/default.py \
--config.num_training_epochs=10 --config.batch_size=50
For more extensive changes, you can directly edit the default configuration file or even add your own.
This example supports only single device training. The model should run with other configurations and hardware, but was explicitly tested on the following.
Hardware | Batch size | Training time | Test mean AP | Validation mean AP | Metrics |
---|---|---|---|---|---|
1x V100 | 256 | 3h20m | 0.244 | 0.252 | 2021-08-03 |
These metrics reported above are obtained at the end of training. We observed that slightly higher metrics can be obtained with early-stopping based on the validation mean AP:
Hardware | Batch size | Training time | Test mean AP | Validation mean AP | Metrics |
---|---|---|---|---|---|
1x V100 | 256 | 2h55m | 0.249 | 0.257 | 2021-08-03 |
The default configuration corresponds to a Graph Convolutional Network model with 695,936 parameters.
We noticed diminishing gains when training for longer. Further, the addition of self-loops and undirected edges significantly helped performance. Minor improvements were seen with skip-connections across message-passing steps, together with LayerNorm. On the contrary, we found that the addition of virtual nodes, which are connected to all nodes in each graph, did not improve performance.
-
Weihua Hu, Matthias Fey, Marinka Zitnik, Yuxiao Dong, Hongyu Ren, Bowen Liu, Michele Catasta and Jure Leskovec (2020). Open Graph Benchmark: Datasets for Machine Learning on Graphs. In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual.
-
Thomas N. Kipf and Max Welling (2016). Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907.
-
Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E. Hinton (2016). Layer normalization. arXiv preprint arXiv:1607.06450.
-
Junying Li, Deng Cai and Xiaofei He (2017). Learning graph-level representation for drug discovery. arXiv preprint arXiv:1709.03741.
The caramboxin molecule diagram depicted above was obtained and modified from Wikimedia Commons, available in the public domain.