diff --git a/Notebooks/Chap08/8_3_Double_Descent.ipynb b/Notebooks/Chap08/8_3_Double_Descent.ipynb
index f60f12f1..aaac87e5 100644
--- a/Notebooks/Chap08/8_3_Double_Descent.ipynb
+++ b/Notebooks/Chap08/8_3_Double_Descent.ipynb
@@ -4,8 +4,7 @@
"metadata": {
"colab": {
"provenance": [],
- "gpuType": "T4",
- "include_colab_link": true
+ "gpuType": "T4"
},
"kernelspec": {
"name": "python3",
@@ -17,16 +16,6 @@
"accelerator": "GPU"
},
"cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "view-in-github",
- "colab_type": "text"
- },
- "source": [
- ""
- ]
- },
{
"cell_type": "markdown",
"source": [
@@ -51,10 +40,45 @@
"!pip install git+https://github.com/greydanus/mnist1d"
],
"metadata": {
- "id": "fn9BP5N5TguP"
+ "id": "fn9BP5N5TguP",
+ "outputId": "3ba15b8f-2395-4b1a-8a66-8ed80e9b5138",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
},
- "execution_count": null,
- "outputs": []
+ "execution_count": 29,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting git+https://github.com/greydanus/mnist1d\n",
+ " Cloning https://github.com/greydanus/mnist1d to /tmp/pip-req-build-cbhxd1j8\n",
+ " Running command git clone --filter=blob:none --quiet https://github.com/greydanus/mnist1d /tmp/pip-req-build-cbhxd1j8\n",
+ " Resolved https://github.com/greydanus/mnist1d to commit 350929d12f4c9a4b7355e0c96604e41b9239bdb4\n",
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from mnist1d==0.0.2.post9) (2.31.0)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from mnist1d==0.0.2.post9) (1.25.2)\n",
+ "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from mnist1d==0.0.2.post9) (3.7.1)\n",
+ "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from mnist1d==0.0.2.post9) (1.11.4)\n",
+ "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mnist1d==0.0.2.post9) (1.2.1)\n",
+ "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mnist1d==0.0.2.post9) (0.12.1)\n",
+ "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mnist1d==0.0.2.post9) (4.53.0)\n",
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mnist1d==0.0.2.post9) (1.4.5)\n",
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mnist1d==0.0.2.post9) (24.1)\n",
+ "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mnist1d==0.0.2.post9) (9.4.0)\n",
+ "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mnist1d==0.0.2.post9) (3.1.2)\n",
+ "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mnist1d==0.0.2.post9) (2.8.2)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->mnist1d==0.0.2.post9) (3.3.2)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->mnist1d==0.0.2.post9) (3.7)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->mnist1d==0.0.2.post9) (2.0.7)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->mnist1d==0.0.2.post9) (2024.6.2)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->mnist1d==0.0.2.post9) (1.16.0)\n"
+ ]
+ }
+ ]
},
{
"cell_type": "code",
@@ -73,10 +97,22 @@
"print('Using:', DEVICE)"
],
"metadata": {
- "id": "hFxuHpRqTgri"
+ "id": "hFxuHpRqTgri",
+ "outputId": "2a246d1e-8568-4605-df0a-28dad08935ec",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
},
- "execution_count": null,
- "outputs": []
+ "execution_count": 30,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Using: cuda\n"
+ ]
+ }
+ ]
},
{
"cell_type": "code",
@@ -99,13 +135,28 @@
"# data['x'], data['y'], data['x_test'], and data['y_test']\n",
"print(\"Examples in training set: {}\".format(len(data['y'])))\n",
"print(\"Examples in test set: {}\".format(len(data['y_test'])))\n",
- "print(\"Length of each example: {}\".format(data['x'].shape[-1]))"
+ "print(\"Dimensionality of each example: {}\".format(data['x'].shape[-1]))"
],
"metadata": {
- "id": "PW2gyXL5UkLU"
+ "id": "PW2gyXL5UkLU",
+ "outputId": "a4186965-3025-4d00-b780-441baab421de",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
},
- "execution_count": null,
- "outputs": []
+ "execution_count": 31,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Did or could not load data from ./mnist1d_data.pkl. Rebuilding dataset...\n",
+ "Examples in training set: 4000\n",
+ "Examples in test set: 4000\n",
+ "Dimensionality of each example: 40\n"
+ ]
+ }
+ ]
},
{
"cell_type": "code",
@@ -141,13 +192,13 @@
"metadata": {
"id": "hAIvZOAlTnk9"
},
- "execution_count": null,
+ "execution_count": 32,
"outputs": []
},
{
"cell_type": "code",
"source": [
- "def fit_model(model, data):\n",
+ "def fit_model(model, data, n_epoch):\n",
"\n",
" # choose cross entropy loss function (equation 5.24)\n",
" loss_function = torch.nn.CrossEntropyLoss()\n",
@@ -164,9 +215,6 @@
" # load the data into a class that creates the batches\n",
" data_loader = DataLoader(TensorDataset(x_train,y_train), batch_size=100, shuffle=True, worker_init_fn=np.random.seed(1))\n",
"\n",
- " # loop over the dataset n_epoch times\n",
- " n_epoch = 1000\n",
- "\n",
" for epoch in range(n_epoch):\n",
" # loop over batches\n",
" for i, batch in enumerate(data_loader):\n",
@@ -200,7 +248,19 @@
"metadata": {
"id": "AazlQhheWmHk"
},
- "execution_count": null,
+ "execution_count": 33,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def count_parameters(model):\n",
+ " return sum(p.numel() for p in model.parameters() if p.requires_grad)"
+ ],
+ "metadata": {
+ "id": "AQNCmFNV6JpV"
+ },
+ "execution_count": 34,
"outputs": []
},
{
@@ -226,42 +286,333 @@
"# This code will take a while (~30 mins on GPU) to run! Go and make a cup of coffee!\n",
"\n",
"hidden_variables = np.array([2,4,6,8,10,14,18,22,26,30,35,40,45,50,55,60,70,80,90,100,120,140,160,180,200,250,300,400]) ;\n",
+ "\n",
"errors_train_all = np.zeros_like(hidden_variables)\n",
"errors_test_all = np.zeros_like(hidden_variables)\n",
+ "total_weights_all = np.zeros_like(hidden_variables)\n",
+ "\n",
+ "# loop over the dataset n_epoch times\n",
+ "n_epoch = 1000\n",
"\n",
"# For each hidden variable size\n",
"for c_hidden in range(len(hidden_variables)):\n",
" print(f'Training model with {hidden_variables[c_hidden]:3d} hidden variables')\n",
" # Get a model\n",
" model = get_model(hidden_variables[c_hidden]) ;\n",
+ " # Count and store number of weights\n",
+ " total_weights_all[c_hidden] = count_parameters(model)\n",
" # Train the model\n",
- " errors_train, errors_test = fit_model(model, data)\n",
+ " errors_train, errors_test = fit_model(model, data, n_epoch)\n",
" # Store the results\n",
" errors_train_all[c_hidden] = errors_train\n",
- " errors_test_all[c_hidden]= errors_test"
+ " errors_test_all[c_hidden]= errors_test\n",
+ "\n",
+ ""
],
"metadata": {
- "id": "K4OmBZGHWXpk"
+ "id": "K4OmBZGHWXpk",
+ "outputId": "62a7aaf0-793a-4ab2-960e-c1127f15717c",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
},
- "execution_count": null,
- "outputs": []
+ "execution_count": 35,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Training model with 2 hidden variables\n",
+ "Epoch 0, train loss 2.280202, train error 89.05, test loss 2.276196, test error 89.18\n",
+ "Epoch 100, train loss 1.925899, train error 71.75, test loss 1.761014, test error 70.57\n",
+ "Epoch 200, train loss 1.912995, train error 71.20, test loss 1.747572, test error 69.07\n",
+ "Epoch 300, train loss 1.908751, train error 71.32, test loss 1.741206, test error 69.07\n",
+ "Epoch 400, train loss 1.906859, train error 71.05, test loss 1.737478, test error 70.00\n",
+ "Epoch 500, train loss 1.906671, train error 71.43, test loss 1.732123, test error 68.95\n",
+ "Epoch 600, train loss 1.903380, train error 71.28, test loss 1.741086, test error 69.65\n",
+ "Epoch 700, train loss 1.902236, train error 70.88, test loss 1.733666, test error 69.22\n",
+ "Epoch 800, train loss 1.903633, train error 71.88, test loss 1.734591, test error 70.57\n",
+ "Epoch 900, train loss 1.901804, train error 71.03, test loss 1.742573, test error 68.10\n",
+ "Training model with 4 hidden variables\n",
+ "Epoch 0, train loss 2.267358, train error 86.97, test loss 2.269902, test error 87.62\n",
+ "Epoch 100, train loss 1.716064, train error 62.22, test loss 1.519437, test error 58.25\n",
+ "Epoch 200, train loss 1.705652, train error 61.88, test loss 1.507422, test error 58.40\n",
+ "Epoch 300, train loss 1.696660, train error 60.62, test loss 1.507704, test error 57.58\n",
+ "Epoch 400, train loss 1.691092, train error 61.60, test loss 1.499592, test error 56.85\n",
+ "Epoch 500, train loss 1.684060, train error 61.05, test loss 1.513196, test error 57.95\n",
+ "Epoch 600, train loss 1.677890, train error 61.10, test loss 1.504343, test error 56.85\n",
+ "Epoch 700, train loss 1.677930, train error 61.10, test loss 1.509796, test error 58.15\n",
+ "Epoch 800, train loss 1.676683, train error 61.17, test loss 1.506790, test error 58.28\n",
+ "Epoch 900, train loss 1.676809, train error 61.53, test loss 1.505417, test error 57.72\n",
+ "Training model with 6 hidden variables\n",
+ "Epoch 0, train loss 2.290775, train error 89.18, test loss 2.284365, test error 87.57\n",
+ "Epoch 100, train loss 1.634348, train error 58.42, test loss 1.467763, test error 55.67\n",
+ "Epoch 200, train loss 1.602023, train error 57.25, test loss 1.456007, test error 55.38\n",
+ "Epoch 300, train loss 1.587963, train error 56.50, test loss 1.453522, test error 55.33\n",
+ "Epoch 400, train loss 1.587952, train error 56.55, test loss 1.453863, test error 54.92\n",
+ "Epoch 500, train loss 1.579495, train error 56.42, test loss 1.452398, test error 55.40\n",
+ "Epoch 600, train loss 1.590881, train error 56.67, test loss 1.473001, test error 56.28\n",
+ "Epoch 700, train loss 1.581218, train error 56.58, test loss 1.470745, test error 55.90\n",
+ "Epoch 800, train loss 1.575105, train error 56.30, test loss 1.455353, test error 55.92\n",
+ "Epoch 900, train loss 1.579539, train error 56.15, test loss 1.449462, test error 55.67\n",
+ "Training model with 8 hidden variables\n",
+ "Epoch 0, train loss 2.278947, train error 85.18, test loss 2.266557, test error 84.68\n",
+ "Epoch 100, train loss 1.593652, train error 55.60, test loss 1.449531, test error 56.00\n",
+ "Epoch 200, train loss 1.557044, train error 53.92, test loss 1.453447, test error 55.03\n",
+ "Epoch 300, train loss 1.544829, train error 53.38, test loss 1.457726, test error 55.22\n",
+ "Epoch 400, train loss 1.538902, train error 54.05, test loss 1.456941, test error 55.33\n",
+ "Epoch 500, train loss 1.538077, train error 54.28, test loss 1.469433, test error 56.20\n",
+ "Epoch 600, train loss 1.529778, train error 53.60, test loss 1.470659, test error 55.72\n",
+ "Epoch 700, train loss 1.522470, train error 53.42, test loss 1.476947, test error 55.95\n",
+ "Epoch 800, train loss 1.520377, train error 53.22, test loss 1.481796, test error 56.55\n",
+ "Epoch 900, train loss 1.522042, train error 53.42, test loss 1.489359, test error 56.15\n",
+ "Training model with 10 hidden variables\n",
+ "Epoch 0, train loss 2.250067, train error 83.80, test loss 2.238750, test error 81.28\n",
+ "Epoch 100, train loss 1.570056, train error 55.80, test loss 1.444564, test error 54.62\n",
+ "Epoch 200, train loss 1.512308, train error 53.33, test loss 1.417375, test error 53.70\n",
+ "Epoch 300, train loss 1.485546, train error 52.62, test loss 1.415674, test error 53.62\n",
+ "Epoch 400, train loss 1.465965, train error 50.22, test loss 1.415631, test error 52.95\n",
+ "Epoch 500, train loss 1.450910, train error 50.50, test loss 1.412722, test error 52.70\n",
+ "Epoch 600, train loss 1.437163, train error 49.65, test loss 1.409063, test error 52.92\n",
+ "Epoch 700, train loss 1.430005, train error 49.38, test loss 1.402299, test error 52.83\n",
+ "Epoch 800, train loss 1.420938, train error 49.17, test loss 1.398151, test error 52.22\n",
+ "Epoch 900, train loss 1.417372, train error 49.25, test loss 1.407553, test error 52.58\n",
+ "Training model with 14 hidden variables\n",
+ "Epoch 0, train loss 2.250147, train error 84.45, test loss 2.224496, test error 83.50\n",
+ "Epoch 100, train loss 1.467325, train error 49.22, test loss 1.371005, test error 51.35\n",
+ "Epoch 200, train loss 1.371697, train error 46.47, test loss 1.341893, test error 49.60\n",
+ "Epoch 300, train loss 1.338310, train error 45.25, test loss 1.345958, test error 49.47\n",
+ "Epoch 400, train loss 1.329984, train error 44.95, test loss 1.364456, test error 49.85\n",
+ "Epoch 500, train loss 1.312067, train error 44.70, test loss 1.368620, test error 49.67\n",
+ "Epoch 600, train loss 1.307369, train error 44.00, test loss 1.383248, test error 49.97\n",
+ "Epoch 700, train loss 1.302358, train error 44.10, test loss 1.384645, test error 49.78\n",
+ "Epoch 800, train loss 1.309552, train error 44.42, test loss 1.399194, test error 50.60\n",
+ "Epoch 900, train loss 1.305113, train error 43.70, test loss 1.416064, test error 51.28\n",
+ "Training model with 18 hidden variables\n",
+ "Epoch 0, train loss 2.159715, train error 80.35, test loss 2.139665, test error 80.20\n",
+ "Epoch 100, train loss 1.385202, train error 47.75, test loss 1.362450, test error 51.58\n",
+ "Epoch 200, train loss 1.265389, train error 43.55, test loss 1.335479, test error 49.20\n",
+ "Epoch 300, train loss 1.224858, train error 41.67, test loss 1.384820, test error 48.47\n",
+ "Epoch 400, train loss 1.186068, train error 40.70, test loss 1.428581, test error 49.33\n",
+ "Epoch 500, train loss 1.167748, train error 39.30, test loss 1.472766, test error 49.95\n",
+ "Epoch 600, train loss 1.160923, train error 38.45, test loss 1.499211, test error 50.65\n",
+ "Epoch 700, train loss 1.145295, train error 38.40, test loss 1.528852, test error 50.67\n",
+ "Epoch 800, train loss 1.141608, train error 38.08, test loss 1.569496, test error 51.33\n",
+ "Epoch 900, train loss 1.132156, train error 38.22, test loss 1.602549, test error 51.35\n",
+ "Training model with 22 hidden variables\n",
+ "Epoch 0, train loss 2.211529, train error 79.80, test loss 2.206877, test error 79.95\n",
+ "Epoch 100, train loss 1.282252, train error 42.90, test loss 1.323739, test error 48.38\n",
+ "Epoch 200, train loss 1.161285, train error 38.65, test loss 1.423607, test error 50.78\n",
+ "Epoch 300, train loss 1.103865, train error 36.65, test loss 1.482603, test error 50.70\n",
+ "Epoch 400, train loss 1.070859, train error 36.42, test loss 1.543965, test error 51.38\n",
+ "Epoch 500, train loss 1.044615, train error 34.82, test loss 1.574064, test error 51.38\n",
+ "Epoch 600, train loss 1.030322, train error 34.82, test loss 1.633871, test error 51.47\n",
+ "Epoch 700, train loss 1.016668, train error 33.93, test loss 1.711600, test error 52.30\n",
+ "Epoch 800, train loss 0.997430, train error 32.82, test loss 1.763077, test error 52.95\n",
+ "Epoch 900, train loss 1.001320, train error 33.82, test loss 1.814692, test error 52.95\n",
+ "Training model with 26 hidden variables\n",
+ "Epoch 0, train loss 2.200904, train error 81.43, test loss 2.190686, test error 80.72\n",
+ "Epoch 100, train loss 1.189043, train error 39.80, test loss 1.398701, test error 51.97\n",
+ "Epoch 200, train loss 1.032893, train error 34.72, test loss 1.504426, test error 52.15\n",
+ "Epoch 300, train loss 0.975272, train error 33.55, test loss 1.621815, test error 53.30\n",
+ "Epoch 400, train loss 0.939880, train error 32.05, test loss 1.739801, test error 54.03\n",
+ "Epoch 500, train loss 0.907610, train error 30.85, test loss 1.822161, test error 54.30\n",
+ "Epoch 600, train loss 0.893433, train error 30.55, test loss 1.929679, test error 54.80\n",
+ "Epoch 700, train loss 0.864689, train error 29.20, test loss 1.968643, test error 54.75\n",
+ "Epoch 800, train loss 0.847545, train error 28.82, test loss 2.045258, test error 54.97\n",
+ "Epoch 900, train loss 0.843597, train error 28.97, test loss 2.113636, test error 55.55\n",
+ "Training model with 30 hidden variables\n",
+ "Epoch 0, train loss 2.105095, train error 78.75, test loss 2.046595, test error 77.50\n",
+ "Epoch 100, train loss 1.139368, train error 36.85, test loss 1.384092, test error 50.33\n",
+ "Epoch 200, train loss 0.948281, train error 31.70, test loss 1.521346, test error 50.53\n",
+ "Epoch 300, train loss 0.866517, train error 28.10, test loss 1.689052, test error 51.30\n",
+ "Epoch 400, train loss 0.822552, train error 26.95, test loss 1.871116, test error 53.58\n",
+ "Epoch 500, train loss 0.793155, train error 26.53, test loss 2.016083, test error 53.45\n",
+ "Epoch 600, train loss 0.747390, train error 25.15, test loss 2.175583, test error 54.60\n",
+ "Epoch 700, train loss 0.739243, train error 24.82, test loss 2.375611, test error 55.90\n",
+ "Epoch 800, train loss 0.705187, train error 23.97, test loss 2.521607, test error 56.12\n",
+ "Epoch 900, train loss 0.678345, train error 22.30, test loss 2.623145, test error 56.05\n",
+ "Training model with 35 hidden variables\n",
+ "Epoch 0, train loss 2.133684, train error 77.95, test loss 2.105965, test error 78.03\n",
+ "Epoch 100, train loss 1.025999, train error 34.07, test loss 1.387318, test error 49.42\n",
+ "Epoch 200, train loss 0.801432, train error 27.22, test loss 1.734866, test error 52.33\n",
+ "Epoch 300, train loss 0.665120, train error 22.30, test loss 2.111001, test error 54.28\n",
+ "Epoch 400, train loss 0.603412, train error 20.30, test loss 2.516824, test error 55.53\n",
+ "Epoch 500, train loss 0.557454, train error 19.50, test loss 2.875285, test error 55.90\n",
+ "Epoch 600, train loss 0.541823, train error 18.70, test loss 3.230060, test error 56.78\n",
+ "Epoch 700, train loss 0.489021, train error 16.43, test loss 3.542463, test error 56.67\n",
+ "Epoch 800, train loss 0.487397, train error 16.50, test loss 3.868127, test error 56.67\n",
+ "Epoch 900, train loss 0.452621, train error 15.62, test loss 4.107164, test error 57.90\n",
+ "Training model with 40 hidden variables\n",
+ "Epoch 0, train loss 2.131664, train error 78.45, test loss 2.104819, test error 79.57\n",
+ "Epoch 100, train loss 0.916092, train error 29.90, test loss 1.443106, test error 50.62\n",
+ "Epoch 200, train loss 0.637736, train error 20.70, test loss 1.931214, test error 52.22\n",
+ "Epoch 300, train loss 0.477490, train error 15.45, test loss 2.592646, test error 54.35\n",
+ "Epoch 400, train loss 0.375117, train error 11.97, test loss 3.317954, test error 55.75\n",
+ "Epoch 500, train loss 0.320583, train error 10.82, test loss 4.105999, test error 56.25\n",
+ "Epoch 600, train loss 0.282808, train error 9.62, test loss 4.787386, test error 56.17\n",
+ "Epoch 700, train loss 0.270115, train error 9.78, test loss 5.615366, test error 57.00\n",
+ "Epoch 800, train loss 0.271310, train error 10.05, test loss 6.246390, test error 56.17\n",
+ "Epoch 900, train loss 0.298141, train error 11.35, test loss 6.893893, test error 56.62\n",
+ "Training model with 45 hidden variables\n",
+ "Epoch 0, train loss 2.124081, train error 78.05, test loss 2.105962, test error 77.93\n",
+ "Epoch 100, train loss 0.826444, train error 26.68, test loss 1.503941, test error 49.47\n",
+ "Epoch 200, train loss 0.488764, train error 15.75, test loss 2.213665, test error 52.97\n",
+ "Epoch 300, train loss 0.339090, train error 11.00, test loss 3.221352, test error 54.25\n",
+ "Epoch 400, train loss 0.239335, train error 8.18, test loss 4.424680, test error 55.75\n",
+ "Epoch 500, train loss 0.163451, train error 4.93, test loss 5.617220, test error 55.38\n",
+ "Epoch 600, train loss 0.099984, train error 3.05, test loss 6.670816, test error 55.85\n",
+ "Epoch 700, train loss 0.015253, train error 0.00, test loss 7.652980, test error 55.70\n",
+ "Epoch 800, train loss 0.008259, train error 0.00, test loss 8.460521, test error 55.85\n",
+ "Epoch 900, train loss 0.005923, train error 0.00, test loss 8.991735, test error 55.92\n",
+ "Training model with 50 hidden variables\n",
+ "Epoch 0, train loss 2.164114, train error 78.28, test loss 2.153555, test error 78.55\n",
+ "Epoch 100, train loss 0.760776, train error 23.62, test loss 1.521621, test error 49.42\n",
+ "Epoch 200, train loss 0.388589, train error 11.07, test loss 2.433003, test error 52.42\n",
+ "Epoch 300, train loss 0.212128, train error 5.60, test loss 3.755425, test error 54.70\n",
+ "Epoch 400, train loss 0.101783, train error 2.12, test loss 5.258374, test error 55.00\n",
+ "Epoch 500, train loss 0.024106, train error 0.03, test loss 6.485045, test error 56.25\n",
+ "Epoch 600, train loss 0.010166, train error 0.00, test loss 7.247528, test error 56.03\n",
+ "Epoch 700, train loss 0.007013, train error 0.00, test loss 7.764050, test error 56.17\n",
+ "Epoch 800, train loss 0.005290, train error 0.00, test loss 8.180995, test error 56.25\n",
+ "Epoch 900, train loss 0.004218, train error 0.00, test loss 8.507399, test error 56.30\n",
+ "Training model with 60 hidden variables\n",
+ "Epoch 0, train loss 2.054841, train error 76.82, test loss 2.010036, test error 77.35\n",
+ "Epoch 100, train loss 0.533140, train error 15.78, test loss 1.685822, test error 50.55\n",
+ "Epoch 200, train loss 0.092306, train error 0.62, test loss 3.330975, test error 53.12\n",
+ "Epoch 300, train loss 0.018351, train error 0.00, test loss 4.631015, test error 54.05\n",
+ "Epoch 400, train loss 0.008906, train error 0.00, test loss 5.276160, test error 54.12\n",
+ "Epoch 500, train loss 0.005695, train error 0.00, test loss 5.702274, test error 54.15\n",
+ "Epoch 600, train loss 0.004078, train error 0.00, test loss 6.026574, test error 54.28\n",
+ "Epoch 700, train loss 0.003147, train error 0.00, test loss 6.281788, test error 54.22\n",
+ "Epoch 800, train loss 0.002537, train error 0.00, test loss 6.493565, test error 54.30\n",
+ "Epoch 900, train loss 0.002114, train error 0.00, test loss 6.669685, test error 54.45\n",
+ "Training model with 80 hidden variables\n",
+ "Epoch 0, train loss 2.028571, train error 74.40, test loss 1.985185, test error 75.28\n",
+ "Epoch 100, train loss 0.244489, train error 4.47, test loss 2.115000, test error 51.50\n",
+ "Epoch 200, train loss 0.020814, train error 0.00, test loss 3.415434, test error 51.40\n",
+ "Epoch 300, train loss 0.008094, train error 0.00, test loss 3.970142, test error 51.72\n",
+ "Epoch 400, train loss 0.004785, train error 0.00, test loss 4.299542, test error 51.83\n",
+ "Epoch 500, train loss 0.003310, train error 0.00, test loss 4.534029, test error 51.70\n",
+ "Epoch 600, train loss 0.002490, train error 0.00, test loss 4.716822, test error 51.78\n",
+ "Epoch 700, train loss 0.001978, train error 0.00, test loss 4.864742, test error 51.75\n",
+ "Epoch 800, train loss 0.001629, train error 0.00, test loss 4.988203, test error 51.80\n",
+ "Epoch 900, train loss 0.001378, train error 0.00, test loss 5.096081, test error 51.78\n",
+ "Training model with 120 hidden variables\n",
+ "Epoch 0, train loss 1.962679, train error 72.30, test loss 1.879106, test error 71.68\n",
+ "Epoch 100, train loss 0.048203, train error 0.00, test loss 2.235627, test error 49.70\n",
+ "Epoch 200, train loss 0.009333, train error 0.00, test loss 2.851372, test error 49.85\n",
+ "Epoch 300, train loss 0.004629, train error 0.00, test loss 3.129520, test error 49.95\n",
+ "Epoch 400, train loss 0.002970, train error 0.00, test loss 3.309089, test error 49.83\n",
+ "Epoch 500, train loss 0.002143, train error 0.00, test loss 3.442119, test error 49.92\n",
+ "Epoch 600, train loss 0.001658, train error 0.00, test loss 3.548133, test error 49.80\n",
+ "Epoch 700, train loss 0.001341, train error 0.00, test loss 3.635033, test error 49.72\n",
+ "Epoch 800, train loss 0.001120, train error 0.00, test loss 3.709674, test error 49.72\n",
+ "Epoch 900, train loss 0.000957, train error 0.00, test loss 3.774692, test error 49.65\n",
+ "Training model with 200 hidden variables\n",
+ "Epoch 0, train loss 1.913651, train error 70.00, test loss 1.842174, test error 72.20\n",
+ "Epoch 100, train loss 0.021130, train error 0.00, test loss 1.918857, test error 46.60\n",
+ "Epoch 200, train loss 0.006089, train error 0.00, test loss 2.209828, test error 46.88\n",
+ "Epoch 300, train loss 0.003314, train error 0.00, test loss 2.356240, test error 46.92\n",
+ "Epoch 400, train loss 0.002209, train error 0.00, test loss 2.455862, test error 46.85\n",
+ "Epoch 500, train loss 0.001633, train error 0.00, test loss 2.532098, test error 46.70\n",
+ "Epoch 600, train loss 0.001284, train error 0.00, test loss 2.592178, test error 46.67\n",
+ "Epoch 700, train loss 0.001051, train error 0.00, test loss 2.644237, test error 46.80\n",
+ "Epoch 800, train loss 0.000885, train error 0.00, test loss 2.687477, test error 46.78\n",
+ "Epoch 900, train loss 0.000762, train error 0.00, test loss 2.726535, test error 46.72\n",
+ "Training model with 360 hidden variables\n",
+ "Epoch 0, train loss 1.881302, train error 67.10, test loss 1.826783, test error 69.07\n",
+ "Epoch 100, train loss 0.012637, train error 0.00, test loss 1.789455, test error 46.08\n",
+ "Epoch 200, train loss 0.004398, train error 0.00, test loss 1.978857, test error 45.70\n",
+ "Epoch 300, train loss 0.002526, train error 0.00, test loss 2.081702, test error 45.45\n",
+ "Epoch 400, train loss 0.001732, train error 0.00, test loss 2.153348, test error 45.22\n",
+ "Epoch 500, train loss 0.001301, train error 0.00, test loss 2.208247, test error 45.28\n",
+ "Epoch 600, train loss 0.001034, train error 0.00, test loss 2.252916, test error 45.28\n",
+ "Epoch 700, train loss 0.000853, train error 0.00, test loss 2.290226, test error 45.40\n",
+ "Epoch 800, train loss 0.000724, train error 0.00, test loss 2.322630, test error 45.42\n",
+ "Epoch 900, train loss 0.000626, train error 0.00, test loss 2.351381, test error 45.33\n",
+ "Training model with 680 hidden variables\n",
+ "Epoch 0, train loss 1.808186, train error 63.28, test loss 1.783419, test error 66.82\n",
+ "Epoch 100, train loss 0.008044, train error 0.00, test loss 1.678780, test error 44.55\n",
+ "Epoch 200, train loss 0.003155, train error 0.00, test loss 1.810078, test error 44.47\n",
+ "Epoch 300, train loss 0.001885, train error 0.00, test loss 1.886664, test error 44.42\n",
+ "Epoch 400, train loss 0.001322, train error 0.00, test loss 1.939946, test error 44.35\n",
+ "Epoch 500, train loss 0.001008, train error 0.00, test loss 1.982137, test error 44.45\n",
+ "Epoch 600, train loss 0.000809, train error 0.00, test loss 2.015580, test error 44.45\n",
+ "Epoch 700, train loss 0.000673, train error 0.00, test loss 2.044045, test error 44.28\n",
+ "Epoch 800, train loss 0.000574, train error 0.00, test loss 2.068916, test error 44.28\n",
+ "Epoch 900, train loss 0.000499, train error 0.00, test loss 2.090607, test error 44.28\n"
+ ]
+ }
+ ]
},
{
"cell_type": "code",
"source": [
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "\n",
+ "# Assuming data['y'] is available and contains the training examples\n",
+ "num_training_examples = len(data['y'])\n",
+ "\n",
+ "# Find the index where total_weights_all is closest to num_training_examples\n",
+ "closest_index = np.argmin(np.abs(np.array(total_weights_all) - num_training_examples))\n",
+ "\n",
+ "# Get the corresponding value of hidden variables\n",
+ "hidden_variable_at_num_training_examples = hidden_variables[closest_index]\n",
+ "\n",
"# Plot the results\n",
"fig, ax = plt.subplots()\n",
- "ax.plot(hidden_variables, errors_train_all,'r-',label='train')\n",
- "ax.plot(hidden_variables, errors_test_all,'b-',label='test')\n",
- "ax.set_ylim(0,100);\n",
- "ax.set_xlabel('No hidden variables'); ax.set_ylabel('Error')\n",
+ "ax.plot(hidden_variables, errors_train_all, 'r-', label='train')\n",
+ "ax.plot(hidden_variables, errors_test_all, 'b-', label='test')\n",
+ "\n",
+ "# Add a vertical line at the point where total weights equal the number of training examples\n",
+ "ax.axvline(x=hidden_variable_at_num_training_examples, color='g', linestyle='--', label='N(weights) = N(train)')\n",
+ "\n",
+ "ax.set_ylim(0, 100)\n",
+ "ax.set_xlabel('No. hidden variables')\n",
+ "ax.set_ylabel('Error')\n",
"ax.legend()\n",
"plt.show()\n"
],
"metadata": {
- "id": "Rw-iRboTXbck"
+ "id": "Rw-iRboTXbck",
+ "outputId": "4b2ec111-9a74-4e10-821c-6f5129c4e109",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 455
+ }
+ },
+ "execution_count": 36,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "