-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
661 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
.ipynb_checkpoints | ||
data/ | ||
.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import numpy\n", | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"from tqdm import tqdm\n", | ||
"\n", | ||
"from model.net import *\n", | ||
"from utils.training import *\n", | ||
"from data.data import *" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Experiment Config" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"args = {\n", | ||
" 'USE_CUDA': True if torch.cuda.is_available() else False,\n", | ||
" 'BATCH_SIZE': 32,\n", | ||
" 'N_EPOCHS': 30,\n", | ||
" 'LEARNING_RATE': 0.01,\n", | ||
" 'MOMENTUM': 0.9,\n", | ||
" 'DATASET_NAME':'mnist',\n", | ||
"}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Model Loading" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#Config for 49 16d vectors in the Primary Capsule. Set Softmax dimension to 0 in this case\n", | ||
"class Config:\n", | ||
" def __init__(self):\n", | ||
" # CNN (cnn)\n", | ||
" self.cnn_in_channels = 1\n", | ||
" self.cnn_out_channels = 12\n", | ||
" self.cnn_kernel_size = 15\n", | ||
"\n", | ||
" # Primary Capsule (pc)\n", | ||
" self.pc_num_capsules = 16\n", | ||
" self.pc_in_channels = 12\n", | ||
" self.pc_out_channels = 1\n", | ||
" self.pc_kernel_size = 8\n", | ||
" self.pc_num_routes = 1 * 7 * 7\n", | ||
"\n", | ||
" # Digit Capsule 1 (dc)\n", | ||
" self.dc_num_capsules = 49\n", | ||
" self.dc_num_routes = 1 * 7 * 7\n", | ||
" self.dc_in_channels = 16\n", | ||
" self.dc_out_channels = 16\n", | ||
" \n", | ||
" # Digit Capsule 2 (dc)\n", | ||
" self.dc_2_num_capsules = 10\n", | ||
" self.dc_2_num_routes = 1 * 7 * 7\n", | ||
" self.dc_2_in_channels = 16\n", | ||
" self.dc_2_out_channels = 16\n", | ||
"\n", | ||
" # Decoder\n", | ||
" self.input_width = 28\n", | ||
" self.input_height = 28\n", | ||
"\n", | ||
"torch.manual_seed(1)\n", | ||
"config = Config()\n", | ||
"\n", | ||
"net = CapsNet(config)\n", | ||
"# capsule_net = torch.nn.DataParallel(capsule_net)\n", | ||
"if args['USE_CUDA']:\n", | ||
" net = net.cuda()\n", | ||
" \n", | ||
"net.load_state_dict(torch.load('./CapsNetMNIST.pth'), map_location='cpu')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Loading Dataset" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"trainloader, testloader = dataset(args)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Training CapsuleNet" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# torch.save(capsule_net.state_dict(), \"./CapsNetMNIST.pth \")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#Config for 16 1d vectors in Capsule Layer. Set the Softmax Dimension to 1 in this case\n", | ||
"# class Config:\n", | ||
"# def __init__(self, dataset='mnist'):\n", | ||
"# # CNN (cnn)\n", | ||
"# self.cnn_in_channels = 1\n", | ||
"# self.cnn_out_channels = 12\n", | ||
"# self.cnn_kernel_size = 15\n", | ||
"\n", | ||
"# # Primary Capsule (pc)\n", | ||
"# self.pc_num_capsules = 1\n", | ||
"# self.pc_in_channels = 12\n", | ||
"# self.pc_out_channels = 16\n", | ||
"# self.pc_kernel_size = 8\n", | ||
"# self.pc_num_routes = 16 * 7 * 7\n", | ||
"\n", | ||
"# # Digit Capsule 1 (dc)\n", | ||
"# self.dc_num_capsules = 49\n", | ||
"# self.dc_num_routes = 16 * 7 * 7\n", | ||
"# self.dc_in_channels = 1\n", | ||
"# self.dc_out_channels = 1 #16\n", | ||
" \n", | ||
"# # Digit Capsule 2 (dc)\n", | ||
"# self.dc_2_num_capsules = 10\n", | ||
"# self.dc_2_num_routes = 7 * 7\n", | ||
"# self.dc_2_in_channels = 1 #16\n", | ||
"# self.dc_2_out_channels = 16\n", | ||
"\n", | ||
"# # Decoder\n", | ||
"# self.input_width = 28\n", | ||
"# self.input_height = 28" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
Oops, something went wrong.