Skip to content

Commit

Permalink
Added PyTorch implementation for creating custom graph convolution in…
Browse files Browse the repository at this point in the history
… TUTORIAL (deepchem#4134)

* Added PyTorch implementation for custom graph convolution network

* Added PyTorch implementation for custom graph convolution network

* Added explanatory text to the PyTorch implementation for the Graph Convolution tutorial

* Added explanatory text to the PyTorch implementation for the Graph Convolution tutorial

* Added a note regarding warnings while importing

* Typo fix
  • Loading branch information
spellsharp authored Oct 7, 2024
1 parent 73c877a commit 8c46096
Showing 1 changed file with 198 additions and 29 deletions.
227 changes: 198 additions & 29 deletions examples/tutorials/Introduction_to_Graph_Convolutions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand Down Expand Up @@ -58,9 +58,16 @@
"Let's use the MoleculeNet suite to load the Tox21 dataset. To featurize the data in a way that graph convolutional networks can use, we set the featurizer option to `'GraphConv'`. The MoleculeNet call returns a training set, a validation set, and a test set for us to use. It also returns `tasks`, a list of the task names, and `transformers`, a list of data transformations that were applied to preprocess the dataset. (Most deep networks are quite finicky and require a set of data transformations to ensure that training proceeds stably.)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: While importing deepchem, if you see any warnings, ignore them for now. Deepchem is a vast library and there are many things that can cause minor warnings to occur. Almost always, it doesn't require any action from your side."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand Down Expand Up @@ -90,7 +97,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -104,10 +111,10 @@
{
"data": {
"text/plain": [
"0.28185401916503905"
"0.29102970123291017"
]
},
"execution_count": 2,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -132,7 +139,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -147,8 +154,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Training set score: {'roc_auc_score': 0.96959686893055}\n",
"Test set score: {'roc_auc_score': 0.795793783300876}\n"
"Training set score: {'roc_auc_score': 0.970785822904073}\n",
"Test set score: {'roc_auc_score': 0.7112009940440461}\n"
]
}
],
Expand Down Expand Up @@ -176,9 +183,32 @@
"Apart from this we are going to apply standard neural network layers such as [Dense](https://keras.io/api/layers/core_layers/dense/), [BatchNormalization](https://keras.io/api/layers/normalization_layers/batch_normalization/) and [Softmax](https://keras.io/api/layers/activation_layers/softmax/) layer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training a custom Graph Convolution network"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you may have seen in the previous tutorials, DeepChem offers both PyTorch and Tensorflow functionalities. However, most of our work moving forward will leverage the PyTorch ecosystem. <br />\n",
"\n",
"Let's look at the Tensorflow implementation first."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tensorflow"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 7,
"metadata": {
"colab": {},
"colab_type": "code",
Expand All @@ -192,10 +222,10 @@
"\n",
"batch_size = 100\n",
"\n",
"class MyGraphConvModel(tf.keras.Model):\n",
"class GraphConvModelTensorflow(tf.keras.Model):\n",
"\n",
" def __init__(self):\n",
" super(MyGraphConvModel, self).__init__()\n",
" super(GraphConvModelTensorflow, self).__init__()\n",
" self.gc1 = GraphConv(128, activation_fn=tf.nn.tanh)\n",
" self.batch_norm1 = layers.BatchNormalization()\n",
" self.gp1 = GraphPool()\n",
Expand Down Expand Up @@ -243,15 +273,15 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 8,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "31Wr0t2zcj1q"
},
"outputs": [],
"source": [
"model = dc.models.KerasModel(MyGraphConvModel(), loss=dc.models.losses.CategoricalCrossEntropy())"
"model = dc.models.KerasModel(GraphConvModelTensorflow(), loss=dc.models.losses.CategoricalCrossEntropy())"
]
},
{
Expand All @@ -266,16 +296,16 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<deepchem.feat.mol_graphs.ConvMol at 0x14d0b1650>"
"<deepchem.feat.mol_graphs.ConvMol at 0x7bf66bfa1160>"
]
},
"execution_count": 6,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -295,7 +325,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 10,
"metadata": {
"colab": {},
"colab_type": "code",
Expand Down Expand Up @@ -331,7 +361,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -345,10 +375,10 @@
{
"data": {
"text/plain": [
"0.21941944122314452"
"0.23354644775390626"
]
},
"execution_count": 8,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -369,7 +399,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -385,8 +415,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Training set score: {'roc_auc_score': 0.8425638289185731}\n",
"Test set score: {'roc_auc_score': 0.7378436684114341}\n"
"Training set score: {'roc_auc_score': 0.8370577643901682}\n",
"Test set score: {'roc_auc_score': 0.6610993488016647}\n"
]
}
],
Expand All @@ -397,12 +427,151 @@
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "tvOYgj52cj16"
},
"metadata": {},
"source": [
"## PyTorch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before working on the PyTorch implementation, we must import a few crucial layers from the `torch_models` collection. These are PyTorch implementations of `GraphConv`, `GraphPool` and `GraphGather` which we used in the tensorflow's implementation as well."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from deepchem.models.torch_models.layers import GraphConv, GraphGather, GraphPool"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"PyTorch's `GraphConv` requires the number of input features to be specified, hence we can extract that piece of information by following steps:\n",
"1. First we get a sample from the dataset. \n",
"2. Next we slice and separate the node_features (which is the first element of the list, hence the index 0). \n",
"3. Finally, we obtain the number of features by finding the shape of the array."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of input features: 75\n"
]
}
],
"source": [
"sample_batch = next(data_generator(train_dataset))\n",
"node_features = sample_batch[0][0]\n",
"num_input_features = node_features.shape[1]\n",
"print(f\"Number of input features: {num_input_features}\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"class GraphConvModelTorch(nn.Module):\n",
" def __init__(self):\n",
" super(GraphConvModelTorch, self).__init__()\n",
" \n",
" self.gc1 = GraphConv(out_channel=128, number_input_features=num_input_features, activation_fn=nn.Tanh())\n",
" self.batch_norm1 = nn.BatchNorm1d(128)\n",
" self.gp1 = GraphPool()\n",
"\n",
" self.gc2 = GraphConv(out_channel=128, number_input_features=128, activation_fn=nn.Tanh())\n",
" self.batch_norm2 = nn.BatchNorm1d(128)\n",
" self.gp2 = GraphPool()\n",
"\n",
" self.dense1 = nn.Linear(128, 256)\n",
" self.act3 = nn.Tanh()\n",
" self.batch_norm3 = nn.BatchNorm1d(256)\n",
" self.readout = GraphGather(batch_size=batch_size, activation_fn=nn.Tanh())\n",
" \n",
" self.dense2 = nn.Linear(512, n_tasks * 2) \n",
" \n",
" self.logits = lambda data: data.view(-1, n_tasks, 2)\n",
" self.softmax = nn.Softmax(dim=-1)\n",
" \n",
" def forward(self, inputs):\n",
" gc1_output = self.gc1(inputs)\n",
" batch_norm1_output = self.batch_norm1(gc1_output)\n",
" gp1_output = self.gp1([batch_norm1_output] + inputs[1:])\n",
"\n",
" gc2_output = self.gc2([gp1_output] + inputs[1:])\n",
" batch_norm2_output = self.batch_norm2(gc2_output)\n",
" gp2_output = self.gp2([batch_norm2_output] + inputs[1:])\n",
"\n",
" dense1_output = self.act3(self.dense1(gp2_output))\n",
" batch_norm3_output = self.batch_norm3(dense1_output)\n",
" readout_output = self.readout([batch_norm3_output] + inputs[1:])\n",
" \n",
" dense2_output = self.dense2(readout_output)\n",
" logits_output = self.logits(dense2_output)\n",
" softmax_output = self.softmax(logits_output)\n",
" return softmax_output"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.2121513557434082"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = dc.models.TorchModel(GraphConvModelTorch(), loss=dc.models.losses.CategoricalCrossEntropy())\n",
"model.fit_generator(data_generator(train_dataset, epochs=50))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training set score: {'roc_auc_score': 0.9838238607233897}\n",
"Test set score: {'roc_auc_score': 0.6923516284964811}\n"
]
}
],
"source": [
"print('Training set score:', model.evaluate_generator(data_generator(train_dataset), [metric], transformers))\n",
"print('Test set score:', model.evaluate_generator(data_generator(test_dataset), [metric], transformers))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Success! The model we've constructed behaves nearly identically to `GraphConvModel`. If you're looking to build your own custom models, you can follow the example we've provided here to do so. We hope to see exciting constructions from your end soon!"
"Success! Both the models we've constructed behave nearly identically to `GraphConvModel`. If you're looking to build your own custom models, you can follow the examples we've provided here to do so. We hope to see exciting constructions from your end soon!"
]
},
{
Expand Down Expand Up @@ -446,7 +615,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.8.5"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 8c46096

Please sign in to comment.