Skip to content

Commit

Permalink
Merge pull request fastai#113 from appleby/lesson3-fixes
Browse files Browse the repository at this point in the history
Lesson 3 fixes
  • Loading branch information
jph00 authored May 25, 2017
2 parents 2d319dd + 0ee9a52 commit ebae087
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions deeplearning1/nbs/lesson3.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@
"outputs": [],
"source": [
"# Create a 'batch' of a single image\n",
"img = np.expand_dims(ndimage.imread('cat.jpg'),0)\n",
"img = np.expand_dims(ndimage.imread('data/dogscats/test/7.jpg'),0)\n",
"# Request the generator to create batches from this image\n",
"aug_iter = gen.flow(img)"
]
Expand Down Expand Up @@ -1011,18 +1011,36 @@
" MaxPooling2D(input_shape=conv_layers[-1].output_shape[1:]),\n",
" Flatten(),\n",
" Dense(4096, activation='relu'),\n",
" Dropout(p),\n",
" BatchNormalization(),\n",
" Dense(4096, activation='relu'),\n",
" Dropout(p),\n",
" Dense(4096, activation='relu'),\n",
" BatchNormalization(),\n",
" Dropout(p),\n",
" Dense(1000, activation='softmax')\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": 68,
"execution_count": 26,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def load_fc_weights_from_vgg16bn(model):\n",
" \"Load weights for model from the dense layers of the Vgg16BN model.\"\n",
" # See imagenet_batchnorm.ipynb for info on how the weights for\n",
" # Vgg16BN can be generated from the standard Vgg16 weights.\n",
" from vgg16bn import Vgg16BN\n",
" vgg16_bn = Vgg16BN()\n",
" _, fc_layers = split_at(vgg16_bn.model, Convolution2D)\n",
" copy_weights(fc_layers, model.layers)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"collapsed": true,
"hidden": true
Expand All @@ -1046,14 +1064,13 @@
},
{
"cell_type": "code",
"execution_count": 70,
"execution_count": null,
"metadata": {
"collapsed": false,
"hidden": true
},
"outputs": [],
"source": [
"bn_model.load_weights('/data/jhoward/ILSVRC2012_img/bn_do3_1.h5')"
"load_fc_weights_from_vgg16bn(bn_model)"
]
},
{
Expand All @@ -1080,7 +1097,7 @@
"outputs": [],
"source": [
"for l in bn_model.layers: \n",
" if type(l)==Dense: l.set_weights(proc_wgts(l, 0.3, 0.6))"
" if type(l)==Dense: l.set_weights(proc_wgts(l, 0.5, 0.6))"
]
},
{
Expand Down

0 comments on commit ebae087

Please sign in to comment.