Skip to content

Commit

Permalink
update capsule tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
VoVAllen committed Oct 22, 2018
1 parent d031332 commit 026d35c
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 33 deletions.
145 changes: 112 additions & 33 deletions tutorial/capsule/Capsule Tutorial(WIP).ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
"## Model Overview\n",
"\n",
"### Introduction\n",
"Capsule Network is \n",
"Capsule Network were first introduced in 2011 by Geoffrey Hinton, et al., in a paper called [Transforming Autoencoders](https://www.cs.toronto.edu/~fritz/absps/transauto6.pdf), but it was only a few months ago, in November 2017, that Sara Sabour, Nicholas Frosst, and Geoffrey Hinton published a paper called Dynamic Routing between Capsules, where they introduced a CapsNet architecture that reached state-of-the-art performance on MNIST.\n",
"\n",
"### What's a capsule?\n",
"> A capsule is a group of neurons whose activity vector represents the instantiation parameters of a specific type of entity such as an object or an object part. \n",
"\n",
"Generally Speaking, the idea of capsule is to encode all the information about the features in a vector form, by substituting scalars in traditional neural network with vectors. And use the norm of the vector to represents the meaning of original scalars. \n",
"Generally Speaking, the idea of capsule is to encode all the information about the features into a vector form, by substituting scalars in traditional neural network with vectors. And use the norm of the vector to represents the meaning of original scalars. \n",
"![figure_1](./capsule_f1.png)\n",
"\n",
"### Dynamic Routing Algorithm\n",
"<img src=\"./capsule_f2.png\" style=\"height:300px;\"/>"
"Due to the different structure of network, capsules network has different operations to calculate results. This figure shows the comparison, drawn by [Max Pechyonkin](https://medium.com/ai%C2%B3-theory-practice-business/understanding-hintons-capsule-networks-part-ii-how-capsules-work-153b6ade9f66O). \n",
"<img src=\"./capsule_f2.png\" style=\"height:250px;\"/><br/>\n",
"\n",
"The key idea is that the output of each capsule is the sum of weighted input vectors. We will go into details in the later section with code implementations.\n"
]
},
{
Expand All @@ -38,7 +41,7 @@
"\n",
"### 1. Consider capsule routing as a graph structure\n",
"\n",
"We can consider each capsule as a node in a graph, and connect the nodes between layers.\n",
"We can consider each capsule as a node in a graph, and connect all the nodes between layers.\n",
"<img src=\"./capsule_f3.png\" style=\"height:200px;\"/>"
]
},
Expand All @@ -50,23 +53,25 @@
"source": [
"def construct_graph(self):\n",
" g = dgl.DGLGraph()\n",
" g.add_nodes(self.in_channel + self.num_unit)\n",
" self.in_channel_nodes = list(range(self.in_channel))\n",
" self.capsule_nodes = list(range(self.in_channel, self.in_channel + self.num_unit))\n",
" g.add_nodes(self.input_capsule_num + self.output_capsule_num)\n",
" input_nodes = list(range(self.input_capsule_num))\n",
" output_nodes = list(range(self.input_capsule_num, self.input_capsule_num + self.output_capsule_num))\n",
" u, v = [], []\n",
" for i in self.in_channel_nodes:\n",
" for j in self.capsule_nodes:\n",
" for i in input_nodes:\n",
" for j in output_nodes:\n",
" u.append(i)\n",
" v.append(j)\n",
" g.add_edges(u, v)\n",
" return g"
" return g, input_nodes, output_nodes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Pre-compute $\\hat{u}_{j|i}$, initialize $b_{ij}$ and store them as edge attribute\n",
"### 2. Initialization & Affine Transformation\n",
"- Pre-compute $\\hat{u}_{j|i}$, initialize $b_{ij}$ and store them as edge attribute\n",
"- Initialize node features as zero\n",
"<img src=\"./capsule_f4.png\" style=\"height:200px;\"/>"
]
},
Expand All @@ -76,36 +81,63 @@
"metadata": {},
"outputs": [],
"source": [
"# x is the input vextor with shape [batch_size, input_capsule_dim, input_num]\n",
"# Transpose x to [batch_size, input_num, input_capsule_dim] \n",
"x = x.transpose(1, 2)\n",
"x = torch.stack([x] * self.num_unit, dim=2).unsqueeze(4)\n",
"W = self.weight.expand(self.batch_size, *self.weight.shape)\n",
"# Expand x to [batch_size, input_num, output_num, input_capsule_dim, 1]\n",
"x = torch.stack([x] * self.output_capsule_num, dim=2).unsqueeze(4)\n",
"# Expand W from [input_num, output_num, input_capsule_dim, output_capsule_dim] \n",
"# to [batch_size, input_num, output_num, output_capsule_dim, input_capsule_dim] \n",
"W = self.weight.expand(self.batch_size, *self.weight.size())\n",
"# u_hat's shape is [input_num, output_num, batch_size, output_capsule_dim]\n",
"u_hat = torch.matmul(W, x).permute(1, 2, 0, 3, 4).squeeze().contiguous()\n",
"self.g.set_e_repr({'b_ij': edge_features.view(-1)})\n",
"self.g.set_e_repr({'u_hat': u_hat.view(-1, self.batch_size, self.unit_size)})"
"\n",
"b_ij = torch.zeros(self.input_capsule_num, self.output_capsule_num).to(self.device)\n",
"\n",
"self.g.set_e_repr({'b_ij': b_ij.view(-1)})\n",
"self.g.set_e_repr({'u_hat': u_hat.view(-1, self.batch_size, self.unit_size)})\n",
"\n",
"# Initialize all node features as zero\n",
"node_features = torch.zeros(self.input_capsule_num + self.output_capsule_num, self.batch_size,\n",
" self.output_capsule_dim).to(self.device)\n",
"self.g.set_n_repr({'h': node_features})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Initialize node features"
"### 3. Write Message Passing functions and Squash function"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.1 Squash function\n",
"Squashing function is to ensure that short vectors get shrunk to almost zero length and long vectors get shrunk to a length slightly below 1.\n",
"![squash](./squash.png)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"node_features = torch.zeros(self.in_channel + self.num_unit, self.batch_size, self.unit_size).to(device)\n",
"self.g.set_n_repr({'h': node_features})"
"def squash(s):\n",
" msg_sq = torch.sum(s ** 2, dim=2, keepdim=True)\n",
" msg = torch.sqrt(msg_sq)\n",
" s = (msg_sq / (1.0 + msg_sq)) * (s / msg)\n",
" return s"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Write message passing functions"
"#### 3.2 Message Functions\n",
"At first stage, we need to define a message function to get all the attributes we need in the further computations."
]
},
{
Expand All @@ -114,34 +146,81 @@
"metadata": {},
"outputs": [],
"source": [
"@staticmethod\n",
"def capsule_msg(src, edge):\n",
" return {'b_ij': edge['b_ij'], 'h': src['h'], 'u_hat': edge['u_hat']}\n",
"\n",
"@staticmethod\n",
" return {'b_ij': edge['b_ij'], 'h': src['h'], 'u_hat': edge['u_hat']}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.3 Reduce Functions\n",
"At this stage, we need to define a reduce function to aggregate all the information we get from message function into node features.\n",
"This step implements the line 4 and line 5 in routing algorithms, which softmax over $b_{ij}$ and calculate weighted sum of input features. Note that softmax operation is over dimension $j$ instead of $i$. \n",
"<img src=\"./capsule_f5.png\" style=\"height:300px\">"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def capsule_reduce(node, msg):\n",
" b_ij_c, u_hat = msg['b_ij'], msg['u_hat']\n",
" # line 4\n",
" c_i = F.softmax(b_ij_c, dim=0)\n",
" # line 5\n",
" s_j = (c_i.unsqueeze(2).unsqueeze(3) * u_hat).sum(dim=1)\n",
" return {'h': s_j}\n",
"\n",
"def capsule_update(self, msg):\n",
" return {'h': s_j}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.4 Node Update Functions\n",
"Squash the intermidiate representations into node features $v_j$\n",
"![step6](./step6.png)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def capsule_update(msg):\n",
" # line 6\n",
" v_j = self.squash(msg['h'])\n",
" return {'h': v_j}\n",
"\n",
" v_j = squash(msg['h'])\n",
" return {'h': v_j}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.5 Edge Update Functions\n",
"Update the routing parameters\n",
"![step7](./step7.png)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def update_edge(self, u, v, edge):\n",
" # line 7\n",
" return {'b_ij': edge['b_ij'] + (v['h'] * edge['u_hat']).mean(dim=1).sum(dim=1)}\n"
" return {'b_ij': edge['b_ij'] + (v['h'] * edge['u_hat']).mean(dim=1).sum(dim=1)}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Executing algorithm"
"### 4. Executing algorithm\n",
"Call `update_all` and `update_edge` functions to execute the algorithms"
]
},
{
Expand Down
Binary file modified tutorial/capsule/capsule_f4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tutorial/capsule/capsule_f5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tutorial/capsule/squash.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tutorial/capsule/step6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tutorial/capsule/step7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 026d35c

Please sign in to comment.