forked from udlbook/udlbook
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
246 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"provenance": [], | ||
"authorship_tag": "ABX9TyNyLnpoXgKN+RGCuTUszCAZ", | ||
"include_colab_link": true | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "view-in-github", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap15/15_2_Wasserstein_Distance.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# **Notebook 15.2: Wassserstein Distance**\n", | ||
"\n", | ||
"This notebook investigates the GAN toy example as illustred in figure 15.1 in the book.\n", | ||
"\n", | ||
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n", | ||
"\n", | ||
"Contact me at [email protected] if you find any mistakes or have any suggestions." | ||
], | ||
"metadata": { | ||
"id": "t9vk9Elugvmi" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"import numpy as np\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"from matplotlib import cm\n", | ||
"from matplotlib.colors import ListedColormap\n", | ||
"from scipy.optimize import linprog" | ||
], | ||
"metadata": { | ||
"id": "OLComQyvCIJ7" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Define two probability distributions\n", | ||
"p = np.array([5, 3, 2, 1, 8, 7, 5, 9, 2, 1])\n", | ||
"q = np.array([4, 10,1, 1, 4, 6, 3, 2, 0, 1])\n", | ||
"p = p/np.sum(p);\n", | ||
"q= q/np.sum(q);\n", | ||
"\n", | ||
"# Draw those distributions\n", | ||
"fig, ax =plt.subplots(2,1);\n", | ||
"x = np.arange(0,p.size,1)\n", | ||
"ax[0].bar(x,p, color=\"#cccccc\")\n", | ||
"ax[0].set_ylim([0,0.35])\n", | ||
"ax[0].set_ylabel(\"p(x=i)\")\n", | ||
"\n", | ||
"ax[1].bar(x,q,color=\"#f47a60\")\n", | ||
"ax[1].set_ylim([0,0.35])\n", | ||
"ax[1].set_ylabel(\"q(x=j)\")\n", | ||
"plt.show()" | ||
], | ||
"metadata": { | ||
"id": "ZIfQwhd-AV6L" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# TODO Define the distance matrix from figure 15.8d\n", | ||
"# Replace this line\n", | ||
"dist_mat = np.zeros((10,10))\n", | ||
"\n", | ||
"# vectorize the distance matrix\n", | ||
"c = dist_mat.flatten()" | ||
], | ||
"metadata": { | ||
"id": "EZSlZQzWBKTm" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Define pretty colormap\n", | ||
"my_colormap_vals_hex =('2a0902', '2b0a03', '2c0b04', '2d0c05', '2e0c06', '2f0d07', '300d08', '310e09', '320f0a', '330f0b', '34100b', '35110c', '36110d', '37120e', '38120f', '39130f', '3a1410', '3b1411', '3c1511', '3d1612', '3e1613', '3f1713', '401714', '411814', '421915', '431915', '451a16', '461b16', '471b17', '481c17', '491d18', '4a1d18', '4b1e19', '4c1f19', '4d1f1a', '4e201b', '50211b', '51211c', '52221c', '53231d', '54231d', '55241e', '56251e', '57261f', '58261f', '592720', '5b2821', '5c2821', '5d2922', '5e2a22', '5f2b23', '602b23', '612c24', '622d25', '632e25', '652e26', '662f26', '673027', '683027', '693128', '6a3229', '6b3329', '6c342a', '6d342a', '6f352b', '70362c', '71372c', '72372d', '73382e', '74392e', '753a2f', '763a2f', '773b30', '783c31', '7a3d31', '7b3e32', '7c3e33', '7d3f33', '7e4034', '7f4134', '804235', '814236', '824336', '834437', '854538', '864638', '874739', '88473a', '89483a', '8a493b', '8b4a3c', '8c4b3c', '8d4c3d', '8e4c3e', '8f4d3f', '904e3f', '924f40', '935041', '945141', '955242', '965343', '975343', '985444', '995545', '9a5646', '9b5746', '9c5847', '9d5948', '9e5a49', '9f5a49', 'a05b4a', 'a15c4b', 'a35d4b', 'a45e4c', 'a55f4d', 'a6604e', 'a7614e', 'a8624f', 'a96350', 'aa6451', 'ab6552', 'ac6552', 'ad6653', 'ae6754', 'af6855', 'b06955', 'b16a56', 'b26b57', 'b36c58', 'b46d59', 'b56e59', 'b66f5a', 'b7705b', 'b8715c', 'b9725d', 'ba735d', 'bb745e', 'bc755f', 'bd7660', 'be7761', 'bf7862', 'c07962', 'c17a63', 'c27b64', 'c27c65', 'c37d66', 'c47e67', 'c57f68', 'c68068', 'c78169', 'c8826a', 'c9836b', 'ca846c', 'cb856d', 'cc866e', 'cd876f', 'ce886f', 'ce8970', 'cf8a71', 'd08b72', 'd18c73', 'd28d74', 'd38e75', 'd48f76', 'd59077', 'd59178', 'd69279', 'd7937a', 'd8957b', 'd9967b', 'da977c', 'da987d', 'db997e', 'dc9a7f', 'dd9b80', 'de9c81', 'de9d82', 'df9e83', 'e09f84', 'e1a185', 'e2a286', 'e2a387', 'e3a488', 'e4a589', 'e5a68a', 'e5a78b', 'e6a88c', 'e7aa8d', 'e7ab8e', 'e8ac8f', 'e9ad90', 'eaae91', 'eaaf92', 'ebb093', 'ecb295', 'ecb396', 'edb497', 'eeb598', 'eeb699', 'efb79a', 'efb99b', 'f0ba9c', 'f1bb9d', 'f1bc9e', 'f2bd9f', 'f2bfa1', 'f3c0a2', 'f3c1a3', 'f4c2a4', 'f5c3a5', 'f5c5a6', 'f6c6a7', 'f6c7a8', 'f7c8aa', 'f7c9ab', 'f8cbac', 'f8ccad', 'f8cdae', 'f9ceb0', 'f9d0b1', 'fad1b2', 'fad2b3', 'fbd3b4', 'fbd5b6', 'fbd6b7', 'fcd7b8', 'fcd8b9', 'fcdaba', 'fddbbc', 'fddcbd', 'fddebe', 'fddfbf', 'fee0c1', 'fee1c2', 'fee3c3', 'fee4c5', 'ffe5c6', 'ffe7c7', 'ffe8c9', 'ffe9ca', 'ffebcb', 'ffeccd', 'ffedce', 'ffefcf', 'fff0d1', 'fff2d2', 'fff3d3', 'fff4d5', 'fff6d6', 'fff7d8', 'fff8d9', 'fffada', 'fffbdc', 'fffcdd', 'fffedf', 'ffffe0')\n", | ||
"my_colormap_vals_dec = np.array([int(element,base=16) for element in my_colormap_vals_hex])\n", | ||
"r = np.floor(my_colormap_vals_dec/(256*256))\n", | ||
"g = np.floor((my_colormap_vals_dec - r *256 *256)/256)\n", | ||
"b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n", | ||
"my_colormap = ListedColormap(np.vstack((r,g,b)).transpose()/255.0)\n", | ||
"\n", | ||
"def draw_2D_heatmap(data, title, my_colormap):\n", | ||
" # Make grid of intercept/slope values to plot\n", | ||
" xv, yv = np.meshgrid(np.linspace(0, 1, 10), np.linspace(0, 1, 10))\n", | ||
" fig,ax = plt.subplots()\n", | ||
" fig.set_size_inches(4,4)\n", | ||
" plt.imshow(data, cmap=my_colormap)\n", | ||
" ax.set_title(title)\n", | ||
" ax.set_xlabel('$q$'); ax.set_ylabel('$p$')\n", | ||
" plt.show()" | ||
], | ||
"metadata": { | ||
"id": "ABRANmp6F8iQ" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"draw_2D_heatmap(dist_mat,'Distance $|i-j|$', my_colormap)" | ||
], | ||
"metadata": { | ||
"id": "G0HFPBXyHT6V" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Define b to be the verticalconcatenation of p and q\n", | ||
"b = np.hstack((p,q))[np.newaxis].transpose()" | ||
], | ||
"metadata": { | ||
"id": "SfqeT3KlHWrt" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# TODO: Now construct the matrix A that has the initial distribution constraints\n", | ||
"# so that Ap=b where p is the transport plan P vectorized rows first so p = np.flatten(P)\n", | ||
"# Replace this line:\n", | ||
"A = np.zeros((20,100))\n" | ||
], | ||
"metadata": { | ||
"id": "7KrybL96IuNW" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Now we have all of the things we need. The vectorized distance matrix $\\mathbf{c}$, the constraint matrix $\\mathbf{A}$, the vectorized and concatenated original distribution $\\mathbf{b}$. We can run the linear programming optimization." | ||
], | ||
"metadata": { | ||
"id": "zEuEtU33S8Ly" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# We don't need the constraint that p>0 as this is the default\n", | ||
"opt = linprog(c, A_eq=A, b_eq=b)\n", | ||
"print(opt)" | ||
], | ||
"metadata": { | ||
"id": "wCfsOVbeSmF5" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Extract the answer and display" | ||
], | ||
"metadata": { | ||
"id": "vpkkOOI2agyl" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"P = np.array(opt.x).reshape(10,10)\n", | ||
"draw_2D_heatmap(P,'Transport plan $\\mathbf{P}$', my_colormap)" | ||
], | ||
"metadata": { | ||
"id": "nZGfkrbRV_D0" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Compute the Wasserstein distance\n" | ||
], | ||
"metadata": { | ||
"id": "ZEiRYRVgalsJ" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"was = np.sum(P * dist_mat)\n", | ||
"print(\"Wasserstein distance = \", was)" | ||
], | ||
"metadata": { | ||
"id": "yiQ_8j-Raq3c" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"TODO -- Compute the\n", | ||
"\n", | ||
"* Forward KL divergence $D_{KL}[p,q]$ between these distributions\n", | ||
"* Reverse KL divergence $D_{KL}[q,p]$ between these distributions\n", | ||
"* Jensen-Shannon divergence $D_{JS}[p,q]$ between these distributions\n", | ||
"\n", | ||
"What do you conclude?" | ||
], | ||
"metadata": { | ||
"id": "zf8yTusua71s" | ||
} | ||
} | ||
] | ||
} |