Skip to content

Commit

Permalink
Revert "Remove duplicate weight initialization"
Browse files Browse the repository at this point in the history
This reverts commit 87cf590.
  • Loading branch information
aleksandrskoselevs committed Aug 23, 2024
1 parent 87cf590 commit 305a055
Showing 1 changed file with 95 additions and 92 deletions.
187 changes: 95 additions & 92 deletions Notebooks/Chap09/9_5_Augmentation.ipynb
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap09/9_5_Augmentation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "el8l05WQEO46"
},
"source": [
"# **Notebook 9.5: Augmentation**\n",
"\n",
Expand All @@ -23,27 +35,25 @@
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
"\n",
"Contact me at [email protected] if you find any mistakes or have any suggestions.\n"
]
],
"metadata": {
"id": "el8l05WQEO46"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "syvgxgRr3myY"
},
"outputs": [],
"source": [
"# Run this if you're in a Colab to install MNIST 1D repository\n",
"!pip install git+https://github.com/greydanus/mnist1d"
]
],
"metadata": {
"id": "syvgxgRr3myY"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ckrNsYd13pMe"
},
"outputs": [],
"source": [
"import torch, torch.nn as nn\n",
"from torch.utils.data import TensorDataset, DataLoader\n",
Expand All @@ -52,15 +62,15 @@
"import matplotlib.pyplot as plt\n",
"import mnist1d\n",
"import random"
]
],
"metadata": {
"id": "ckrNsYd13pMe"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "D_Woo9U730lZ"
},
"outputs": [],
"source": [
"args = mnist1d.data.get_dataset_args()\n",
"data = mnist1d.data.get_dataset(args, path='./mnist1d_data.pkl', download=False, regenerate=False)\n",
Expand All @@ -70,15 +80,15 @@
"print(\"Examples in training set: {}\".format(len(data['y'])))\n",
"print(\"Examples in test set: {}\".format(len(data['y_test'])))\n",
"print(\"Length of each example: {}\".format(data['x'].shape[-1]))"
]
],
"metadata": {
"id": "D_Woo9U730lZ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JfIFWFIL33eF"
},
"outputs": [],
"source": [
"D_i = 40 # Input dimensions\n",
"D_k = 200 # Hidden dimensions\n",
Expand All @@ -99,17 +109,17 @@
" nn.init.kaiming_uniform_(layer_in.weight)\n",
" layer_in.bias.data.fill_(0.0)\n",
"\n",
"# Initialize model weights\n",
"# Call the function you just defined\n",
"model.apply(weights_init)"
]
],
"metadata": {
"id": "JfIFWFIL33eF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YFfVbTPE4BkJ"
},
"outputs": [],
"source": [
"# choose cross entropy loss function (equation 5.24)\n",
"loss_function = torch.nn.CrossEntropyLoss()\n",
Expand All @@ -126,6 +136,9 @@
"# load the data into a class that creates the batches\n",
"data_loader = DataLoader(TensorDataset(x_train,y_train), batch_size=100, shuffle=True, worker_init_fn=np.random.seed(1))\n",
"\n",
"# Initialize model weights\n",
"model.apply(weights_init)\n",
"\n",
"# loop over the dataset n_epoch times\n",
"n_epoch = 50\n",
"# store the loss and the % correct at each epoch\n",
Expand Down Expand Up @@ -156,15 +169,15 @@
" errors_train[epoch] = 100 - 100 * (predicted_train_class == y_train).float().sum() / len(y_train)\n",
" errors_test[epoch]= 100 - 100 * (predicted_test_class == y_test).float().sum() / len(y_test)\n",
" print(f'Epoch {epoch:5d}, train error {errors_train[epoch]:3.2f}, test error {errors_test[epoch]:3.2f}')"
]
],
"metadata": {
"id": "YFfVbTPE4BkJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FmGDd4vB8LyM"
},
"outputs": [],
"source": [
"# Plot the results\n",
"fig, ax = plt.subplots()\n",
Expand All @@ -175,24 +188,24 @@
"ax.set_title('Train Error %3.2f, Test Error %3.2f'%(errors_train[-1],errors_test[-1]))\n",
"ax.legend()\n",
"plt.show()"
]
],
"metadata": {
"id": "FmGDd4vB8LyM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "55XvoPDO8Qp-"
},
"source": [
"The best test performance is about 33%. Let's see if we can improve on that by augmenting the data."
]
],
"metadata": {
"id": "55XvoPDO8Qp-"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IP6z2iox8MOF"
},
"outputs": [],
"source": [
"def augment(input_vector):\n",
" # Create output vector\n",
Expand All @@ -208,15 +221,15 @@
" data_out = np.array(data_out)\n",
"\n",
" return data_out"
]
],
"metadata": {
"id": "IP6z2iox8MOF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bzN0lu5J95AJ"
},
"outputs": [],
"source": [
"n_data_orig = data['x'].shape[0]\n",
"# We'll double the amount of data\n",
Expand All @@ -234,15 +247,15 @@
" # Augment the point and store\n",
" augmented_x[c_augment,:] = augment(data['x'][random_data_index,:])\n",
" augmented_y[c_augment] = data['y'][random_data_index]\n"
]
],
"metadata": {
"id": "bzN0lu5J95AJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hZUNrXpS_kRs"
},
"outputs": [],
"source": [
"# choose cross entropy loss function (equation 5.24)\n",
"loss_function = torch.nn.CrossEntropyLoss()\n",
Expand Down Expand Up @@ -292,15 +305,15 @@
" errors_train_aug[epoch] = 100 - 100 * (predicted_train_class == y_train).float().sum() / len(y_train)\n",
" errors_test_aug[epoch]= 100 - 100 * (predicted_test_class == y_test).float().sum() / len(y_test)\n",
" print(f'Epoch {epoch:5d}, train error {errors_train_aug[epoch]:3.2f}, test error {errors_test_aug[epoch]:3.2f}')"
]
],
"metadata": {
"id": "hZUNrXpS_kRs"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IcnAW4ixBnuc"
},
"outputs": [],
"source": [
"# Plot the results\n",
"fig, ax = plt.subplots()\n",
Expand All @@ -312,31 +325,21 @@
"ax.set_title('TrainError %3.2f, Test Error %3.2f'%(errors_train_aug[-1],errors_test_aug[-1]))\n",
"ax.legend()\n",
"plt.show()"
]
],
"metadata": {
"id": "IcnAW4ixBnuc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "jgsR7ScJHc9b"
},
"source": [
"Hopefully, you should see an improvement in performance when we augment the data."
]
}
],
"metadata": {
"colab": {
"include_colab_link": true,
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
],
"metadata": {
"id": "jgsR7ScJHc9b"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
]
}

0 comments on commit 305a055

Please sign in to comment.