forked from udlbook/udlbook
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
380 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,380 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"provenance": [], | ||
"authorship_tag": "ABX9TyM4DdZDGoP1xGst+Nn+rwvt", | ||
"include_colab_link": true | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "view-in-github", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap18/18_2_1D_Diffusion_Model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# **Notebook 18.2: 1D Diffusion Model**\n", | ||
"\n", | ||
"This notebook investigates the diffusion encoder as described in section 18.3 and 18.4 of the book.\n", | ||
"\n", | ||
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n", | ||
"\n", | ||
"Contact me at [email protected] if you find any mistakes or have any suggestions." | ||
], | ||
"metadata": { | ||
"id": "t9vk9Elugvmi" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"import numpy as np\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"from matplotlib.colors import ListedColormap\n", | ||
"from operator import itemgetter\n", | ||
"from scipy import stats\n", | ||
"from IPython.display import display, clear_output" | ||
], | ||
"metadata": { | ||
"id": "OLComQyvCIJ7" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"#Create pretty colormap as in book\n", | ||
"my_colormap_vals_hex =('2a0902', '2b0a03', '2c0b04', '2d0c05', '2e0c06', '2f0d07', '300d08', '310e09', '320f0a', '330f0b', '34100b', '35110c', '36110d', '37120e', '38120f', '39130f', '3a1410', '3b1411', '3c1511', '3d1612', '3e1613', '3f1713', '401714', '411814', '421915', '431915', '451a16', '461b16', '471b17', '481c17', '491d18', '4a1d18', '4b1e19', '4c1f19', '4d1f1a', '4e201b', '50211b', '51211c', '52221c', '53231d', '54231d', '55241e', '56251e', '57261f', '58261f', '592720', '5b2821', '5c2821', '5d2922', '5e2a22', '5f2b23', '602b23', '612c24', '622d25', '632e25', '652e26', '662f26', '673027', '683027', '693128', '6a3229', '6b3329', '6c342a', '6d342a', '6f352b', '70362c', '71372c', '72372d', '73382e', '74392e', '753a2f', '763a2f', '773b30', '783c31', '7a3d31', '7b3e32', '7c3e33', '7d3f33', '7e4034', '7f4134', '804235', '814236', '824336', '834437', '854538', '864638', '874739', '88473a', '89483a', '8a493b', '8b4a3c', '8c4b3c', '8d4c3d', '8e4c3e', '8f4d3f', '904e3f', '924f40', '935041', '945141', '955242', '965343', '975343', '985444', '995545', '9a5646', '9b5746', '9c5847', '9d5948', '9e5a49', '9f5a49', 'a05b4a', 'a15c4b', 'a35d4b', 'a45e4c', 'a55f4d', 'a6604e', 'a7614e', 'a8624f', 'a96350', 'aa6451', 'ab6552', 'ac6552', 'ad6653', 'ae6754', 'af6855', 'b06955', 'b16a56', 'b26b57', 'b36c58', 'b46d59', 'b56e59', 'b66f5a', 'b7705b', 'b8715c', 'b9725d', 'ba735d', 'bb745e', 'bc755f', 'bd7660', 'be7761', 'bf7862', 'c07962', 'c17a63', 'c27b64', 'c27c65', 'c37d66', 'c47e67', 'c57f68', 'c68068', 'c78169', 'c8826a', 'c9836b', 'ca846c', 'cb856d', 'cc866e', 'cd876f', 'ce886f', 'ce8970', 'cf8a71', 'd08b72', 'd18c73', 'd28d74', 'd38e75', 'd48f76', 'd59077', 'd59178', 'd69279', 'd7937a', 'd8957b', 'd9967b', 'da977c', 'da987d', 'db997e', 'dc9a7f', 'dd9b80', 'de9c81', 'de9d82', 'df9e83', 'e09f84', 'e1a185', 'e2a286', 'e2a387', 'e3a488', 'e4a589', 'e5a68a', 'e5a78b', 'e6a88c', 'e7aa8d', 'e7ab8e', 'e8ac8f', 'e9ad90', 'eaae91', 'eaaf92', 'ebb093', 'ecb295', 'ecb396', 'edb497', 'eeb598', 'eeb699', 'efb79a', 'efb99b', 'f0ba9c', 'f1bb9d', 'f1bc9e', 'f2bd9f', 'f2bfa1', 'f3c0a2', 'f3c1a3', 'f4c2a4', 'f5c3a5', 'f5c5a6', 'f6c6a7', 'f6c7a8', 'f7c8aa', 'f7c9ab', 'f8cbac', 'f8ccad', 'f8cdae', 'f9ceb0', 'f9d0b1', 'fad1b2', 'fad2b3', 'fbd3b4', 'fbd5b6', 'fbd6b7', 'fcd7b8', 'fcd8b9', 'fcdaba', 'fddbbc', 'fddcbd', 'fddebe', 'fddfbf', 'fee0c1', 'fee1c2', 'fee3c3', 'fee4c5', 'ffe5c6', 'ffe7c7', 'ffe8c9', 'ffe9ca', 'ffebcb', 'ffeccd', 'ffedce', 'ffefcf', 'fff0d1', 'fff2d2', 'fff3d3', 'fff4d5', 'fff6d6', 'fff7d8', 'fff8d9', 'fffada', 'fffbdc', 'fffcdd', 'fffedf', 'ffffe0')\n", | ||
"my_colormap_vals_dec = np.array([int(element,base=16) for element in my_colormap_vals_hex])\n", | ||
"r = np.floor(my_colormap_vals_dec/(256*256))\n", | ||
"g = np.floor((my_colormap_vals_dec - r *256 *256)/256)\n", | ||
"b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n", | ||
"my_colormap_vals = np.vstack((r,g,b)).transpose()/255.0\n", | ||
"my_colormap = ListedColormap(my_colormap_vals)" | ||
], | ||
"metadata": { | ||
"id": "4PM8bf6lO0VE" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Probability distribution for normal\n", | ||
"def norm_pdf(x, mu, sigma):\n", | ||
" return np.exp(-0.5 * (x-mu) * (x-mu) / (sigma * sigma)) / np.sqrt(2*np.pi*sigma*sigma)" | ||
], | ||
"metadata": { | ||
"id": "ONGRaQscfIOo" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# True distribution is a mixture of four Gaussians\n", | ||
"class TrueDataDistribution:\n", | ||
" # Constructor initializes parameters\n", | ||
" def __init__(self):\n", | ||
" self.mu = [1.5, -0.216, 0.45, -1.875]\n", | ||
" self.sigma = [0.3, 0.15, 0.525, 0.075]\n", | ||
" self.w = [0.2, 0.3, 0.35, 0.15]\n", | ||
"\n", | ||
" # Return PDF\n", | ||
" def pdf(self, x):\n", | ||
" return(self.w[0] *norm_pdf(x,self.mu[0],self.sigma[0]) + self.w[1] *norm_pdf(x,self.mu[1],self.sigma[1]) + self.w[2] *norm_pdf(x,self.mu[2],self.sigma[2]) + self.w[3] *norm_pdf(x,self.mu[3],self.sigma[3]))\n", | ||
"\n", | ||
" # Draw samples\n", | ||
" def sample(self, n):\n", | ||
" hidden = np.random.choice(4, n, p=self.w)\n", | ||
" epsilon = np.random.normal(size=(n))\n", | ||
" mu_list = list(itemgetter(*hidden)(self.mu))\n", | ||
" sigma_list = list(itemgetter(*hidden)(self.sigma))\n", | ||
" return mu_list + sigma_list * epsilon" | ||
], | ||
"metadata": { | ||
"id": "gZvG0MKhfY8Y" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Define ground truth probability distribution that we will model\n", | ||
"true_dist = TrueDataDistribution()\n", | ||
"# Let's visualize this\n", | ||
"x_vals = np.arange(-3,3,0.01)\n", | ||
"pr_x_true = true_dist.pdf(x_vals)\n", | ||
"fig,ax = plt.subplots()\n", | ||
"fig.set_size_inches(8,2.5)\n", | ||
"ax.plot(x_vals, pr_x_true, 'r-')\n", | ||
"ax.set_xlabel(\"$x$\")\n", | ||
"ax.set_ylabel(\"$Pr(x)$\")\n", | ||
"ax.set_ylim(0,1.0)\n", | ||
"ax.set_xlim(-3,3)\n", | ||
"plt.show()" | ||
], | ||
"metadata": { | ||
"id": "iJu_uBiaeUVv" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"To train the model to describe this distribution, we'll need to generate pairs of samples drawn from $Pr(z_t|x)$ (diffusion kernel) and $q(z_{t-1}|z_{t},x)$ (equation 18.15).\n", | ||
"\n" | ||
], | ||
"metadata": { | ||
"id": "DRHUG_41i4t_" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# The diffusion kernel returns the parameters of Pr(z_{t}|x)\n", | ||
"def diffusion_kernel(x, t, beta):\n", | ||
" alpha = np.power(1-beta,t)\n", | ||
" dk_mean = x * np.sqrt(alpha)\n", | ||
" dk_std = np.sqrt(1-alpha)\n", | ||
" return dk_mean, dk_std\n", | ||
"\n", | ||
"# Compute mean and variance q(z_{t-1}|z_{t},x)\n", | ||
"def conditional_diffusion_distribution(x,z_t,t,beta):\n", | ||
" # TODO -- Implement this function\n", | ||
" # Replace this line\n", | ||
" cd_mean = 0; cd_std = 1\n", | ||
"\n", | ||
" return cd_mean, cd_std\n", | ||
"\n", | ||
"def get_data_pairs(x_train,t,beta):\n", | ||
" # Find diffusion kernel for every x_train and draw samples\n", | ||
" dk_mean, dk_std = diffusion_kernel(x_train, t, beta)\n", | ||
" z_t = np.random.normal(size=x_train.shape) * dk_std + dk_mean\n", | ||
" # Find conditional diffusion distribution for each x_train, z pair and draw samlpes\n", | ||
" cd_mean, cd_std = conditional_diffusion_distribution(x_train,z_t,t,beta)\n", | ||
" if t == 1:\n", | ||
" z_tminus1 = x_train\n", | ||
" else:\n", | ||
" z_tminus1 = np.random.normal(size=x_train.shape) * cd_std + cd_mean\n", | ||
"\n", | ||
" return z_t, z_tminus1" | ||
], | ||
"metadata": { | ||
"id": "x6B8t72Ukscd" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"We also need models $\\mbox{f}_t[z_{t},\\phi_{t}]$ that map from $z_{t}$ to the mean of the distribution at time $z_{t-1}$. We're just going to use a very hacky non-parametric model (basically a lookup table) that tells you the result based on the (quantized) input." | ||
], | ||
"metadata": { | ||
"id": "aSG_4uA8_zZ-" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# This code is really ugly! Don't look too closely at it!\n", | ||
"# All you need to know is that it is a model that trains from pairs zt, zt_minus1\n", | ||
"# And can then predict zt\n", | ||
"class NonParametricModel():\n", | ||
" # Constructor initializes parameters\n", | ||
" def __init__(self):\n", | ||
"\n", | ||
" self.inc = 0.01\n", | ||
" self.max_val = 3.0\n", | ||
" self.model = []\n", | ||
"\n", | ||
" # Learns a model that predicts z_t_minus1 given z_t\n", | ||
" def train(self, zt, zt_minus1):\n", | ||
" zt = np.clip(zt,-self.max_val,self.max_val)\n", | ||
" zt_minus1 = np.clip(zt_minus1,-self.max_val,self.max_val)\n", | ||
" bins = np.arange(-self.max_val,self.max_val+self.inc,self.inc)\n", | ||
" numerator, *_ = stats.binned_statistic(zt, zt_minus1-zt, statistic='sum',bins=bins)\n", | ||
" denominator, *_ = stats.binned_statistic(zt, zt_minus1-zt, statistic='count',bins=bins)\n", | ||
" self.model = numerator / (denominator + 1)\n", | ||
"\n", | ||
" def predict(self, zt):\n", | ||
" bin_index = np.floor((zt+self.max_val)/self.inc)\n", | ||
" bin_index = np.clip(bin_index,0, len(self.model)-1).astype('uint32')\n", | ||
" return zt + self.model[bin_index]" | ||
], | ||
"metadata": { | ||
"id": "ZHViC0pL_yy5" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Sample data from distribution (this would usually be our collected training set)\n", | ||
"n_sample = 100000\n", | ||
"x_train = true_dist.sample(n_sample)\n", | ||
"\n", | ||
"# Define model parameters\n", | ||
"T = 100\n", | ||
"beta = 0.01511\n", | ||
"\n", | ||
"all_models = []\n", | ||
"for t in range(0,T):\n", | ||
" clear_output(wait=True)\n", | ||
" display(\"Training timestep %d\"%(t))\n", | ||
" zt,zt_minus1 = get_data_pairs(x_train,t+1,beta)\n", | ||
" all_models.append(NonParametricModel())\n", | ||
" # The model at index t maps data from z_{t+1} to z_{t}\n", | ||
" all_models[t].train(zt,zt_minus1)" | ||
], | ||
"metadata": { | ||
"id": "CzVFybWoBygu" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Now that we've learned the model, let's draw some samples from it. We start at $z_{100}$ and use the model to predict $z_{99}$, then $z_{98}$ and so on until finally we get to $z_{1}$ and then $x$ (represented as $z_{0}$ here). We'll store all of the intermediate stages as well, so we can plot the trajectories. See equations 18.16." | ||
], | ||
"metadata": { | ||
"id": "ZPc9SEvtl14U" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"def sample(model, T, sigma_t, n_samples):\n", | ||
" # Create the output array\n", | ||
" # Each row represents a time step, first row will be sampled data\n", | ||
" # Each column represents a different sample\n", | ||
" samples = np.zeros((T+1,n_samples))\n", | ||
"\n", | ||
" # TODO -- Initialize the samples z_{T} at samples[T,:] from standard normal distribution\n", | ||
" # Replace this line\n", | ||
" samples[T,:] = np.zeros((1,n_samples))\n", | ||
"\n", | ||
"\n", | ||
" # For t=100...99..98... ...0\n", | ||
" for t in range(T,0,-1):\n", | ||
" clear_output(wait=True)\n", | ||
" display(\"Predicting z_{%d} from z_{%d}\"%(t-1,t))\n", | ||
" # TODO Predict samples[t-1,:] from samples[t,:] using the appropriate model\n", | ||
" # Replace this line:\n", | ||
" samples[t-1,:] = np.zeros((1,n_samples))\n", | ||
"\n", | ||
"\n", | ||
" # If not the last time step\n", | ||
" if t>0:\n", | ||
" # TODO Add noise to the samples at z_t-1 we just generated with mean zero, standard deviation sigma_t\n", | ||
" # Replace this line\n", | ||
" samples[t-1,:] = samples[t-1,:]\n", | ||
"\n", | ||
" return samples" | ||
], | ||
"metadata": { | ||
"id": "A-ZMFOvACIOw" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Now let's run the diffusion process for a whole bunch of samples" | ||
], | ||
"metadata": { | ||
"id": "ECAUfHNi9NVW" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"sigma_t=0.12288\n", | ||
"n_samples = 100000\n", | ||
"samples = sample(all_models, T, sigma_t, n_samples)\n", | ||
"\n", | ||
"\n", | ||
"# Plot the data\n", | ||
"sampled_data = samples[0,:]\n", | ||
"bins = np.arange(-3,3.05,0.05)\n", | ||
"\n", | ||
"fig,ax = plt.subplots()\n", | ||
"fig.set_size_inches(8,2.5)\n", | ||
"ax.set_xlim([-3,3])\n", | ||
"plt.hist(sampled_data, bins=bins, density =True)\n", | ||
"ax.set_ylim(0, 0.8)\n", | ||
"plt.show()" | ||
], | ||
"metadata": { | ||
"id": "M-TY5w9Q8LYW" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Let's, plot the evolution of a few of the paths as in figure 18.7 (paths are from bottom to top now)." | ||
], | ||
"metadata": { | ||
"id": "jYrAW6tN-gJ4" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"fig, ax = plt.subplots()\n", | ||
"t_vals = np.arange(0,101,1)\n", | ||
"ax.plot(samples[:,0],t_vals,'r-')\n", | ||
"ax.plot(samples[:,1],t_vals,'g-')\n", | ||
"ax.plot(samples[:,2],t_vals,'b-')\n", | ||
"ax.plot(samples[:,3],t_vals,'c-')\n", | ||
"ax.plot(samples[:,4],t_vals,'m-')\n", | ||
"ax.set_xlim([-3,3])\n", | ||
"ax.set_ylim([101, 0])\n", | ||
"ax.set_xlabel('value')\n", | ||
"ax.set_ylabel('z_{t}')\n", | ||
"plt.show()" | ||
], | ||
"metadata": { | ||
"id": "4XU6CDZC_kFo" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Notice that the samples have a tendency to move from positions that are near the center at time 100 to positions that are high in the true probability distribution at time 0" | ||
], | ||
"metadata": { | ||
"id": "SGTYGGevAktz" | ||
} | ||
} | ||
] | ||
} |