diff --git a/Makefile b/Makefile index 95a7857..232dab7 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ OPENCV=0 OPENMP=0 DEBUG=0 -OBJ=main.o image.o args.o test.o matrix.o list.o data.o classifier.o net.o connected_layer.o activation_layer.o convolutional_layer.o maxpool_layer.o +OBJ=main.o image.o args.o test.o matrix.o list.o data.o classifier.o net.o connected_layer.o activation_layer.o convolutional_layer.o maxpool_layer.o batchnorm_layer.o EXOBJ=test.o VPATH=./src/:./ diff --git a/README.md b/README.md index c7849b5..17f03f3 100644 --- a/README.md +++ b/README.md @@ -6,3 +6,4 @@ but I'm working on it! - [Homework 0](./hw0.md) - [Homework 1](./hw1.md) +- [Homework 2](./hw2.md) diff --git a/collate_hw0.sh b/collate_hw0.sh index beedd2a..9463eca 100755 --- a/collate_hw0.sh +++ b/collate_hw0.sh @@ -17,5 +17,5 @@ prepare hw0.ipynb tar cvzf hw0.tar.gz submit rm -rf submit/ -echo "Done. Please upload submit.tar.gz to Canvas." +echo "Done. Please upload hw0.tar.gz to Canvas." diff --git a/collate_hw1.sh b/collate_hw1.sh index 0478f6b..f6d7b7c 100644 --- a/collate_hw1.sh +++ b/collate_hw1.sh @@ -16,5 +16,5 @@ prepare hw1.ipynb tar cvzf hw1.tar.gz submit rm -rf submit/ -echo "Done. Please upload submit.tar.gz to Canvas." +echo "Done. Please upload hw1.tar.gz to Canvas." diff --git a/collate_hw2.sh b/collate_hw2.sh new file mode 100644 index 0000000..2ed168e --- /dev/null +++ b/collate_hw2.sh @@ -0,0 +1,19 @@ +rm -rf submit/ +mkdir -p submit + +prepare () { + if [[ $(git diff origin -- $1 | wc -c) -eq 0 ]]; then + echo "WARNING: $1 is unchanged according to git." + fi + cp $1 submit/ +} + +echo "Creating tarball..." +prepare src/batchnorm_layer.c +prepare tryhw2.py +prepare hw2.ipynb + +tar cvzf hw2.tar.gz submit +rm -rf submit/ +echo "Done. Please upload hw2.tar.gz to Canvas." + diff --git a/hw2.md b/hw2.md new file mode 100644 index 0000000..37c26d6 --- /dev/null +++ b/hw2.md @@ -0,0 +1,102 @@ +# CSE 490g1 / 599g1 Homework 2 # + +Welcome friends, + +For the third assignment we'll be implementing a powerful tool for improving optimization, batch normalization! + +You'll have to copy over your answers from the previous assignment. + +## 7. Batch Normalization ## + +The idea behind [batch normalization](https://arxiv.org/pdf/1502.03167.pdf) is simple: we'll normalize the layer output so every neuron has zero mean and unit variance. However, this simple technique provides huge benefits for model stability, convergence, and regularization. + +### Batch Norm and Convolutions ### + +Batch normalization after fully connected layers is easy. You simply calculate the batch statistics for each neuron and then normalize. With our framework, every row is a different example in a batch and every column is a different neuron so we will calculate statistics for each column and then normalize so that every column has mean 0 and variance 1. + +With convolutional layers we are going to normalize the output of a filter over a batch of images. Each filter produces a single channel in the output of a convolutional layer. Thus for batch norm, we are normalizing across a batch of channels in the output. So, for example, we calculate the mean and variance for all the 1st channels across all the examples, all the 2nd channels across all the examples, etc. Another way of thinking about it is we are normalizing the output of a single filter, which gets applied both to all of the examples in a batch but also at numerous spatial locations for every image. + +Thus for our batch normalization functions we will be normalizing across rows but also across columns depending on the spatial component of a feature map. Check out `batch_norm.c`, I've already filled in the `mean` example for calculating the mean of a batch. + +The `groups` parameter will tell you how many groups (i.e. channels) there are in the output. So, if your convolutional layer outputs a `32 x 32 x 8` image and has a batch size of 128, the matrix `x` will have 128 rows and 8192 columns. We want to calculate a mean for every channel thus the `groups` parameter will be 8 and our matrix `m` will have 1 row and 8 columns (since there are 8 channels in the output). + +We also calculate an `n` parameter that tells us the number of elements per group in one example. The images we are processing are `32 x 32` so the `n` parameter in this case will be the integer 1024. The total number of elements in matrix `x` is the number of examples in the batch (`x.rows`) times the number of groups (`groups`) times the number of elements per group (`n`). + +After a fully connected layer, the `groups` parameter would always be the number of outputs and we would calculate separate means for each neuron in the output and the `n` parameter would be 1. + +### Forward propagation ### + +These are the forward propagation equations from the [original paper](https://arxiv.org/abs/1502.03167). Note, in the original terminology we're just use `x̂` as the output, we'll skip the scaling and shifting: + + + +### 7.1 `variance` ### + +Fill in the section to compute the variance of a feature map. As in the `mean` computation, we will compute variance for each filter. We need the previously computed `mean` for this computation so it is passed in as a parameter. Remember, variance is just the average squared difference of an element from the mean: + +![variance equation](https://wikimedia.org/api/rest_v1/media/math/render/svg/0c5c6e7bbd52e69c29e2d5cfe21989313aba55d4) + +Don't take the square root just yet, that would be standard deviation! + +### 7.2 `normalize` ### + +To normalize our output, we simply subtract the mean from every element and divide by the standard deviation (now you'll need a square root). When you're dividing by the standard deviation it can be wise to add in a small term (the epsilon in the batchnorm equations) to prevent dividing by zero. Especially if you are using RELUs, you may occassionally have a batch with 0 variance. + +You should use `eps = 0.00001f`. + +### Understanding the forward pass ### + +`batch_normalize_forward` shows how we process the forward pass of batch normalization. Mostly we're doing what you'd expect, calculating mean and variance and normalizing with them. + +We are also keeping track of a rolling average of our mean and variance. During training or when processing large batches we can calculate means and variances but if we just want to process a single image it can be hard to get good statistics, we have to batch to norm against! Thus if we have a batch size of 1 we normalize using our rolling averages. + +We assume the `l.rolling_mean` and `l.rolling_variance` matrices are initialized when the layer is created. + +We also have the matrix pointer `l.x` which will keep track of the input to the batch norm process. We will need to remember this for the backward step! + +### Backward propagation ### + +The backward propagation step looks like this: + + + +So to backward propagate we'll need to calculate these intermediate results, dL/dmu and dL/dsigma^2. Then, using them, we can calculate dL/dx. + +### 7.3 `delta_mean` ### + +Calculate dL/dmu. + +### 7.4 `delta_variance` ### + +Calculate dL/dsigma^2. + +### 7.5 `delta_batch_norm` ### + +Using the intermediate results, calculate dL/dx. + +### 7.6 Using your batchnorm ### + +Try out batchnorm! To add it after a layer, just make this simple change: + + make_convolutional_layer(16, 16, 8, 16, 3, 1) + make_batchnorm_layer(16) # groups parameter should be same as output channels from previous layer + make_activation_layer(RELU) + +You should be able to add it after convolutional or connected layers. The standard for batch norm is to use it at every layer except the output. First, train the `conv_net` as usual. Then try it with batchnorm. Does it do better?? + +In class we learned about annealing your learning rate to get better convergence. We ALSO learned that with batch normalization you can use larger learning rates because it's more stable. Increase the starting learning rate to `.1` and train for multiple rounds with successively smaller learning rates. Using just this model, what's the best performance you can get? + +## PyTorch Section ## + +Upload `homework2_colab.ipynb` to Colab and train a neural language model. + +## Turn it in ## + +First run the `collate_hw2.sh` script by running: + + bash collate_hw2.sh + +This will create the file `hw2.tar.gz` in your directory with all the code you need to submit. The command will check to see that your files have changed relative to the version stored in the `git` repository. If it hasn't changed, figure out why, maybe you need to download your ipynb from google? + +Submit `submit.tar.gz` in the file upload field for Homework 2 on Canvas. + diff --git a/src/batchnorm_layer.c b/src/batchnorm_layer.c new file mode 100644 index 0000000..ff12f33 --- /dev/null +++ b/src/batchnorm_layer.c @@ -0,0 +1,145 @@ +#include +#include +#include +#include "uwnet.h" + +// Take mean of matrix x over rows and spatial dimension +// matrix x: matrix with data +// int groups: number of distinct means to take, usually equal to # outputs +// after connected layers or # channels after convolutional layers +// returns: (1 x groups) matrix with means +matrix mean(matrix x, int groups) +{ + assert(x.cols % groups == 0); + matrix m = make_matrix(1, groups); + int n = x.cols / groups; + int i, j; + for(i = 0; i < x.rows; ++i){ + for(j = 0; j < x.cols; ++j){ + m.data[j/n] += x.data[i*x.cols + j]; + } + } + for(i = 0; i < m.cols; ++i){ + m.data[i] = m.data[i] / x.rows / n; + } + return m; +} + +// Take variance over matrix x given mean m +matrix variance(matrix x, matrix m, int groups) +{ + matrix v = make_matrix(1, groups); + // TODO: 7.1 - Calculate variance + return v; +} + +// Normalize x given mean m and variance v +// returns: y = (x-m)/sqrt(v + epsilon) +matrix normalize(matrix x, matrix m, matrix v, int groups) +{ + matrix norm = make_matrix(x.rows, x.cols); + // TODO: 7.2 - Normalize x + return norm; +} + + +// Run an batchnorm layer on input +// layer l: pointer to layer to run +// matrix x: input to layer +// returns: the result of running the layer y = (x - mu) / sigma +matrix forward_batchnorm_layer(layer l, matrix x) +{ + // Saving our input + // Probably don't change this + free_matrix(*l.x); + *l.x = copy_matrix(x); + + if(x.rows == 1){ + return normalize(x, l.rolling_mean, l.rolling_variance, l.channels); + } + + float s = 0.1; + matrix m = mean(x, l.channels); + matrix v = variance(x, m, l.channels); + matrix y = normalize(x, m, v, l.channels); + + scal_matrix(1-s, l.rolling_mean); + axpy_matrix(s, m, l.rolling_mean); + scal_matrix(1-s, l.rolling_variance); + axpy_matrix(s, v, l.rolling_variance); + + free_matrix(m); + free_matrix(v); + + return y; +} + +matrix delta_mean(matrix d, matrix v) +{ + int groups = v.cols; + matrix dm = make_matrix(1, groups); + // TODO 7.3 - Calculate dL/dm + return dm; +} + + +matrix delta_variance(matrix d, matrix x, matrix m, matrix v) +{ + int groups = m.cols; + matrix dv = make_matrix(1, groups); + // TODO 7.4 - Calculate dL/dv + return dv; +} + +matrix delta_batch_norm(matrix d, matrix dm, matrix dv, matrix m, matrix v, matrix x) +{ + matrix dx = make_matrix(d.rows, d.cols); + // TODO 7.5 - Calculate dL/dx + return dx; +} + + +// Run an batchnorm layer on input +// layer l: pointer to layer to run +// matrix dy: derivative of loss wrt output, dL/dy +// returns: derivative of loss wrt input, dL/dx +matrix backward_batchnorm_layer(layer l, matrix dy) +{ + matrix x = *l.x; + + matrix m = mean(x, l.channels); + matrix v = variance(x, m, l.channels); + + matrix dm = delta_mean(dy, v); + matrix dv = delta_variance(dy, x, m, v); + matrix dx = delta_batch_norm(dy, dm, dv, m, v, x); + + free_matrix(m); + free_matrix(v); + free_matrix(dm); + free_matrix(dv); + + return dx; +} + +// Update batchnorm layer..... nothing happens tho +// layer l: layer to update +// float rate: SGD learning rate +// float momentum: SGD momentum term +// float decay: l2 normalization term +void update_batchnorm_layer(layer l, float rate, float momentum, float decay){} + +layer make_batchnorm_layer(int groups) +{ + layer l = {0}; + l.channels = groups; + l.x = calloc(1, sizeof(matrix)); + + l.rolling_mean = make_matrix(1, groups); + l.rolling_variance = make_matrix(1, groups); + + l.forward = forward_batchnorm_layer; + l.backward = backward_batchnorm_layer; + l.update = update_batchnorm_layer; + return l; +} diff --git a/src/test.c b/src/test.c index b882b24..129089d 100644 --- a/src/test.c +++ b/src/test.c @@ -9,6 +9,13 @@ #include "image.h" #include "test.h" #include "args.h" +// Forward declare for tests +matrix mean(matrix x, int groups); +matrix variance(matrix x, matrix m, int groups); +matrix normalize(matrix x, matrix m, matrix v, int groups); +matrix delta_mean(matrix d, matrix v); +matrix delta_variance(matrix d, matrix x, matrix m, matrix v); +matrix delta_batch_norm(matrix d, matrix dm, matrix dv, matrix m, matrix v, matrix x); int tests_total = 0; int tests_fail = 0; @@ -310,6 +317,84 @@ void test_maxpool_layer() free_layer(max_l3); } +void test_batchnorm_layer() +{ + matrix a = load_matrix("data/test/a.matrix"); + matrix y = load_matrix("data/test/y.matrix"); + + matrix mu_a = mean(a, 64); + matrix mu_a_s = mean(a, 8); + matrix sig_a = variance(a, mu_a, 64); + matrix sig_a_s = variance(a, mu_a_s, 8); + matrix norm_a = normalize(a, mu_a, sig_a, 64); + matrix norm_a_s = normalize(a, mu_a_s, sig_a_s, 8); + + matrix truth_mu_a = load_matrix("data/test/mu_a.matrix"); + matrix truth_mu_a_s = load_matrix("data/test/mu_a_s.matrix"); + matrix truth_sig_a = load_matrix("data/test/sig_a.matrix"); + matrix truth_sig_a_s = load_matrix("data/test/sig_a_s.matrix"); + matrix truth_norm_a = load_matrix("data/test/norm_a.matrix"); + matrix truth_norm_a_s = load_matrix("data/test/norm_a_s.matrix"); + + TEST(same_matrix(truth_mu_a, mu_a)); + TEST(same_matrix(truth_mu_a_s, mu_a_s)); + TEST(same_matrix(truth_sig_a, sig_a)); + TEST(same_matrix(truth_sig_a_s, sig_a_s)); + TEST(same_matrix(truth_norm_a, norm_a)); + TEST(same_matrix(truth_norm_a_s, norm_a_s)); + + + + matrix dm = delta_mean(y, sig_a); + matrix dm_s = delta_mean(y, sig_a_s); + + matrix dv = delta_variance(y, a, mu_a, sig_a); + matrix dv_s = delta_variance(y, a, mu_a_s, sig_a_s); + + matrix dbn = delta_batch_norm(y, dm, dv, mu_a, sig_a, a); + matrix dbn_s = delta_batch_norm(y, dm_s, dv_s, mu_a_s, sig_a_s, a); + + matrix truth_dm = load_matrix("data/test/dm.matrix"); + matrix truth_dm_s = load_matrix("data/test/dm_s.matrix"); + matrix truth_dv = load_matrix("data/test/dv.matrix"); + matrix truth_dv_s = load_matrix("data/test/dv_s.matrix"); + matrix truth_dbn = load_matrix("data/test/dbn.matrix"); + matrix truth_dbn_s = load_matrix("data/test/dbn_s.matrix"); + TEST(same_matrix(truth_dm, dm)); + TEST(same_matrix(truth_dm_s, dm_s)); + TEST(same_matrix(truth_dv, dv)); + TEST(same_matrix(truth_dv_s, dv_s)); + TEST(same_matrix(truth_dbn, dbn)); + TEST(same_matrix(truth_dbn_s, dbn_s)); + + free_matrix(truth_mu_a); + free_matrix(truth_mu_a_s); + free_matrix(truth_sig_a); + free_matrix(truth_sig_a_s); + free_matrix(truth_norm_a); + free_matrix(truth_norm_a_s); + free_matrix(mu_a); + free_matrix(mu_a_s); + free_matrix(sig_a); + free_matrix(sig_a_s); + free_matrix(norm_a); + free_matrix(norm_a_s); + free_matrix(truth_dm); + free_matrix(truth_dm_s); + free_matrix(truth_dv); + free_matrix(truth_dv_s); + free_matrix(truth_dbn); + free_matrix(truth_dbn_s); + free_matrix(dm); + free_matrix(dm_s); + free_matrix(dv); + free_matrix(dv_s); + free_matrix(dbn); + free_matrix(dbn_s); + free_matrix(a); + free_matrix(y); +} + void make_matrix_test() { srand(1); @@ -418,7 +503,6 @@ void make_matrix_test() - // im2col tests //image im = load_image("data/test/dog.jpg"); @@ -443,8 +527,46 @@ void make_matrix_test() col2mat2.cols = col2im_res2.w*col2im_res2.h; col2mat2.data = col2im_res2.data; save_matrix(col2mat2, "data/test/col2mat2.matrix"); + + + + // Batch norm test + + matrix mu_a = mean(a, 64); + matrix mu_a_s = mean(a, 8); + + matrix sig_a = variance(a, mu_a, 64); + matrix sig_a_s = variance(a, mu_a_s, 8); + + matrix norm_a = normalize(a, mu_a, sig_a, 64); + matrix norm_a_s = normalize(a, mu_a_s, sig_a_s, 8); + + save_matrix(mu_a, "data/test/mu_a.matrix"); + save_matrix(mu_a_s, "data/test/mu_a_s.matrix"); + save_matrix(sig_a, "data/test/sig_a.matrix"); + save_matrix(sig_a_s, "data/test/sig_a_s.matrix"); + save_matrix(norm_a, "data/test/norm_a.matrix"); + save_matrix(norm_a_s, "data/test/norm_a_s.matrix"); + + matrix dm = delta_mean(y, sig_a); + matrix dm_s = delta_mean(y, sig_a_s); + + save_matrix(dm, "data/test/dm.matrix"); + save_matrix(dm_s, "data/test/dm_s.matrix"); + + matrix dv = delta_variance(y, a, mu_a, sig_a); + matrix dv_s = delta_variance(y, a, mu_a_s, sig_a_s); + + save_matrix(dv, "data/test/dv.matrix"); + save_matrix(dv_s, "data/test/dv_s.matrix"); + + matrix dbn = delta_batch_norm(y, dm, dv, mu_a, sig_a, a); + matrix dbn_s = delta_batch_norm(y, dm_s, dv_s, mu_a_s, sig_a_s, a); + save_matrix(dbn, "data/test/dbn.matrix"); + save_matrix(dbn_s, "data/test/dbn_s.matrix"); } + void test_matrix_speed() { int i; @@ -477,6 +599,7 @@ void run_tests() test_im2col(); test_col2im(); test_maxpool_layer(); + test_batchnorm_layer(); printf("%d tests, %d passed, %d failed\n", tests_total, tests_total-tests_fail, tests_fail); } diff --git a/src/uwnet.h b/src/uwnet.h index d0255e4..fc8c4bc 100644 --- a/src/uwnet.h +++ b/src/uwnet.h @@ -43,6 +43,7 @@ layer make_connected_layer(int inputs, int outputs); layer make_activation_layer(ACTIVATION activation); layer make_convolutional_layer(int w, int h, int c, int filters, int size, int stride); layer make_maxpool_layer(int w, int h, int c, int size, int stride); +layer make_batchnorm_layer(int groups); typedef struct { diff --git a/tryhw2.py b/tryhw2.py index c088fdd..51157e0 100644 --- a/tryhw2.py +++ b/tryhw2.py @@ -1,15 +1,12 @@ from uwnet import * def conv_net(): l = [ make_convolutional_layer(32, 32, 3, 8, 3, 2), - make_batchnorm_layer(8), make_activation_layer(RELU), make_maxpool_layer(16, 16, 8, 3, 2), make_convolutional_layer(8, 8, 8, 16, 3, 1), - make_batchnorm_layer(16), make_activation_layer(RELU), make_maxpool_layer(8, 8, 16, 3, 2), make_convolutional_layer(4, 4, 16, 32, 3, 1), - make_batchnorm_layer(32), make_activation_layer(RELU), make_connected_layer(512, 10), make_activation_layer(SOFTMAX)] diff --git a/uwnet.py b/uwnet.py index a4e8610..ca47b8f 100644 --- a/uwnet.py +++ b/uwnet.py @@ -170,10 +170,9 @@ def load_image_classification_data(images, labels): make_maxpool_layer.argtypes = [c_int, c_int, c_int, c_int, c_int] make_maxpool_layer.restype = LAYER -#def make_convolutional_layer(w, h, c, filters, size, stride, activation, batchnorm = 0): -# l = make_convolutional_layer_lib(w, h, c, filters, size, stride, activation) -# l.batchnorm = batchnorm -# return l +make_batchnorm_layer = lib.make_batchnorm_layer +make_batchnorm_layer.argtypes = [c_int] +make_batchnorm_layer.restype = LAYER save_weights_lib = lib.save_weights save_weights_lib.argtypes = [NET, c_char_p]