Skip to content

Commit

Permalink
Smaller figure grid, remove tics from bar plots.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkDaoust authored Aug 16, 2018
1 parent b2e0358 commit 05f230d
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions samples/core/tutorials/keras/basic_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -541,9 +541,6 @@
},
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"plt.figure(figsize=(10,10))\n",
"for i in range(25):\n",
" plt.subplot(5,5,i+1)\n",
Expand Down Expand Up @@ -882,6 +879,8 @@
"\n",
"def plot_value_array(predictions_array, true_label):\n",
" plt.grid('off')\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" thisplot = plt.bar(range(10), predictions_array, color=\"#777777\")\n",
" plt.ylim([0, 1]) \n",
" predicted_label = np.argmax(predictions_array)\n",
Expand Down Expand Up @@ -961,16 +960,15 @@
"source": [
"# Plot the first X test images, their predicted label, and the true label\n",
"# Color correct predictions in blue, incorrect predictions in red\n",
"num_images = int(50)\n",
"plt.figure(figsize=(24,20))\n",
"num_rows = 5\n",
"num_cols = 3\n",
"num_images = num_rows*num_cols\n",
"plt.figure(figsize=(2*2*num_cols, 2*num_rows))\n",
"for i in range(num_images):\n",
" plt.subplot(10,num_images / 5,2*i+1)\n",
" plot_image(predictions[i], test_labels[i], test_images[i])\n",
" \n",
" plt.subplot(10,num_images / 5,2*i+2)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" plot_value_array(predictions[i], test_labels[i])"
" plt.subplot(num_rows, 2*num_cols, 2*i+1)\n",
" plot_image(predictions[i], test_labels[i], test_images[i])\n",
" plt.subplot(num_rows, 2*num_cols, 2*i+2)\n",
" plot_value_array(predictions[i], test_labels[i])\n",
],
"execution_count": 0,
"outputs": []
Expand Down

0 comments on commit 05f230d

Please sign in to comment.