forked from matthewvowels1/Awesome-VAEs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9c874bc
commit 84e8d27
Showing
6 changed files
with
172 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"cells": [], | ||
"metadata": {}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# create synthetic batch of z and ASSUME that there are 10 examples per class in a batch of 50\n", | ||
"classes=5\n", | ||
"examples_per_class = 10\n", | ||
"batch_size = classes*examples_per_class\n", | ||
"z_dim = 5\n", | ||
"\n", | ||
"# here rows are embeddings, columns are dimensions of the embedding\n", | ||
"z_batch = np.random.randn(batch_size,z_dim)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# this is mu_j in the paper:\n", | ||
"class_means = np.zeros((z_dim,classes)) \n", | ||
"\n", | ||
"for i in range(0,batch_size, examples_per_class):\n", | ||
" \n", | ||
" class_means[(int(i/examples_per_class))] = z_batch[i:i+examples_per_class].mean()\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# I expect you can combine this loop that calculates intra_spread with the one for mu_j above easily enough, \n", | ||
"# to save computation\n", | ||
"\n", | ||
"cum_diff_sq = 0\n", | ||
"\n", | ||
"for i in range(class_means.shape[0]):\n", | ||
" \n", | ||
" z = z_batch[i*examples_per_class:i*examples_per_class+examples_per_class] # extract per class data\n", | ||
" mu = class_means[i] # extract class[i] mean\n", | ||
" mu_diff = mu-z # calculate difference between the data and its corresponding class mean\n", | ||
" # next calculate the per row squared L2 norm (i.e. dot product)\n", | ||
" mu_diff_norm = np.asarray([np.dot(mu_diff[i].T, mu_diff[i]) for i in range(mu_diff.shape[0])])\n", | ||
" cum_diff_sq += mu_diff_norm.sum() # sum these squared norms over examples AND over classes\n", | ||
"\n", | ||
"intra_spread = cum_diff_sq/batch_size # divide by total number of datapoints in batch\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# for inter-spread you essentially need to find the min of all pairwise distances\n", | ||
"# you might want to look into efficient ways to compute this (there might be a keras function already, mind)\n", | ||
"# if you only have a few classes then maybe not a big deal.\n", | ||
"\n", | ||
"inter_separation = np.min([np.dot((class_means[i] - class_means[j]).T,(class_means[i] - class_means[j])) for i in range(class_means.shape[0]) for j in range(class_means.shape[0]) if i > j])\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"ii_loss = intra_spread - inter_separation" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.