Skip to content

Commit

Permalink
Created using Colaboratory
Browse files Browse the repository at this point in the history
  • Loading branch information
udlbook committed Oct 4, 2023
1 parent aa5d89a commit f50de74
Showing 1 changed file with 246 additions and 0 deletions.
246 changes: 246 additions & 0 deletions Notebooks/Chap15/15_2_Wasserstein_Distance.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyNyLnpoXgKN+RGCuTUszCAZ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap15/15_2_Wasserstein_Distance.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# **Notebook 15.2: Wassserstein Distance**\n",
"\n",
"This notebook investigates the GAN toy example as illustred in figure 15.1 in the book.\n",
"\n",
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
"\n",
"Contact me at [email protected] if you find any mistakes or have any suggestions."
],
"metadata": {
"id": "t9vk9Elugvmi"
}
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib import cm\n",
"from matplotlib.colors import ListedColormap\n",
"from scipy.optimize import linprog"
],
"metadata": {
"id": "OLComQyvCIJ7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Define two probability distributions\n",
"p = np.array([5, 3, 2, 1, 8, 7, 5, 9, 2, 1])\n",
"q = np.array([4, 10,1, 1, 4, 6, 3, 2, 0, 1])\n",
"p = p/np.sum(p);\n",
"q= q/np.sum(q);\n",
"\n",
"# Draw those distributions\n",
"fig, ax =plt.subplots(2,1);\n",
"x = np.arange(0,p.size,1)\n",
"ax[0].bar(x,p, color=\"#cccccc\")\n",
"ax[0].set_ylim([0,0.35])\n",
"ax[0].set_ylabel(\"p(x=i)\")\n",
"\n",
"ax[1].bar(x,q,color=\"#f47a60\")\n",
"ax[1].set_ylim([0,0.35])\n",
"ax[1].set_ylabel(\"q(x=j)\")\n",
"plt.show()"
],
"metadata": {
"id": "ZIfQwhd-AV6L"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# TODO Define the distance matrix from figure 15.8d\n",
"# Replace this line\n",
"dist_mat = np.zeros((10,10))\n",
"\n",
"# vectorize the distance matrix\n",
"c = dist_mat.flatten()"
],
"metadata": {
"id": "EZSlZQzWBKTm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Define pretty colormap\n",
"my_colormap_vals_hex =('2a0902', '2b0a03', '2c0b04', '2d0c05', '2e0c06', '2f0d07', '300d08', '310e09', '320f0a', '330f0b', '34100b', '35110c', '36110d', '37120e', '38120f', '39130f', '3a1410', '3b1411', '3c1511', '3d1612', '3e1613', '3f1713', '401714', '411814', '421915', '431915', '451a16', '461b16', '471b17', '481c17', '491d18', '4a1d18', '4b1e19', '4c1f19', '4d1f1a', '4e201b', '50211b', '51211c', '52221c', '53231d', '54231d', '55241e', '56251e', '57261f', '58261f', '592720', '5b2821', '5c2821', '5d2922', '5e2a22', '5f2b23', '602b23', '612c24', '622d25', '632e25', '652e26', '662f26', '673027', '683027', '693128', '6a3229', '6b3329', '6c342a', '6d342a', '6f352b', '70362c', '71372c', '72372d', '73382e', '74392e', '753a2f', '763a2f', '773b30', '783c31', '7a3d31', '7b3e32', '7c3e33', '7d3f33', '7e4034', '7f4134', '804235', '814236', '824336', '834437', '854538', '864638', '874739', '88473a', '89483a', '8a493b', '8b4a3c', '8c4b3c', '8d4c3d', '8e4c3e', '8f4d3f', '904e3f', '924f40', '935041', '945141', '955242', '965343', '975343', '985444', '995545', '9a5646', '9b5746', '9c5847', '9d5948', '9e5a49', '9f5a49', 'a05b4a', 'a15c4b', 'a35d4b', 'a45e4c', 'a55f4d', 'a6604e', 'a7614e', 'a8624f', 'a96350', 'aa6451', 'ab6552', 'ac6552', 'ad6653', 'ae6754', 'af6855', 'b06955', 'b16a56', 'b26b57', 'b36c58', 'b46d59', 'b56e59', 'b66f5a', 'b7705b', 'b8715c', 'b9725d', 'ba735d', 'bb745e', 'bc755f', 'bd7660', 'be7761', 'bf7862', 'c07962', 'c17a63', 'c27b64', 'c27c65', 'c37d66', 'c47e67', 'c57f68', 'c68068', 'c78169', 'c8826a', 'c9836b', 'ca846c', 'cb856d', 'cc866e', 'cd876f', 'ce886f', 'ce8970', 'cf8a71', 'd08b72', 'd18c73', 'd28d74', 'd38e75', 'd48f76', 'd59077', 'd59178', 'd69279', 'd7937a', 'd8957b', 'd9967b', 'da977c', 'da987d', 'db997e', 'dc9a7f', 'dd9b80', 'de9c81', 'de9d82', 'df9e83', 'e09f84', 'e1a185', 'e2a286', 'e2a387', 'e3a488', 'e4a589', 'e5a68a', 'e5a78b', 'e6a88c', 'e7aa8d', 'e7ab8e', 'e8ac8f', 'e9ad90', 'eaae91', 'eaaf92', 'ebb093', 'ecb295', 'ecb396', 'edb497', 'eeb598', 'eeb699', 'efb79a', 'efb99b', 'f0ba9c', 'f1bb9d', 'f1bc9e', 'f2bd9f', 'f2bfa1', 'f3c0a2', 'f3c1a3', 'f4c2a4', 'f5c3a5', 'f5c5a6', 'f6c6a7', 'f6c7a8', 'f7c8aa', 'f7c9ab', 'f8cbac', 'f8ccad', 'f8cdae', 'f9ceb0', 'f9d0b1', 'fad1b2', 'fad2b3', 'fbd3b4', 'fbd5b6', 'fbd6b7', 'fcd7b8', 'fcd8b9', 'fcdaba', 'fddbbc', 'fddcbd', 'fddebe', 'fddfbf', 'fee0c1', 'fee1c2', 'fee3c3', 'fee4c5', 'ffe5c6', 'ffe7c7', 'ffe8c9', 'ffe9ca', 'ffebcb', 'ffeccd', 'ffedce', 'ffefcf', 'fff0d1', 'fff2d2', 'fff3d3', 'fff4d5', 'fff6d6', 'fff7d8', 'fff8d9', 'fffada', 'fffbdc', 'fffcdd', 'fffedf', 'ffffe0')\n",
"my_colormap_vals_dec = np.array([int(element,base=16) for element in my_colormap_vals_hex])\n",
"r = np.floor(my_colormap_vals_dec/(256*256))\n",
"g = np.floor((my_colormap_vals_dec - r *256 *256)/256)\n",
"b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n",
"my_colormap = ListedColormap(np.vstack((r,g,b)).transpose()/255.0)\n",
"\n",
"def draw_2D_heatmap(data, title, my_colormap):\n",
" # Make grid of intercept/slope values to plot\n",
" xv, yv = np.meshgrid(np.linspace(0, 1, 10), np.linspace(0, 1, 10))\n",
" fig,ax = plt.subplots()\n",
" fig.set_size_inches(4,4)\n",
" plt.imshow(data, cmap=my_colormap)\n",
" ax.set_title(title)\n",
" ax.set_xlabel('$q$'); ax.set_ylabel('$p$')\n",
" plt.show()"
],
"metadata": {
"id": "ABRANmp6F8iQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"draw_2D_heatmap(dist_mat,'Distance $|i-j|$', my_colormap)"
],
"metadata": {
"id": "G0HFPBXyHT6V"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Define b to be the verticalconcatenation of p and q\n",
"b = np.hstack((p,q))[np.newaxis].transpose()"
],
"metadata": {
"id": "SfqeT3KlHWrt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# TODO: Now construct the matrix A that has the initial distribution constraints\n",
"# so that Ap=b where p is the transport plan P vectorized rows first so p = np.flatten(P)\n",
"# Replace this line:\n",
"A = np.zeros((20,100))\n"
],
"metadata": {
"id": "7KrybL96IuNW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Now we have all of the things we need. The vectorized distance matrix $\\mathbf{c}$, the constraint matrix $\\mathbf{A}$, the vectorized and concatenated original distribution $\\mathbf{b}$. We can run the linear programming optimization."
],
"metadata": {
"id": "zEuEtU33S8Ly"
}
},
{
"cell_type": "code",
"source": [
"# We don't need the constraint that p>0 as this is the default\n",
"opt = linprog(c, A_eq=A, b_eq=b)\n",
"print(opt)"
],
"metadata": {
"id": "wCfsOVbeSmF5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Extract the answer and display"
],
"metadata": {
"id": "vpkkOOI2agyl"
}
},
{
"cell_type": "code",
"source": [
"P = np.array(opt.x).reshape(10,10)\n",
"draw_2D_heatmap(P,'Transport plan $\\mathbf{P}$', my_colormap)"
],
"metadata": {
"id": "nZGfkrbRV_D0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Compute the Wasserstein distance\n"
],
"metadata": {
"id": "ZEiRYRVgalsJ"
}
},
{
"cell_type": "code",
"source": [
"was = np.sum(P * dist_mat)\n",
"print(\"Wasserstein distance = \", was)"
],
"metadata": {
"id": "yiQ_8j-Raq3c"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"TODO -- Compute the\n",
"\n",
"* Forward KL divergence $D_{KL}[p,q]$ between these distributions\n",
"* Reverse KL divergence $D_{KL}[q,p]$ between these distributions\n",
"* Jensen-Shannon divergence $D_{JS}[p,q]$ between these distributions\n",
"\n",
"What do you conclude?"
],
"metadata": {
"id": "zf8yTusua71s"
}
}
]
}

0 comments on commit f50de74

Please sign in to comment.