forked from udlbook/udlbook
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
328 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,328 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"provenance": [], | ||
"authorship_tag": "ABX9TyOoGS+lY+EhGthebSO4smpj", | ||
"include_colab_link": true | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "view-in-github", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap11/11_3_Batch_Normalization.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# **Notebook 11.3: Batch normalization**\n", | ||
"\n", | ||
"This notebook investigates the use of batch normalization in residual networks.\n", | ||
"\n", | ||
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n", | ||
"\n", | ||
"Contact me at [email protected] if you find any mistakes or have any suggestions.\n", | ||
"\n" | ||
], | ||
"metadata": { | ||
"id": "t9vk9Elugvmi" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Run this if you're in a Colab to make a local copy of the MNIST 1D repository\n", | ||
"!git clone https://github.com/greydanus/mnist1d" | ||
], | ||
"metadata": { | ||
"id": "D5yLObtZCi9J" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"import numpy as np\n", | ||
"import os\n", | ||
"import torch, torch.nn as nn\n", | ||
"from torch.utils.data import TensorDataset, DataLoader\n", | ||
"from torch.optim.lr_scheduler import StepLR\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import mnist1d\n", | ||
"import random" | ||
], | ||
"metadata": { | ||
"id": "YrXWAH7sUWvU" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"args = mnist1d.data.get_dataset_args()\n", | ||
"data = mnist1d.data.get_dataset(args, path='./mnist1d_data.pkl', download=False, regenerate=False)\n", | ||
"\n", | ||
"# The training and test input and outputs are in\n", | ||
"# data['x'], data['y'], data['x_test'], and data['y_test']\n", | ||
"print(\"Examples in training set: {}\".format(len(data['y'])))\n", | ||
"print(\"Examples in test set: {}\".format(len(data['y_test'])))\n", | ||
"print(\"Length of each example: {}\".format(data['x'].shape[-1]))" | ||
], | ||
"metadata": { | ||
"id": "twI72ZCrCt5z" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Load in the data\n", | ||
"train_data_x = data['x'].transpose()\n", | ||
"train_data_y = data['y']\n", | ||
"val_data_x = data['x_test'].transpose()\n", | ||
"val_data_y = data['y_test']\n", | ||
"# Print out sizes\n", | ||
"print(\"Train data: %d examples (columns), each of which has %d dimensions (rows)\"%((train_data_x.shape[1],train_data_x.shape[0])))\n", | ||
"print(\"Validation data: %d examples (columns), each of which has %d dimensions (rows)\"%((val_data_x.shape[1],val_data_x.shape[0])))" | ||
], | ||
"metadata": { | ||
"id": "8bKADvLHbiV5" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"def print_variance(name, data):\n", | ||
" # First dimension(rows) is batch elements\n", | ||
" # Second dimension(columns) is neurons.\n", | ||
" np_data = data.detach().numpy()\n", | ||
" # Compute variance across neurons and average these variances over members of the batch\n", | ||
" neuron_variance = np.mean(np.var(np_data, axis=0))\n", | ||
" # Print out the name and the variance\n", | ||
" print(\"%s variance=%f\"%(name,neuron_variance))" | ||
], | ||
"metadata": { | ||
"id": "3bBpJIV-N-lt" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# He initialization of weights\n", | ||
"def weights_init(layer_in):\n", | ||
" if isinstance(layer_in, nn.Linear):\n", | ||
" nn.init.kaiming_uniform_(layer_in.weight)\n", | ||
" layer_in.bias.data.fill_(0.0)" | ||
], | ||
"metadata": { | ||
"id": "YgLaex1pfhqz" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"def run_one_step_of_model(model, x_train, y_train):\n", | ||
" # choose cross entropy loss function (equation 5.24 in the loss notes)\n", | ||
" loss_function = nn.CrossEntropyLoss()\n", | ||
" # construct SGD optimizer and initialize learning rate and momentum\n", | ||
" optimizer = torch.optim.SGD(model.parameters(), lr = 0.05, momentum=0.9)\n", | ||
"\n", | ||
" # load the data into a class that creates the batches\n", | ||
" data_loader = DataLoader(TensorDataset(x_train,y_train), batch_size=200, shuffle=True, worker_init_fn=np.random.seed(1))\n", | ||
"\n", | ||
" # Initialize model weights\n", | ||
" model.apply(weights_init)\n", | ||
"\n", | ||
" # Get a batch\n", | ||
" for i, data in enumerate(data_loader):\n", | ||
" # retrieve inputs and labels for this batch\n", | ||
" x_batch, y_batch = data\n", | ||
" # zero the parameter gradients\n", | ||
" optimizer.zero_grad()\n", | ||
" # forward pass -- calculate model output\n", | ||
" pred = model(x_batch)\n", | ||
" # compute the loss\n", | ||
" loss = loss_function(pred, y_batch)\n", | ||
" # backward pass\n", | ||
" loss.backward()\n", | ||
" # SGD update\n", | ||
" optimizer.step()\n", | ||
" # Break out of this loop -- we just want to see the first\n", | ||
" # iteration, but usually we would continue\n", | ||
" break" | ||
], | ||
"metadata": { | ||
"id": "DFlu45pORQEz" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# convert training data to torch tensors\n", | ||
"x_train = torch.tensor(train_data_x.transpose().astype('float32'))\n", | ||
"y_train = torch.tensor(train_data_y.astype('long'))" | ||
], | ||
"metadata": { | ||
"id": "i7Q0ScWgRe4G" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# This is a simple residual model with 5 residual branches in a row\n", | ||
"class ResidualNetwork(torch.nn.Module):\n", | ||
" def __init__(self, input_size, output_size, hidden_size=100):\n", | ||
" super(ResidualNetwork, self).__init__()\n", | ||
" self.linear1 = nn.Linear(input_size, hidden_size)\n", | ||
" self.linear2 = nn.Linear(hidden_size, hidden_size)\n", | ||
" self.linear3 = nn.Linear(hidden_size, hidden_size)\n", | ||
" self.linear4 = nn.Linear(hidden_size, hidden_size)\n", | ||
" self.linear5 = nn.Linear(hidden_size, hidden_size)\n", | ||
" self.linear6 = nn.Linear(hidden_size, output_size)\n", | ||
"\n", | ||
" def count_params(self):\n", | ||
" return sum([p.view(-1).shape[0] for p in self.parameters()])\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" print_variance(\"Input\",x)\n", | ||
" f = self.linear1(x)\n", | ||
" print_variance(\"First preactivation\",f)\n", | ||
" res1 = f+ self.linear2(f.relu())\n", | ||
" print_variance(\"After first residual connection\",res1)\n", | ||
" res2 = res1 + self.linear3(res1.relu())\n", | ||
" print_variance(\"After second residual connection\",res2)\n", | ||
" res3 = res2 + self.linear4(res2.relu())\n", | ||
" print_variance(\"After third residual connection\",res3)\n", | ||
" res4 = res3 + self.linear4(res3.relu())\n", | ||
" print_variance(\"After fourth residual connection\",res4)\n", | ||
" res5 = res4 + self.linear4(res4.relu())\n", | ||
" print_variance(\"After fifth residual connection\",res5)\n", | ||
" return self.linear6(res5)" | ||
], | ||
"metadata": { | ||
"id": "FslroPJJffrh" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Define the model and run for one step\n", | ||
"# Monitoring the variance at each point in the network\n", | ||
"n_hidden = 100\n", | ||
"n_input = 40\n", | ||
"n_output = 10\n", | ||
"model = ResidualNetwork(n_input, n_output, n_hidden)\n", | ||
"run_one_step_of_model(model, x_train, y_train)" | ||
], | ||
"metadata": { | ||
"id": "NYw8I_3mmX5c" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Notice that the variance roughly doubles at each step so it increases exponentially as in figure 11.6b in the book." | ||
], | ||
"metadata": { | ||
"id": "0kZUlWkkW8jE" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# TODO Adapt the residual network below to add a batch norm operation\n", | ||
"# before the contents of each residual link as in figure 11.6c in the book\n", | ||
"# Use the torch function nn.BatchNorm1d\n", | ||
"class ResidualNetworkWithBatchNorm(torch.nn.Module):\n", | ||
" def __init__(self, input_size, output_size, hidden_size=100):\n", | ||
" super(ResidualNetworkWithBatchNorm, self).__init__()\n", | ||
" self.linear1 = nn.Linear(input_size, hidden_size)\n", | ||
" self.linear2 = nn.Linear(hidden_size, hidden_size)\n", | ||
" self.linear3 = nn.Linear(hidden_size, hidden_size)\n", | ||
" self.linear4 = nn.Linear(hidden_size, hidden_size)\n", | ||
" self.linear5 = nn.Linear(hidden_size, hidden_size)\n", | ||
" self.linear6 = nn.Linear(hidden_size, output_size)\n", | ||
"\n", | ||
" def count_params(self):\n", | ||
" return sum([p.view(-1).shape[0] for p in self.parameters()])\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" print_variance(\"Input\",x)\n", | ||
" f = self.linear1(x)\n", | ||
" print_variance(\"First preactivation\",f)\n", | ||
" res1 = f+ self.linear2(f.relu())\n", | ||
" print_variance(\"After first residual connection\",res1)\n", | ||
" res2 = res1 + self.linear3(res1.relu())\n", | ||
" print_variance(\"After second residual connection\",res2)\n", | ||
" res3 = res2 + self.linear4(res2.relu())\n", | ||
" print_variance(\"After third residual connection\",res3)\n", | ||
" res4 = res3 + self.linear4(res3.relu())\n", | ||
" print_variance(\"After fourth residual connection\",res4)\n", | ||
" res5 = res4 + self.linear4(res4.relu())\n", | ||
" print_variance(\"After fifth residual connection\",res5)\n", | ||
" return self.linear6(res5)" | ||
], | ||
"metadata": { | ||
"id": "5JvMmaRITKGd" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Define the model\n", | ||
"n_hidden = 100\n", | ||
"n_input = 40\n", | ||
"n_output = 10\n", | ||
"model = ResidualNetworkWithBatchNorm(n_input, n_output, n_hidden)\n", | ||
"run_one_step_of_model(model, x_train, y_train)" | ||
], | ||
"metadata": { | ||
"id": "2U3DnlH9Uw6c" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Note that the variance now increases linearly as in figure 11.6c." | ||
], | ||
"metadata": { | ||
"id": "R_ucFq9CXq8D" | ||
} | ||
} | ||
] | ||
} |