{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Working With Image Data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import load_digits\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "digits_data = load_digits()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['data', 'target', 'target_names', 'images', 'DESCR'])" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "digits_data.keys()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "labels = pd.Series(digits_data['target'])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...54555657585960616263
00.00.05.013.09.01.00.00.00.00.0...0.00.00.00.06.013.010.00.00.00.0
\n", "

1 rows × 64 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 \\\n", "0 0.0 0.0 5.0 13.0 9.0 1.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 \n", "\n", " 58 59 60 61 62 63 \n", "0 6.0 13.0 10.0 0.0 0.0 0.0 \n", "\n", "[1 rows x 64 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = pd.DataFrame(digits_data['data'])\n", "data.head(1)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAKt0lEQVR4nO3dXYhc9RnH8d+vq9L6EoxNKJINXRckIIWauAQkIDR2S6yivaiSgEKl4E0VpQWjveud3oi9KIJErWCqZKOCiNUKKq3QWneS2BpXSxJTMlWbhEZ8KTREn17sBKJd3TNnzts+/X5gcV+G/T/D5uuZmT17/o4IAcjjK20PAKBaRA0kQ9RAMkQNJEPUQDKn1fFNV6xYERMTE3V861YdO3as0fX6/X5jay1btqyxtcbHxxtba2xsrLG1mnTw4EEdPXrUC32tlqgnJiY0Oztbx7du1czMTKPrbd26tbG1pqenG1vrrrvuamyt5cuXN7ZWk6ampr7wazz8BpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSKRS17U2237K9z/YddQ8FoLxFo7Y9JulXkq6QdJGkLbYvqnswAOUUOVKvl7QvIg5ExHFJj0m6pt6xAJRVJOpVkg6d8nF/8LnPsH2T7Vnbs0eOHKlqPgBDKhL1Qn/e9T9XK4yI+yNiKiKmVq5cOfpkAEopEnVf0upTPh6X9E494wAYVZGoX5V0oe0LbJ8habOkp+odC0BZi14kISJO2L5Z0nOSxiQ9GBF7a58MQCmFrnwSEc9IeqbmWQBUgDPKgGSIGkiGqIFkiBpIhqiBZIgaSIaogWRq2aEjqyZ3zJCkt99+u7G1mtxS6LzzzmtsrR07djS2liRde+21ja63EI7UQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kU2SHjgdtH7b9ehMDARhNkSP1ryVtqnkOABVZNOqI+L2kfzUwC4AKVPacmm13gG6oLGq23QG6gVe/gWSIGkimyK+0HpX0R0lrbPdt/7j+sQCUVWQvrS1NDAKgGjz8BpIhaiAZogaSIWogGaIGkiFqIBmiBpJZ8tvu9Hq9xtZqchscSdq/f39ja01OTja21vT0dGNrNfnvQ2LbHQA1IGogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIJki1yhbbftF23O299q+tYnBAJRT5NzvE5J+FhG7bJ8jqWf7+Yh4o+bZAJRQZNuddyNi1+D9DyXNSVpV92AAyhnqObXtCUlrJb2ywNfYdgfogMJR2z5b0uOSbouIDz7/dbbdAbqhUNS2T9d80Nsj4ol6RwIwiiKvflvSA5LmIuKe+kcCMIoiR+oNkm6QtNH2nsHb92ueC0BJRbbdeVmSG5gFQAU4owxIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZJb8XlrHjh1rbK1169Y1tpbU7P5WTbrkkkvaHiE1jtRAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJFLjz4Vdt/tv3aYNudXzQxGIByipwm+h9JGyPio8Glgl+2/duI+FPNswEoociFB0PSR4MPTx+8RZ1DASiv6MX8x2zvkXRY0vMRwbY7QEcVijoiPomIiyWNS1pv+1sL3IZtd4AOGOrV74h4X9JLkjbVMg2AkRV59Xul7XMH739N0nclvVn3YADKKfLq9/mSHrY9pvn/CeyIiKfrHQtAWUVe/f6L5vekBrAEcEYZkAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8mw7c4QpqenG1srsyZ/ZsuXL29sra7gSA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDKFox5c0H+3bS46CHTYMEfqWyXN1TUIgGoU3XZnXNKVkrbVOw6AURU9Ut8r6XZJn37RDdhLC+iGIjt0XCXpcET0vux27KUFdEORI/UGSVfbPijpMUkbbT9S61QASls06oi4MyLGI2JC0mZJL0TE9bVPBqAUfk8NJDPU5Ywi4iXNb2ULoKM4UgPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJLPltd5rcVqXX+9LT35e0JrfCmZ2dbWyt6667rrG1uoIjNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRQ6TXRwJdEPJX0i6URETNU5FIDyhjn3+zsRcbS2SQBUgoffQDJFow5Jv7Pds33TQjdg2x2gG4pGvSEi1km6QtJPbF/2+Ruw7Q7QDYWijoh3Bv89LOlJSevrHApAeUU2yDvL9jkn35f0PUmv1z0YgHKKvPr9DUlP2j55+99ExLO1TgWgtEWjjogDkr7dwCwAKsCvtIBkiBpIhqiBZIgaSIaogWSIGkiGqIFklvy2O5OTk42t1eR2MZI0MzOTcq0mbd26te0RGseRGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZApFbftc2zttv2l7zvaldQ8GoJyi537/UtKzEfFD22dIOrPGmQCMYNGobS+TdJmkH0lSRByXdLzesQCUVeTh96SkI5Iesr3b9rbB9b8/g213gG4oEvVpktZJui8i1kr6WNIdn78R2+4A3VAk6r6kfkS8Mvh4p+YjB9BBi0YdEe9JOmR7zeBTl0t6o9apAJRW9NXvWyRtH7zyfUDSjfWNBGAUhaKOiD2SpmqeBUAFOKMMSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWTYS2sId999d2NrSc3uAzU11dy5Rb1er7G1/h9xpAaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGklk0attrbO855e0D27c1MRyA4S16mmhEvCXpYkmyPSbpH5KerHkuACUN+/D7ckn7I+LvdQwDYHTDRr1Z0qMLfYFtd4BuKBz14JrfV0uaWejrbLsDdMMwR+orJO2KiH/WNQyA0Q0T9RZ9wUNvAN1RKGrbZ0qalvREveMAGFXRbXf+LenrNc8CoAKcUQYkQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMo6I6r+pfUTSsH+euULS0cqH6Yas94371Z5vRsSCfzlVS9Rl2J6NiOY2dGpQ1vvG/eomHn4DyRA1kEyXor6/7QFqlPW+cb86qDPPqQFUo0tHagAVIGogmU5EbXuT7bds77N9R9vzVMH2atsv2p6zvdf2rW3PVCXbY7Z323667VmqZPtc2zttvzn42V3a9kzDav059WCDgL9p/nJJfUmvStoSEW+0OtiIbJ8v6fyI2GX7HEk9ST9Y6vfrJNs/lTQlaVlEXNX2PFWx/bCkP0TEtsEVdM+MiPfbnmsYXThSr5e0LyIORMRxSY9JuqblmUYWEe9GxK7B+x9KmpO0qt2pqmF7XNKVkra1PUuVbC+TdJmkByQpIo4vtaClbkS9StKhUz7uK8k//pNsT0haK+mVdiepzL2Sbpf0aduDVGxS0hFJDw2eWmyzfVbbQw2rC1F7gc+l+T2b7bMlPS7ptoj4oO15RmX7KkmHI6LX9iw1OE3SOkn3RcRaSR9LWnKv8XQh6r6k1ad8PC7pnZZmqZTt0zUf9PaIyHJ55Q2SrrZ9UPNPlTbafqTdkSrTl9SPiJOPqHZqPvIlpQtRvyrpQtsXDF6Y2CzpqZZnGplta/652VxE3NP2PFWJiDsjYjwiJjT/s3ohIq5veaxKRMR7kg7ZXjP41OWSltwLm4Wu+12niDhh+2ZJz0kak/RgROxteawqbJB0g6S/2t4z+NzPI+KZFmfC4m6RtH1wgDkg6caW5xla67/SAlCtLjz8BlAhogaSIWogGaIGkiFqIBmiBpIhaiCZ/wLr8rHX1UUh+gAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "first_image = data.iloc[0]\n", "np_image = first_image.values\n", "np_image = np_image.reshape(8,8)\n", "\n", "plt.imshow(np_image, cmap='gray_r')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAADeCAYAAAAU9Eo0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAQoElEQVR4nO3dQWyVddbH8d95OyGZGLQYqlEhVBI1YQPRho3J6Cw0uqIuNLICNmzGxLLCJex0YayL2ZA3UjbGyKLAwqizwdnSxjbqZJwAltCQaEssmcSFwZx3IczbKaXP4d7n/zzn4X4/G0o9vf/Dz9uTy+U5fczdBQDI63/abgAAsDEGNQAkx6AGgOQY1ACQHIMaAJL7Q4kH3bp1q4+Ojvb9OD///HNlzeLiYmXNgw8+GDpv27ZtlTVDQ0Ohx9rIwsKClpeXrZevrSvb3377rbLmm2++qazZtGlT6Lxdu3aF6uowOzu77O4j9/p1dWV7/fr1ypqrV69W1kR7GR4eDtXVoddspWbzXVhY6Puc2x555JHKmu3bt/d9zkZzocigHh0d1czMTN+Pc/r06cqao0ePVta89NJLofPefffdypotW7aEHmsjY2NjPX9tXdmurKyEzqqjRlItPUeZ2ZVevq6ubKempiprJiYmKmvef//90Hn79u0L1dWh12ylZvM9dOhQ3+fctn///sqaycnJvs/ZaC6E3vows1fM7Hszu2hm7/TdEf6DbMsi33LItjmVg9rMhiT9VdKrknZJ2m9mzf099j5GtmWRbzlk26zIK+q9ki66+2V3/1XSJ5Ka+7vW/Y1syyLfcsi2QZFB/YSk1f/ysXjrc//FzA6b2YyZzSwtLdXV3/2ObMuqzJdse8Zzt0GRQb3ev0Le8QNC3P2Eu4+5+9jISE//KDyIyLasynzJtmc8dxsUGdSLklZfe7JN0rUy7Qwcsi2LfMsh2wZFBvUFSU+Z2ZNmtknSm5LOlW1rYJBtWeRbDtk2qPI6ane/aWZvSfpC0pCkj9z9u+KdKXaN9A8//FBZE1mckaSHH364subTTz+trHn99ddD57WZbcSNGzcqa+bn50OPdfbs2cqauq8HLpFv5PrzyDXSEQcOHAjVnT9/vrJmz549fXbz39p87h47dqyy5u23366smZubC51X5/JMr0ILL+7+maTPCvcykMi2LPIth2ybw8/6AIDkGNQAkByDGgCSY1ADQHIMagBIjkENAMkxqAEguSI3DoiYnZ2trIkss1y6dKmyZufOnaGeIjcYiPQdXXhpU+SHr9epyR9uX1Jk4SWyKHTy5MnKmjNnzoR6Gh8fr6zJsLRRl8iiSuSuN9HFpLqXhXrBK2oASI5BDQDJMagBIDkGNQAkx6AGgOQY1ACQHIMaAJJjUANAcq0tvETuuvLss89W1kSXWSKee+652h4ru7oWXnbv3l3L43TF6OhoZc1DDz1UWRO5S0lkuUaKLdhEHiuyJNIVkef3hx9+GHqs6J12qhw8eLDnr+UVNQAkx6AGgOQY1ACQHIMaAJJjUANAcgxqAEiOQQ0AyTGoASC51AsvkTuu1CnS05YtWxropD+Tk5OVNfPz8w10MpjOnz9fWRO5e0v0riynTp2qrPnqq68qa7pyF57InVkimezYsSN0XmRZqPRc4BU1ACTHoAaA5BjUAJAcgxoAkmNQA0ByDGoASI5BDQDJMagBIDkGNQAk19pmYmSTZ3Z2tpazIhuHkjQzM1NZ88Ybb/TbTnGRW/5EtufOnj3bfzMDaM+ePbXUzM3Nhc6LbOF9/fXXlTVd2Ux87bXXKmsim5+R7wEpdvu10kKD2swWJP1b0m+Sbrr7WMmmBgnZlkW+5ZBtc+7lFfWf3X25WCeDjWzLIt9yyLYBvEcNAMlFB7VL+tLMZs3s8HoFZnbYzGbMbGZpaam+Du9/ZFvWhvmSbV947jYkOqifd/dnJb0q6S9m9qe1Be5+wt3H3H1sZGSk1ibvc2Rb1ob5km1feO42JDSo3f3arV9/kjQtaW/JpgYJ2ZZFvuWQbXMqB7WZPWBmm29/LOllSd+WbmwQkG1Z5FsO2TYrctXHo5Kmzex2/cfu/nnRrgYH2ZZFvuWQbYMqB7W7X5a0u+6Dd+7cWVkTWUA5ffp0LTVRR48ere2xSmU7PDxcWRO5XVdk4SXzLb1K5VuHyLLF1NRU8T561Wa2kVtoRZZUMiyyRHF5HgAkx6AGgOQY1ACQHIMaAJJjUANAcgxqAEiOQQ0AyTGoASC51u7wEll4ee+99yprIgsoY2Oxn2de1x1luiCyFBNZLLhy5UrovMjdSiJ3PemCyDJR5A4kKysrofNeeOGFypqJiYnQY3VBZFlofHy8fCMN4hU1ACTHoAaA5BjUAJAcgxoAkmNQA0ByDGoASI5BDQDJMagBIDlz9/of1GxJ0upNiK2Slms/qLxSfe9w955uyUy2IT3lu062UjfzTZetxHM34K7ZFhnUdxxiNuPusfXARLrQdxd6XE9X+u5Kn6t1peeu9LlWG33z1gcAJMegBoDkmhrUJxo6p25d6LsLPa6nK313pc/VutJzV/pcq/G+G3mPGgDQO976AIDkGNQAkFzxQW1mr5jZ92Z20czeKX1eHcxswcy+MbM5M5tpu5+76WK2UjfyJduyuphvm9kWfY/azIYk/UvSS5IWJV2QtN/d/1Hs0BqY2YKkMXdPezF+V7OV8udLtmV1Nd82sy39inqvpIvuftndf5X0iaR9hc8cFGRbDtmWRb73qPSgfkLS1VW/X7z1uexc0pdmNmtmh9tu5i66mq2UP1+yLaur+baWbemb29o6n+vC9YDPu/s1M3tE0t/M7J/u/ve2m1qjq9lK+fMl27K6mm9r2ZZ+Rb0oafuq32+TdK3wmX1z92u3fv1J0rR+/6taNp3MVupEvmRbVifzbTPb0oP6gqSnzOxJM9sk6U1J5wqf2Rcze8DMNt/+WNLLkr5tt6t1dS5bqTP5km1Zncu37WyLvvXh7jfN7C1JX0gakvSRu39X8swaPCpp2syk3/P52N0/b7elO3U0W6kD+ZJtWR3Nt9VsWSEHgOTYTASA5BjUAJAcgxoAkmNQA0ByDGoASI5BDQDJMagBIDkGNQAkx6AGgOQY1ACQHIMaAJJjUANAcgxqAEiOQQ0AyTGoASA5BjUAJMegBoDkGNQAkByDGgCSY1ADQHIMagBIjkENAMkxqAEgOQY1ACTHoAaA5BjUAJAcgxoAkmNQA0ByDGoASO4PJR5069atPjo62vfjXL16tbLm+vXrlTVDQ0Oh85555pnKmk2bNoUeayMLCwtaXl62Xr62rmxXVlYqay5dulRZ89hjj4XOe/zxx0N1dZidnV1295F7/bq6sv3xxx8raxYXFytr/vjHP4bO2759e2XN5s2bQ49VpddspfryjXzPLywsVNZke+5uNBdCg9rMXpH0oaQhSf/r7u9uVD86OqqZmZl7bnStiYmJypqpqanKmuHh4dB5586dq6yp44k2Njb2n4/byvbs2bOVNePj45U1hw8fDp137NixUF0dzOzKqo/D+daV7eTkZGXNkSNHKmuefvrp2s578cUXQ49VpddspfryjXzPHzp0qLIm23N39VxYq/KtDzMbkvRXSa9K2iVpv5ntqq27AUa2ZZFvOWTbrMh71HslXXT3y+7+q6RPJO0r29bAINuyyLccsm1QZFA/IWn1m8WLtz6H/pFtWeRbDtk2KDKo13tz2+8oMjtsZjNmNrO0tNR/Z4OBbMuqzJdse8Zzt0GRQb0oafU/K2+TdG1tkbufcPcxdx8bGenpH4UHEdmWVZkv2faM526DIoP6gqSnzOxJM9sk6U1J1ZdHIIJsyyLfcsi2QZWX57n7TTN7S9IX+v0ynI/c/bvinQ0Asi2LfMsh22aFrqN2988kfVa4l56cP3++siZy3WW0ru5rKtvKdnp6urLmzJkzlTUffPBB6LxItgcPHgw91r3I/NytMj8/H6qL/H+q6zrq1drKNrLMsnv37sqa48ePh86L7BPs2bMn9Fi9YoUcAJJjUANAcgxqAEiOQQ0AyTGoASA5BjUAJMegBoDkGNQAkFyRO7zUJfID0SOiF6NHlme6IPLnqGtRKLJYIMWWBkosvLQh8ueI3DgA64vkG3l+HzhwIHReZKGIhRcAGHAMagBIjkENAMkxqAEgOQY1ACTHoAaA5BjUAJAcgxoAkku98BIRufg9csG6JM3NzfXZTXfUdceP0dHR2s6L3D2n7jvslBC5Awl6F3nORRZeogt1GRbheEUNAMkxqAEgOQY1ACTHoAaA5BjUAJAcgxoAkmNQA0ByDGoASC71wktdiwPR5Y7IXRoiSzHRJZBSbty4UVnTdI/Dw8ONntem6IJVXQYp2zpF50uGfHlFDQDJMagBIDkGNQAkx6AGgOQY1ACQHIMaAJJjUANAcgxqAEiOQQ0AybW2mVjXht/U1FT/zdzDeZFtprY3E3fv3l1Zc/LkyQY6+X8rKyuVNZHN0C5o+lZcdd1WbdBEb7EVud1faaFBbWYLkv4t6TdJN919rGRTg4RsyyLfcsi2OffyivrP7r5crJPBRrZlkW85ZNsA3qMGgOSig9olfWlms2Z2eL0CMztsZjNmNrO0tFRfh/c/si1rw3zJti88dxsSHdTPu/uzkl6V9Bcz+9PaAnc/4e5j7j42MjJSa5P3ObIta8N8ybYvPHcbEhrU7n7t1q8/SZqWtLdkU4OEbMsi33LItjmVg9rMHjCzzbc/lvSypG9LNzYIyLYs8i2HbJsVuerjUUnTZna7/mN3/7xoV4ODbMsi33LItkGVg9rdL0uq3qC4R5OTk5U1kaWYyMXo4+PjkZYaX1QplW3kzxG52D+yTBRd7qjzsaJK5Vulrls37dixI1TXxqJQW9lKseWpiYmJypr5+fnQeZE5dOzYscqaqjn0yy+/3PW/cXkeACTHoAaA5BjUAJAcgxoAkmNQA0ByDGoASI5BDQDJMagBILnW7vASWYCILMVEao4cORJpSQcOHKisuV/upnHmzJnKmsjSQHS5I7JgU9eiSNsiC1aRJYroAlDkeyCykNEVkexOnTpV23mR75UbN25U1hw/frznHnhFDQDJMagBIDkGNQAkx6AGgOQY1ACQHIMaAJJjUANAcgxqAEjO3L3+BzVbknRl1ae2Slqu/aDySvW9w917uiUz2Yb0lO862UrdzDddthLP3YC7ZltkUN9xiNmMu48VP6hmXei7Cz2upyt9d6XP1brSc1f6XKuNvnnrAwCSY1ADQHJNDeoTDZ1Tty703YUe19OVvrvS52pd6bkrfa7VeN+NvEcNAOgdb30AQHIMagBIrvigNrNXzOx7M7toZu+UPq8OZrZgZt+Y2ZyZzbTdz910MVupG/mSbVldzLfNbIu+R21mQ5L+JeklSYuSLkja7+7/KHZoDcxsQdKYu6e9GL+r2Ur58yXbsrqab5vZln5FvVfSRXe/7O6/SvpE0r7CZw4Ksi2HbMsi33tUelA/Ienqqt8v3vpcdi7pSzObNbPDbTdzF13NVsqfL9mW1dV8W8u29M1tbZ3PdeF6wOfd/ZqZPSLpb2b2T3f/e9tNrdHVbKX8+ZJtWV3Nt7VsS7+iXpS0fdXvt0m6VvjMvrn7tVu//iRpWr//VS2bTmYrdSJfsi2rk/m2mW3pQX1B0lNm9qSZbZL0pqRzhc/si5k9YGabb38s6WVJ37bb1bo6l63UmXzJtqzO5dt2tkXf+nD3m2b2lqQvJA1J+sjdvyt5Zg0elTRtZtLv+Xzs7p+329KdOpqt1IF8ybasjubbaraskANAcmwmAkByDGoASI5BDQDJMagBIDkGNQAkx6AGgOQY1ACQ3P8BiQhKWaxLIVAAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "f, axarr = plt.subplots(2, 4)\n", "\n", "axarr[0, 0].imshow(data.iloc[0].values.reshape(8,8), cmap='gray_r')\n", "axarr[0, 1].imshow(data.iloc[99].values.reshape(8,8), cmap='gray_r')\n", "axarr[0, 2].imshow(data.iloc[199].values.reshape(8,8), cmap='gray_r')\n", "axarr[0, 3].imshow(data.iloc[299].values.reshape(8,8), cmap='gray_r')\n", "\n", "axarr[1, 0].imshow(data.iloc[999].values.reshape(8,8), cmap='gray_r')\n", "axarr[1, 1].imshow(data.iloc[1099].values.reshape(8,8), cmap='gray_r')\n", "axarr[1, 2].imshow(data.iloc[1199].values.reshape(8,8), cmap='gray_r')\n", "axarr[1, 3].imshow(data.iloc[1299].values.reshape(8,8), cmap='gray_r')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## K-Nearest Neighbors Model" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9677233358079684" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.model_selection import KFold\n", "\n", "# 50% Train / test validation\n", "def train_knn(nneighbors, train_features, train_labels):\n", " knn = KNeighborsClassifier(n_neighbors = nneighbors)\n", " knn.fit(train_features, train_labels)\n", " return knn\n", "\n", "def test(model, test_features, test_labels):\n", " predictions = model.predict(test_features)\n", " train_test_df = pd.DataFrame()\n", " train_test_df['correct_label'] = test_labels\n", " train_test_df['predicted_label'] = predictions\n", " overall_accuracy = sum(train_test_df[\"predicted_label\"] == train_test_df[\"correct_label\"])/len(train_test_df) \n", " return overall_accuracy\n", "\n", "def cross_validate(k):\n", " fold_accuracies = []\n", " kf = KFold(n_splits = 4, random_state=2)\n", " for train_index, test_index in kf.split(data):\n", " train_features, test_features = data.loc[train_index], data.loc[test_index]\n", " train_labels, test_labels = labels.loc[train_index], labels.loc[test_index]\n", " model = train_knn(k, train_features, train_labels)\n", " overall_accuracy = test(model, test_features, test_labels)\n", " fold_accuracies.append(overall_accuracy)\n", " return fold_accuracies\n", " \n", "knn_one_accuracies = cross_validate(1)\n", "np.mean(knn_one_accuracies)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "k_values = list(range(1,10))\n", "k_overall_accuracies = []\n", "\n", "for k in k_values:\n", " k_accuracies = cross_validate(k)\n", " k_mean_accuracy = np.mean(k_accuracies)\n", " k_overall_accuracies.append(k_mean_accuracy)\n", " \n", "plt.figure(figsize=(8,4))\n", "plt.title(\"Mean Accuracy vs. k\")\n", "plt.plot(k_values, k_overall_accuracies)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Neural Network With One Hidden Layer" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.model_selection import KFold\n", "from sklearn.neural_network import MLPClassifier\n", "\n", "# 50% Train / test validation\n", "def train_nn(neuron_arch, train_features, train_labels):\n", " mlp = MLPClassifier(hidden_layer_sizes=neuron_arch)\n", " mlp.fit(train_features, train_labels)\n", " return mlp\n", "\n", "def test(model, test_features, test_labels):\n", " predictions = model.predict(test_features)\n", " train_test_df = pd.DataFrame()\n", " train_test_df['correct_label'] = test_labels\n", " train_test_df['predicted_label'] = predictions\n", " overall_accuracy = sum(train_test_df[\"predicted_label\"] == train_test_df[\"correct_label\"])/len(train_test_df) \n", " return overall_accuracy\n", "\n", "def cross_validate(neuron_arch):\n", " fold_accuracies = []\n", " kf = KFold(n_splits = 4, random_state=2)\n", " for train_index, test_index in kf.split(data):\n", " train_features, test_features = data.loc[train_index], data.loc[test_index]\n", " train_labels, test_labels = labels.loc[train_index], labels.loc[test_index]\n", " \n", " model = train_nn(neuron_arch, train_features, train_labels)\n", " overall_accuracy = test(model, test_features, test_labels)\n", " fold_accuracies.append(overall_accuracy)\n", " return fold_accuracies" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n", "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n", "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n", "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n", "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n", "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n", "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n", "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n", "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n", "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n" ] }, { "data": { "text/plain": [ "[]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeoAAAEICAYAAACQ4bezAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3de5RdZX3/8fdnJplcJxeSyT0kAUIgQLgYoUgVKggJRhFYWqIgWCutgr/Wij+hstBiEdufVm0Fu2hL5WYporYYIlQRvNQbQchACIFwTZgzySQhM7lNJjPz/f2x94Sdk5nJmWSSc86cz2utWbMvz97nu5+zz/nu59mXo4jAzMzMSlNVsQMwMzOznjlRm5mZlTAnajMzsxLmRG1mZlbCnKjNzMxKmBO1mZlZCXOiNrOyIelHki7vh/XMlBSSBvVHXN2s/68l/Wsv81+RdE4P886StPZgxGXlyYm6jKQf7jZJ4/OmP5V+6cwsUlyzJHVKurUYr1+u0i/kkHRL3vRfSrqiSGEdsANNNGmSe1nSVklrJf1n17yIWBgRd/RPpPsd3xck3d3N9JB0FEBEfCki/vTQR9e7bIxWPpyoy8/LwOKuEUknAMOKFw4AHwbeAC6RNORQvvDBahEdQtuADx+Kg6xyqKu0tXwZcE5EjATmA48UNyrrL+WwD5YiJ+rycxdJYuxyOXBntoCkIZK+Iuk1Sesk/bOkYem8sZKWSGqS9EY6PC2z7GOSvijpfyVtkfQ/+S34bnwYuB7YBbwnL5bjJP1Y0qY0lr9Op1enLacX09d5QtL07rok05j+NB2+Io3ta5I2AV+QdKSkn0raKGmDpHskjcksP13S99Nt3ijpm2kdbUoPdLrKTZC0Q1JdN/W5WdLxmWl1adkJksan9bg5XecvJBX62doMfBv4fE8FJP2JpJXp+/WwpBnp9P2pqypJ10t6VdJ6SXdKGp23vsvTfWeDpM9l1n2qpGWSWtL38h8K2cA+7lNvBR6OiBcBIqIxIm7rZft+me7rb6St8IWZsrMk/Tx9zZ9IukXdtITTsqMl/ZuknKTXJf2tpOpCtq+H9e3R6pZ0WVrnG7N1ms4bJunb6TY8m9ZBdv4USd9L99+XJf2fvNe5L30ft0haIWn+fsTb42dI0mckfS+v/D9J+no63GPddbcP9jU2c6IuR78BRkk6Nv0w/DGQ/+Xzd8DRwEnAUcBU4IZ0XhXw78AM4HBgB/DNvOU/CHwEmADUANf0FIyktwPTgHuB+8gcREiqBX4CPARMSWPpah39FUnPwPnAKOBPgO0FbD/AacBLaXw3AQJuTl/jWGA66RdCWkdLgFeBmSR1cW9E7ExjvjSz3sXATyKiKftiadnvk+nJAD4A/Cwi1gOfBtYCdcBE4K+Bvjyb9ybgYklz8mdIel+6vovS9f8C+I8+rDu/rq5I//4IOAIYyd7v/x8Cc4CzgRskHZtO/wbwjYgYBRxJ8n4XqtB96jckPQyfkTS/gGR5GrAKGA/8PfBvkpTO+w7wO2Acyf5wWS/ruQNoJ9lHTwbOBfql61rSXOBb6etPSeOZlinyeZL6PBI4j+Tgu2vZKuCHwHKSffds4C8lnZdZ/r0k+/IY4AH2fj8LCpMePkMk3y8LMol7EMn3zl3p/H3VXf4+aH0VEf4rkz/gFeAcktbrzcAC4MfAIJLEMJPkA7cNODKz3OnAyz2s8yTgjcz4Y8D1mfFPAA/1EtO/Av+VeZ1dwIR0fDHwZA/LrQIu6Gb6zHRbBuXF9Kfp8BXAa/uop/d1vW4aU1N2fZlypwFrgKp0fBnwgR7WeQ7wUmb8f4EPp8M3Av8NHNXH9/MsYG06/PfAf6bDvwSuSId/BHw0s0wVyQHNjP2pK5IDpU9kxuek79mgzPqmZeb/DrgkHf458DfA+EK3az/3qQ+RHOBtAzYC1/ayfasz84an8U8iOQhtB4Zn5t8N3J2/n5EcXO0EhmXKLgYe7SG+LwBtJL0h2b/o2gfSMl2vdQPJwWHX8iPS5c9Jx18CFmTmX5nZL07r5j28Dvj3zOv8JDNvLrCjl7rdHWOhn6HMfvixdHgR8Gw63GvddbcP+q/vfz5fUJ7uIvnSnEVetzdJq2s48MSbDQsEdHVFDQe+RpLkx6bzayVVR0RHOt6YWd92klbXXpR0p7+f9Og5In4t6TWS1tPXSY7KX+xhG3qbty9r8uKYAPwj8HagliSZvZF5nVcjoj1/JRHxW0nbgDMl5UhaBA/08Jo/BYZJOo2kfk4CfpDO+38kX5j/k9b5bRHx5T5u098BL0o6MW/6DOAbkr6amSaS1lVDAetdkzc+haR3ocurvJmsuvT0/n+U5KDkOUkvA38TEUsKiKG3de4lIu4B7pE0mCRh3CPpyYh4uLf1RsT2tP5HkrSwN0VEtpdmDcn+kG8GMBjIZT4zVexdd1n3RUS2NwZJPfWiTMmuKyK2SdrY03z2fH9mAFMkbc5MqybpWemSX7dDJQ3qbp/vyT4+Q5C0mj8O/AtJL1RXa7qQuuutHq0A7vouQxHxKslFZeeTdMlmbSDpzj4uIsakf6MjuTAHkm7aOcBpkXRhviOdLvruQpJu61slNUpqJEkgXd3fa0i687rT07xt6f/hmWmT8srkfyHenE6bl27Tpby5PWuAw9XzRSx3pOUvA+6PiNbuCkVEJ0lX72KSA5ElEbElnbclIj4dEUeQnKP/K0ln9/B63YqIjSQHN1/Mm7UG+LPMezkmIoZFxK/Yv7pqIPly7dLV8lxXQIwvRMRiki7MvwPulzRiX8vtr4jYFRHfBeqB4/dVPk8OOCw9MO3SXZKGpI53kvQUdNXxqIg4ru9R9xjL7tdOYxrX03yS9yQb28t5739tRJzfT7F16e0zBPBfwDwl12ksAu7JxLevuvNPNB4gJ+ry9VHgnRGxLTsxTSj/AnwtPUpG0tTMOa1akkS+WdJh9HIRUwEuB24HTiBpYZ4EnAGcpOQirSXAJEl/qeSCrNq0RQpJl/kXJc1WYp6kcZGcH34duFTJBWd/Qs/JvkstsDXdpqnAZzLzfkfyRfhlSSMkDZV0Rmb+XSQHHJeyd+9Evu+QnJv7UDoMgKRFko5Kz422AB3pX1/9A/A2knOEXf4ZuE7ScelrjZb0foD9rKv/AD6l5EKrkcCXSLrc99n6knSppLp0H+tq4e3Pdvb2GldIene6r1QpuTjsOOC3fVlPejC7jOQCuhpJp5N3oWOmbA74H+Crkkalr3ukpDMPcHO63A8skvSHkmpIeiWy3733kbzHY5Vc2PnJzLzfAS2SPqvkorNqScdL2uOCsz6qST8HXX/V9P4ZIj2AvZ/0vH9EvJZOP9h1ZzhRl62IeDEilvUw+7PAauA3klpIzvd1Xaj0dZLbuTaQXLjz0P68fvphPhv4eiRX5nb9PZGu8/K0xfkuki/IRuAFkouYIElK95F8yFuAf+PN28w+RvJFsZHkS/pX+wjnb4BTgGbgQTK9DGl3/ntIurVfI7no648z89cCvyc56s92J+4lIn5L0oqdQnLOrstskjreCvwauDUiHkvr6UdKr3Tfl4hoITlXfVhm2g9IWq/3pu/lM8DCzGJ9ravbefPUyctAK3smht4sAFZI2kpyYdklPfVAHIAWkovnXiM5GPh74OMR8cv9WNeHSK5R2Aj8LfCfJK2/7nyY5CK3Z0m6fO8HJu/Ha+4lIlYAV5EkuVy6/ux95n9D0t39Msnn4a7Msl3770np/A0kB7mjDyCkFSQH611/H6GXz1DGHSQH5XflTT9odWcJRbhXwiqbpNuBhoi4vtix2MGj5MEpz0XEgfQiVSxJhwPPAZPSg0o7RHwxmVU0JQ8auYjkthIbQNLu4U0kLdFzgQuAvl7kZ+y+TeyvSK5ed5I+xJyorWJJ+iLwKeDmiHi52PFYv5tE0oU7jqSr+eMR8WRxQyo/6QWD60i65xcUOZyK5K5vMzOzElbQxWSSFkhaJWm1pGu7mT9D0iOS6pU84i/7SMoOJT8a8ZSknu5RNTMzs27ss0WdXrr/PMnVu2uBx4HFEfFspsx3Se4rvUPSO4GPRMRl6bytmXt492n8+PExc+bMPm+ImZlZuXriiSc2RERdd/MKOUd9Kslj+l4CkHQvyUUZz2bKzCU51wfwKMnN8ftl5syZLFvW011HZmZmA4+kV3uaV0jX91T2fATc2nRa1nLg4nT4QpJHUnY9eWeokl/c+Y2SHxjoLsAr0zLLmpqauitiZmZWkQpJ1N09WjK/v/wakuclPwmcSfK0pK4nHR0eEfNJn/8saa8nJ0XEbRExPyLm19V12/I3MzOrSIV0fa9lz+fQTiPvxwAiooHkXlTSxxJeHBHNmXlExEuSHiO5X3V/f4zBzMysohTSon4cmJ0+G7gGuIS8XxiSND69IR6Sn2C7PZ0+VtKQrjIkz4HOnts2MzOzXuwzUacP678aeBhYSfLzbisk3SjpvWmxs4BVkp4n+bm8rh8HPxZYJmk5yUVmX85eLW5mZma9K7kHnsyfPz981beZmVUSSU+k13Ptxb+eZWZmVsL8rG+zCrFtZzuPPLee1eu2IIkqiSpBVZWQoEpCpP/T8Tfnp8PdlanqGt+zTHa8qoo9XzO/TFXyP78M7PkaPZapyqx3d+xvlt9je3ooY1aqnKjNBrAdbR08umo9S+ob+Olz62nd1VnskEpWNpnvcZCRJnbRdUDQS5m88T0PDrqW2/vAJn99hZbZ6+CnmwOS7uLc40Cnhzh7P9DpvUxPcfWlTHcHUz2vb//L7Kt+usoUkxO12QDTuquDnz3fxJL6HI+sXMf2tg7Gj6zh/W+ZzqJ5k3nrzMOQoDOgM4LOCCIgMuOdAZH+75pG7L1MZ6ZM7J52YGXoZpnOzLToYb29lYnMdr25vXsv01uZbuunE4LuyuRvb3d1+uZrZst0dHb2Uqb7OAt9H6O3bcnMs73lJ/NbPngK58ydeEhe24nabABoa+/kl6ubWLI8x4+fXceWne2MHT6YC06aynvmTea0I8ZRXbVnq6BaUN3t84ys0uUfpO19ANJ7mX0eTKUHOD2W6UzX01OZznR4r1gzBzL7KtOZOThk7+3Lzo90Pdkyh48bfsjeDydqszK1q6OTX724kSXLG3h4RSMtre2MGjqIhSdMYtG8KZx+5DgGV/t6Ues7ST6QKyFO1GZlpKMz+M1LG1lSn+OhZ3K8sX0XI4cM4ty5E1l04mT+8Kg6agY5OZsNJE7UZiWuszN4/JVNLKnP8aNncmzY2sbwmmrOOXYii+ZN5h1H1zF0cHWxwzSzg8SJ2qwEdXYGT655gx8uz7H06Rzrt+xk6OAqzj4mSc5nzZnAsBonZ7NK4ERtViIiguVrm1myvIGlT+doaG6lZlAVZx1dx6ITp3D2MRMYMcQfWbNK40+9WRFFBCsaWlhSn+PBpxtYs2kHg6vFO2bX8ZkFczjn2InUDh1c7DDNrIicqM0OsYhg1botLFme48Gnc7y8YRuDqsQZR43nk++czXlzJzF6uJOzmSWcqM0OkdXrt7CkPseS+hyr12+lSnD6keO48h1HsOC4SYwdUVPsEM2sBDlRmx1Er2zYxpL6BpbU53iucQsSnDrzMC5/3/EsPH4S40cOKXaIZlbinKjN+tmaTdt3n3N+5vUWAObPGMsX3jOXhSdMZuKooUWO0MzKiRO1WT9o2LyDpU/n+GF9juVrNgNw4vQxXP/uYzn/hMlMGTOsyBGaWblyojbbT+taWln6dHLO+YlX3wDg+KmjuHbhMbz7hMlMP+zQPQvYzAYuJ2qzPtiwdSc/SpPz717ZRAQcM6mWa849mnfPm8Ks8SOKHaKZDTBO1Gb78Ma2Nh5a0ciS+gZ+/eJGOgOOmjCSvzh7NovmTeaoCbXFDtHMBrCCErWkBcA3gGrgXyPiy3nzZwC3A3XAJuDSiFibmT8KWAn8ICKu7qfYzQ6a5u27ePjZRpbU5/jf1Rvo6AxmjR/BVX90FIvmTeHoiSOL/mPyZlYZ9pmoJVUDtwDvAtYCj0t6ICKezRT7CnBnRNwh6Z3AzcBlmflfBH7Wf2Gb9b8trbv4ycp1LFme4+cvNLGrI5h+2DA+9vYjWDRvMsdNGeXkbGaHXCEt6lOB1RHxEoCke4ELgGyingt8Kh1+FPivrhmS3gJMBB4C5vdDzGb9ZtvOdh55bj1Lljfw2PNNtLV3MmX0UK5420wWzZvCvGmjnZzNrKgKSdRTgTWZ8bXAaXlllgMXk3SPXwjUShoHvAF8laR1fXZPLyDpSuBKgMMPP7zQ2M32y462Dh5dtZ4l9Q389Ln1tO7qZOKoIXzotMNZNG8yJ08fS1WVk7OZlYZCEnV331iRN34N8E1JVwA/B14H2oFPAEsjYk1vrZKIuA24DWD+/Pn56zY7YK27OvjZ8008WJ/jJyvXsb2tg/Eja3j/W6azaN5k3jrzMCdnMytJhSTqtcD0zPg0oCFbICIagIsAJI0ELo6IZkmnA2+X9AlgJFAjaWtEXNsv0Zv1oq29k1+ubmLJ8hw/fnYdW3a2M3b4YC44aSrvmTeZ044YR7WTs5mVuEIS9ePAbEmzSFrKlwAfzBaQNB7YFBGdwHUkV4ATER/KlLkCmO8kbQfTro5OfvXiRpYsb+DhFY20tLYzauggFp4wiUXzpnD6keMYXF1V7DDNzAq2z0QdEe2SrgYeJrk96/aIWCHpRmBZRDwAnAXcLClIur6vOogxm+2hozP47Usb+WF9joeeyfHG9l2MHDKIc+dOZNGJk/nDo+qoGeTkbGblSRGldUp4/vz5sWzZsmKHYSWuszN4/JVNLKnP8aNncmzY2sbwmmrOOXYii+ZN5h1H1zF0cHWxwzQzK4ikJyKi2zuj/GQyKxudncGTa95gSX2OpU/nWNeyk6GDqzj7mCQ5nzVnAsNqnJzNbGBxoraSFhHUr21mSX0DD9bnaGhupWZQFWcdXceiE6dw9jETGDHEu7GZDVz+hrOSExGsaGjZ/ZvOazbtYHC1eMfsOj6zYA7nHDuR2qGDix2mmdkh4URtJSEiWLVuCw/WJ79M9fKGbQyqEmccNZ5PvnM2582dxOjhTs5mVnmcqK2oVq/fypL6BpbU51i9fitVgtOPHMeV7ziCBcdNYuyImmKHaGZWVE7Udsi9smHb7uT8XOMWJDh15mFc/r7jWXj8JMaPHFLsEM3MSoYTtR0SazZt58Gncyypb+CZ11sAmD9jLF94z1wWnjCZiaOGFjlCM7PS5ERtB03D5h0sfTrHD+tzLF+zGYATp4/h+ncfy/knTGbKmGFFjtDMrPQ5UVu/Wt/Smracczzx6hsAHD91FNcuPIZ3nzCZ6YcNL3KEZmblxYna9kt7RyevbNzOqsYtrGps4bnGLTzXuIXXNm0H4JhJtVxz7tG8e94UZo0fUeRozczKlxO19SoiaNqyk+cat7AqTcar1rXw/LqttLV3AlAlmDV+BCdMG80fv3U65x03kaMm1BY5cjOzgcGJ2nbb3taetpC37E7Mq9ZtYdO2tt1lJtQOYc6kWq5420zmTKxlzqRajpow0s/VNjM7SJyoK1BHZ/DKxm1JQs61pK3kpNu66zdahtdUc/TEWs6dO5E5k2o5ZtIojplU6/uazcwOMSfqASwiaNq6c3creWUu6bZ+Yd1WduZ1Wx8/ZTQXnzItTcq1TB87nKoqFXkLzMzMiXqA2N7WzvPrtr55YVdu727rutohHDOplg+fPoM5aQvZ3dZmZqXNibrM7NFtnV5xvapxC69muq2HDa7m6Em1vOvYiRwzuXZ31/Vh7rY2Mys7TtQlLLnaumWPi7ueX7dlj27rmeNHMHfKKC48eRrHTHa3tZnZQONEXQJ2tHXw/LotPJd2W3edU97YTbf1ZX8wY3cLefZEd1ubmQ10TtRFdNvPX+Q7v32t227rc46duPvCrjmTahnnH6owM6tIBSVqSQuAbwDVwL9GxJfz5s8AbgfqgE3ApRGxNp3+/XS5wcA/RcQ/92P8ZWvp0zm+tPQ5Tp11GBee/ObV1ocf5m5rMzN70z4TtaRq4BbgXcBa4HFJD0TEs5liXwHujIg7JL0TuBm4DMgBb4uInZJGAs+kyzb0+5aUkRebtvKZ7y7n5MPHcPdHT6NmUFWxQzIzsxJVSIY4FVgdES9FRBtwL3BBXpm5wCPp8KNd8yOiLSJ2ptOHFPh6A9r2tnY+fvcTDBlczS0fPMVJ2szMelVIlpgKrMmMr02nZS0HLk6HLwRqJY0DkDRdUn26jr/rrjUt6UpJyyQta2pq6us2lI2I4HM/eIYX1m/lG5ec5J95NDOzfSokUXd3wjTyxq8BzpT0JHAm8DrQDhARayJiHnAUcLmkiXutLOK2iJgfEfPr6ur6tAHl5J7fvsYPnnydT51zNG+fPXC308zM+k8hiXotMD0zPg3Yo1UcEQ0RcVFEnAx8Lp3WnF8GWAG8/YAiLlP1azdz4w+f5aw5dVz9R0cVOxwzMysThSTqx4HZkmZJqgEuAR7IFpA0XlLXuq4juQIcSdMkDUuHxwJnAKv6K/hy8ca2Nj5+9++pqx3C1z5wkq/qNjOzgu0zUUdEO3A18DCwErgvIlZIulHSe9NiZwGrJD0PTARuSqcfC/xW0nLgZ8BXIuLpft6GktbZGXzqvqdo2rKTWz90in99yszM+qSg+6gjYimwNG/aDZnh+4H7u1nux8C8A4yxrH3z0dU8tqqJL77veE6cPqbY4ZiZWZnxvUEH0S9eaOJrP3me9500hUtPO7zY4ZiZWRlyoj5IGjbv4C/ufYrZE0bypYtOQPJ5aTMz6zsn6oOgrb2Tq77ze3bu6uBbl76F4TV+pLqZme0fZ5CD4EtLV/Lka5u59UOncGTdyGKHY2ZmZcwt6n72w+UNfPtXr/AnZ8zi/BMmFzscMzMrc07U/Wj1+i189nv1vGXGWK47/5hih2NmZgOAE3U/2baznT+/+/cMS39sY3C1q9bMzA6cz1H3g4jguu8/zUtNW7nro6cxafTQYodkZmYDhJt9/eCu37zKA8sb+PS5czjjqPHFDsfMzAYQJ+oD9ORrb/DFJc9y9jET+PiZRxY7HDMzG2CcqA/Apm1tXHXP75k4aij/4B/bMDOzg8DnqPdTR2fwF/c+yYatbXzv429j9PDBxQ7JzMwGICfq/fSPj7zAL17YwJcuPIETpo0udjhmZjZAuet7Pzy2aj3/+NMXuOiUqSw+dXqxwzEzswHMibqPXt+8g7/8z6eYM7GWm97nH9swM7ODy4m6D3a2d/CJe35PR0fwrUvfwrCa6mKHZGZmA5zPUffBTQ+uZPmazfzzpacwa/yIYodjZmYVwC3qAv33U69z569f5WNvn8WC4/1jG2ZmdmgUlKglLZC0StJqSdd2M3+GpEck1Ut6TNK0dPpJkn4taUU674/7ewMOhefXbeHa7z3NW2eO5f8u8I9tmJnZobPPRC2pGrgFWAjMBRZLmptX7CvAnRExD7gRuDmdvh34cEQcBywAvi5pTH8Ffyhs3dnOn9/9BCOGDOKb/rENMzM7xArJOqcCqyPipYhoA+4FLsgrMxd4JB1+tGt+RDwfES+kww3AeqCuPwI/FCKCz36vnlc2bOOfFp/MxFH+sQ0zMzu0CknUU4E1mfG16bSs5cDF6fCFQK2kcdkCkk4FaoAX9y/UQ+/bv3qFB+tzfOa8Yzj9yHH7XsDMzKyfFZKou7tROPLGrwHOlPQkcCbwOtC+ewXSZOAu4CMR0bnXC0hXSlomaVlTU1PBwR9MT7z6Bjc9uJJzjp3In595RLHDMTOzClVIol4LZB+/NQ1oyBaIiIaIuCgiTgY+l05rBpA0CngQuD4iftPdC0TEbRExPyLm19UVv2d8w9adXHXP75kyZhhf/cCJfqiJmZkVTSGJ+nFgtqRZkmqAS4AHsgUkjZfUta7rgNvT6TXAD0guNPtu/4V9cH32/no2bW/j1g+dwuhh/rENMzMrnn0m6ohoB64GHgZWAvdFxApJN0p6b1rsLGCVpOeBicBN6fQPAO8ArpD0VPp3Un9vRH/q7Ax+/kITl542g+On+sc2zMysuAp6MllELAWW5k27ITN8P3B/N8vdDdx9gDEeUhu3tbGrI5g5fnixQzEzM/OTyfLlmncAMMm3YpmZWQlwos6Ta24FYPLoYUWOxMzMzIl6L41pop402i1qMzMrPifqPLnmVgZXi3EjaoodipmZmRN1vsbmHUwaPZSqKt87bWZmxedEnaehuZXJo3x+2szMSoMTdZ7G5lafnzYzs5LhRJ0RETQ2tzLZidrMzEqEE3XGpm1ttHV0ukVtZmYlw4k6w/dQm5lZqXGizmjcnajdojYzs9LgRJ3R9fhQJ2ozMysVTtQZueZWBlWJcSOHFDsUMzMzwIl6D43NrUwcNZRqP+zEzMxKhBN1Rs73UJuZWYlxos5obPE91GZmVlqcqFMRQcPmHU7UZmZWUpyoU5u372JneyeTfA+1mZmVECfqVM73UJuZWQkqKFFLWiBplaTVkq7tZv4MSY9Iqpf0mKRpmXkPSdosaUl/Bt7fGluSe6h9MZmZmZWSfSZqSdXALcBCYC6wWNLcvGJfAe6MiHnAjcDNmXn/D7isf8I9eNyiNjOzUlRIi/pUYHVEvBQRbcC9wAV5ZeYCj6TDj2bnR8QjwJZ+iPWgamxupbpKTKh1ojYzs9JRSKKeCqzJjK9Np2UtBy5Ohy8EaiWNKzQISVdKWiZpWVNTU6GL9auGza1MqB3ih52YmVlJKSRRd5e5Im/8GuBMSU8CZwKvA+2FBhERt0XE/IiYX1dXV+hi/aqxZYfPT5uZWckZVECZtcD0zPg0oCFbICIagIsAJI0ELo6I5v4K8lDINbdyzKTaYodhZma2h0Ja1I8DsyXNklQDXAI8kC0gabykrnVdB9zev2EeXBFBY3Mrk0b5HmozMyst+0zUEdEOXA08DKwE7ouIFZJulPTetNhZwCpJzwMTgZu6lpf0C+C7wNmS1ko6r5+34YC1tLazva2DKWPc9W1mZqWlkK5vImIpsDRv2g2Z4fuB+3tY9u0HEuCh0JjemuVz1GZmVmr8ZDKgoTl52InvoTYzs1LjRE22Re1z1GZmVlqcqEmu+JZgQu2QYodiZma2BydqoB1QGdwAAAzrSURBVLF5B3UjhzC42tVhZmalxZmJpEU9eYy7vc3MrPQ4UZMm6lG+kMzMzEqPEzXJxWS+NcvMzEpRxSfqLa272Lqz3bdmmZlZSar4RO2HnZiZWSmr+ESdSxP1ZN9DbWZmJajiE3Xj7kTtFrWZmZWeik/UXY8Pneirvs3MrARVfKJubG5l/Mgh1Ayq+KowM7MSVPHZKdfc6m5vMzMrWRWfqH0PtZmZlbKKT9S55h1McaI2M7MSVdGJetvOdlpa2/3zlmZmVrIqOlHnfGuWmZmVuIpO1H4qmZmZlbqCErWkBZJWSVot6dpu5s+Q9IikekmPSZqWmXe5pBfSv8v7M/gDlUvvoXaL2szMStU+E7WkauAWYCEwF1gsaW5esa8Ad0bEPOBG4OZ02cOAzwOnAacCn5c0tv/CPzBdLWo/7MTMzEpVIS3qU4HVEfFSRLQB9wIX5JWZCzySDj+amX8e8OOI2BQRbwA/BhYceNj9I9fSyrgRNQwdXF3sUMzMzLpVSKKeCqzJjK9Np2UtBy5Ohy8EaiWNK3BZJF0paZmkZU1NTYXGfsBym3f4/LSZmZW0QhK1upkWeePXAGdKehI4E3gdaC9wWSLitoiYHxHz6+rqCgipf/ipZGZmVuoKSdRrgemZ8WlAQ7ZARDRExEURcTLwuXRacyHLFlNji59KZmZmpa2QRP04MFvSLEk1wCXAA9kCksZL6lrXdcDt6fDDwLmSxqYXkZ2bTiu6HW0dbN6+y79DbWZmJW2fiToi2oGrSRLsSuC+iFgh6UZJ702LnQWskvQ8MBG4KV12E/BFkmT/OHBjOq3oGlvSe6h9xbeZmZWwQYUUioilwNK8aTdkhu8H7u9h2dt5s4VdMnbfQz3GidrMzEpXxT6ZLLe56/Gh7vo2M7PSVbGJ2l3fZmZWDio2UeeadzBm+GCG1fhhJ2ZmVroqNlE3Nre6NW1mZiWvYhN1rrmVKWN8ftrMzEpbxSbqxmY/7MTMzEpfRSbq1l0dbNzWxmR3fZuZWYmryES9ruuKb7eozcysxFVkos41+x5qMzMrDxWZqBub3aI2M7PyUJGJ+s0WtRO1mZmVtgpN1DsYNXQQI4YU9KhzMzOzoqnQRN3q89NmZlYWKjJR+x5qMzMrFxWZqJMWtRO1mZmVvopL1G3tnWzYutMtajMzKwsVl6i7HnYyxeeozcysDFRcos75HmozMysjFZiodwC+h9rMzMpDQYla0gJJqyStlnRtN/MPl/SopCcl1Us6P51eI+nfJT0tabmks/o5/j7zU8nMzKyc7DNRS6oGbgEWAnOBxZLm5hW7HrgvIk4GLgFuTad/DCAiTgDeBXxVUlFb8bnmVkYOGUTt0MHFDMPMzKwghSTNU4HVEfFSRLQB9wIX5JUJYFQ6PBpoSIfnAo8ARMR6YDMw/0CDPhCNvjXLzMzKSCGJeiqwJjO+Np2W9QXgUklrgaXAJ9Ppy4ELJA2SNAt4CzA9/wUkXSlpmaRlTU1NfdyEvsm1+GEnZmZWPgpJ1OpmWuSNLwa+HRHTgPOBu9Iu7ttJEvsy4OvAr4D2vVYWcVtEzI+I+XV1dX2Jv89ym3e4RW1mZmWjkF+lWMuereBpvNm13eWjwAKAiPi1pKHA+LS7+1NdhST9CnjhgCI+ALs6OmnaupNJvofazMzKRCEt6seB2ZJmSaohuVjsgbwyrwFnA0g6FhgKNEkaLmlEOv1dQHtEPNtv0ffR+i07ifCtWWZmVj722aKOiHZJVwMPA9XA7RGxQtKNwLKIeAD4NPAvkj5F0i1+RUSEpAnAw5I6gdeByw7alhSgMb2H2ueozcysXBT0g8wRsZTkIrHstBsyw88CZ3Sz3CvAnAMLsf90PZXMjw81M7NyUVFPJstt9sNOzMysvFRWom5uZXhNNaOGFtSRYGZmVnQVlagbW3YwafRQpO7uODMzMys9FZWoc34qmZmZlZmKStSNza1MGuULyczMrHxUTKJu7+hk/ZadTBnjFrWZmZWPiknUTVt30tEZvuLbzMzKSsUk6q57qH2O2szMyknFJOrGNFH7HLWZmZWTiknUblGbmVk5qphE3di8g6GDqxgzfHCxQzEzMytYxSTq5B7qYX7YiZmZlZWKStSTRrnb28zMykvFJOpGP5XMzMzKUEUk6o7OYF1Lq++hNjOzslMRiXrj1p20d4Zb1GZmVnYqIlG/eWuW76E2M7PyUiGJegeAu77NzKzsFJSoJS2QtErSaknXdjP/cEmPSnpSUr2k89PpgyXdIelpSSslXdffG1AIP+zEzMzK1T4TtaRq4BZgITAXWCxpbl6x64H7IuJk4BLg1nT6+4EhEXEC8BbgzyTN7J/QC9fY3EpNdRWHjag51C9tZmZ2QAppUZ8KrI6IlyKiDbgXuCCvTACj0uHRQENm+ghJg4BhQBvQcsBR91GuObni2w87MTOzclNIop4KrMmMr02nZX0BuFTSWmAp8Ml0+v3ANiAHvAZ8JSI2HUjA+6Ox2bdmmZlZeSokUXfXDI288cXAtyNiGnA+cJekKpLWeAcwBZgFfFrSEXu9gHSlpGWSljU1NfVpAwqRa9nBFCdqMzMrQ4Uk6rXA9Mz4NN7s2u7yUeA+gIj4NTAUGA98EHgoInZFxHrgf4H5+S8QEbdFxPyImF9XV9f3rehFZ2ekLWrfmmVmZuWnkET9ODBb0ixJNSQXiz2QV+Y14GwASceSJOqmdPo7lRgB/AHwXH8FX4iN29rY1eGHnZiZWXnaZ6KOiHbgauBhYCXJ1d0rJN0o6b1psU8DH5O0HPgP4IqICJKrxUcCz5Ak/H+PiPqDsB09akxvzfI5ajMzK0eDCikUEUtJLhLLTrshM/wscEY3y20luUWraLoeduIWtZmZlaMB/2SyxhY/PtTMzMrXgE/UDZtbGVwtxvlhJ2ZmVoYGfKJubN7BxFFDqaryw07MzKz8DPhEnWtu9flpMzMrWwM+UTe2+B5qMzMrXwM6UUeEW9RmZlbWBnSi3tneyfnHT+Lk6WOKHYqZmdl+Keg+6nI1dHA1X7/k5GKHYWZmtt8GdIvazMys3DlRm5mZlTAnajMzsxLmRG1mZlbCnKjNzMxKmBO1mZlZCXOiNjMzK2FO1GZmZiVMEVHsGPYgqQl4tZci44ENhyicgcp12D9cj/3D9dg/XI/9o1j1OCMi6rqbUXKJel8kLYuI+cWOo5y5DvuH67F/uB77h+uxf5RiPbrr28zMrIQ5UZuZmZWwckzUtxU7gAHAddg/XI/9w/XYP1yP/aPk6rHszlGbmZlVknJsUZuZmVUMJ2ozM7MSVjaJWtICSaskrZZ0bbHjKSeSXpH0tKSnJC1Lpx0m6ceSXkj/jy12nKVG0u2S1kt6JjOt23pT4h/T/bNe0inFi7y09FCPX5D0erpPPiXp/My869J6XCXpvOJEXVokTZf0qKSVklZI+ot0uvfHPuilHkt6fyyLRC2pGrgFWAjMBRZLmlvcqMrOH0XESZn7A68FHomI2cAj6bjt6dvAgrxpPdXbQmB2+ncl8K1DFGM5+DZ71yPA19J98qSIWAqQfq4vAY5Ll7k1/fxXunbg0xFxLPAHwFVpXXl/7Jue6hFKeH8si0QNnAqsjoiXIqINuBe4oMgxlbsLgDvS4TuA9xUxlpIUET8HNuVN7qneLgDujMRvgDGSJh+aSEtbD/XYkwuAeyNiZ0S8DKwm+fxXtIjIRcTv0+EtwEpgKt4f+6SXeuxJSeyP5ZKopwJrMuNr6b1ybU8B/I+kJyRdmU6bGBE5SHZeYELRoisvPdWb99G+uzrtlr09c+rF9bgPkmYCJwO/xfvjfsurRyjh/bFcErW6meb7ygp3RkScQtIddpWkdxQ7oAHI+2jffAs4EjgJyAFfTae7HnshaSTwPeAvI6Klt6LdTHM9prqpx5LeH8slUa8FpmfGpwENRYql7EREQ/p/PfADkq6bdV1dYen/9cWLsKz0VG/eR/sgItZFREdEdAL/wpvdia7HHkgaTJJc7omI76eTvT/2UXf1WOr7Y7kk6seB2ZJmSaohObn/QJFjKguSRkiq7RoGzgWeIam/y9NilwP/XZwIy05P9fYA8OH0ats/AJq7uiRtb3nnSy8k2SchqcdLJA2RNIvkYqjfHer4So0kAf8GrIyIf8jM8v7YBz3VY6nvj4MO9Qvuj4hol3Q18DBQDdweESuKHFa5mAj8INk/GQR8JyIekvQ4cJ+kjwKvAe8vYowlSdJ/AGcB4yWtBT4PfJnu620pcD7JxSbbgY8c8oBLVA/1eJakk0i6EV8B/gwgIlZIug94luQK3asioqMYcZeYM4DLgKclPZVO+2u8P/ZVT/W4uJT3Rz9C1MzMrISVS9e3mZlZRXKiNjMzK2FO1GZmZiXMidrMzKyEOVGbmZmVMCdqMzOzEuZEbWZmVsL+P8rExRtmh1dpAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "nn_one_neurons = [\n", " (8,),\n", " (16,),\n", " (32,),\n", " (64,),\n", " (128,),\n", " (256,)\n", "]\n", "nn_one_accuracies = []\n", "\n", "for n in nn_one_neurons:\n", " nn_accuracies = cross_validate(n)\n", " nn_mean_accuracy = np.mean(nn_accuracies)\n", " nn_one_accuracies.append(nn_mean_accuracy)\n", "\n", "plt.figure(figsize=(8,4))\n", "plt.title(\"Mean Accuracy vs. Neurons In Single Hidden Layer\")\n", "\n", "x = [i[0] for i in nn_one_neurons]\n", "plt.plot(x, nn_one_accuracies)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Summary:\n", "\n", "It looks like adding more neurons to the single hidden layer helped massively improved simple accuracy from approximately `86%` to approximately `94%`. Simple accuracy computes the number of correct classifications the model made, but doesn't tell us anything about false or true positives or false or true negatives.\n", "\n", "Given that k-nearest neighbors achieved approximately `96%` accuracy, there doesn't seem to be any advantages to using a single hidden layer neural network for this problem." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Neural Network With Two Hidden Layers" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "nn_two_neurons = [\n", " (64,64),\n", " (128, 128),\n", " (256, 256)\n", "]\n", "nn_two_accuracies = []\n", "\n", "for n in nn_two_neurons:\n", " nn_accuracies = cross_validate(n)\n", " nn_mean_accuracy = np.mean(nn_accuracies)\n", " nn_two_accuracies.append(nn_mean_accuracy)\n", "\n", "plt.figure(figsize=(8,4))\n", "plt.title(\"Mean Accuracy vs. Neurons In Two Hidden Layers\")\n", "\n", "x = [i[0] for i in nn_two_neurons]\n", "plt.plot(x, nn_two_accuracies)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0.9398861667903984, 0.9571492204899779, 0.9515837663944569]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nn_two_accuracies" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Summary\n", "\n", "Using 2 hidden layers improved our simple accuracy to `95%`. While I'd traditionally be worried about overfitting, using 4-fold cross validation also gives me a bit more assurance that the model is generalizing to achieve the extra `1%` in simple accuracy over the single hidden layer networks we tried earlier." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Neural Network With Three Hidden Layers" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.model_selection import KFold\n", "\n", "# 50% Train / test validation\n", "def train_nn(neuron_arch, train_features, train_labels):\n", " mlp = MLPClassifier(hidden_layer_sizes=neuron_arch)\n", " mlp.fit(train_features, train_labels)\n", " return mlp\n", "\n", "def test(model, test_features, test_labels):\n", " predictions = model.predict(test_features)\n", " train_test_df = pd.DataFrame()\n", " train_test_df['correct_label'] = test_labels\n", " train_test_df['predicted_label'] = predictions\n", " overall_accuracy = sum(train_test_df[\"predicted_label\"] == train_test_df[\"correct_label\"])/len(train_test_df) \n", " return overall_accuracy\n", "\n", "def cross_validate_six(neuron_arch):\n", " fold_accuracies = []\n", " kf = KFold(n_splits = 6, random_state=2)\n", " for train_index, test_index in kf.split(data):\n", " train_features, test_features = data.loc[train_index], data.loc[test_index]\n", " train_labels, test_labels = labels.loc[train_index], labels.loc[test_index]\n", " \n", " model = train_nn(neuron_arch, train_features, train_labels)\n", " overall_accuracy = test(model, test_features, test_labels)\n", " fold_accuracies.append(overall_accuracy)\n", " return fold_accuracies" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n", "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n", "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n", "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n", "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n", "/home/yunoac/anaconda3/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:566: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n", " % self.max_iter, ConvergenceWarning)\n" ] }, { "data": { "text/plain": [ "[]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "nn_three_neurons = [\n", " (10, 10, 10),\n", " (64, 64, 64),\n", " (128, 128, 128)\n", "]\n", "\n", "nn_three_accuracies = []\n", "\n", "for n in nn_three_neurons:\n", " nn_accuracies = cross_validate_six(n)\n", " nn_mean_accuracy = np.mean(nn_accuracies)\n", " nn_three_accuracies.append(nn_mean_accuracy)\n", "\n", "plt.figure(figsize=(8,4))\n", "plt.title(\"Mean Accuracy vs. Neurons In Three Hidden Layers\")\n", "\n", "x = [i[0] for i in nn_three_neurons]\n", "plt.plot(x, nn_three_accuracies)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0.9154124860646601, 0.950479375696767, 0.9582552954292085]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nn_three_accuracies" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Summary\n", "\n", "Using 3 hidden layers improved our simple accuracy to `96%`, even with 6-fold cross validation. This seems to be in line with the research literature out there about deep neural networks for computer vision. Having more layers and more neurons tends to improve the network's performance." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }