{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Hello World with CNN: Detect handwritten digits with a Resnet-34 based CNN.\n", "\n", "![](https://upload.wikimedia.org/wikipedia/commons/2/27/MnistExamples.png \"MNIST Examples\")\n", "\n", "To learn the cnn we will use the common [MNIST](http://yann.lecun.com/exdb/mnist/) dataset:\n", "- 70.000 grayscale images of handwritten digits (60.000 train / 10.000 test)\n", "- each image has a size of 28x28\n", "- dataset for teaching and benchmarking" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 1: Import Python Packages" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from torchvision import transforms,datasets,models\n", "from torch.utils.data import Dataset, DataLoader\n", "import torch\n", "import matplotlib.pyplot as plt\n", "from PIL import Image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 2: Load MNIST Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# define batch size\n", "batchsize=4096\n", "\n", "# define some data transformations\n", "transformations= transforms.Compose([transforms.ToTensor(),\n", " transforms.Lambda(lambda x: x.repeat(3, 1, 1)),\n", " transforms.Normalize((0.1307,), (0.3081,))])\n", "\n", "\n", "# define train and test dataset\n", "mnist_train = datasets.MNIST(\".\",train=True,transform=transformations,download=True)\n", "mnist_test = datasets.MNIST(\".\",train=False,transform=transformations,download=True)\n", "\n", "# define dataloaders\n", "train_dataloader = DataLoader(mnist_train, batch_size=batchsize, num_workers=7)\n", "test_dataloader = DataLoader(mnist_test, batch_size=batchsize, num_workers=7)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 3: Show some data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# create unnormalizer\n", "un_normalizer = transforms.Normalize((-0.1307/0.3081,),(1/0.3081,))\n", "\n", "# create plot grid\n", "fig, axes = plt.subplots(ncols=3,nrows=3,figsize=(10,10))\n", "\n", "# plot title for figure \n", "fig.suptitle('These are 9 examples of the MNIST Dataset', fontsize=16)\n", "\n", "for idx,ax in enumerate(axes.flatten()):\n", " # transform data back to original value range and shape for displaying\n", " image = (np.array(un_normalizer(mnist_train[idx][0]).permute(1,2,0))*255).astype(int)\n", " # show image\n", " ax.imshow(image)\n", " # plot title for every image\n", " ax.set_title(\"Label: \" + str(mnist_train[idx][1]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 4: Get a Resnet-34 Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = models.resnet34(num_classes=10).cuda()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 5: Get a Loss Function / Optimzer." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_function = torch.nn.CrossEntropyLoss()\n", "optimzer = torch.optim.Adam(model.parameters(), 1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 6: Define some functions for training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# function to calculate accuray\n", "def accuracy(outputs, targets):\n", " n = targets.shape[0]\n", " outputs = outputs.argmax(dim=-1).view(n,-1)\n", " targets = targets.view(n,-1)\n", " res = (outputs==targets).float().mean()\n", " return res\n", "\n", "# function to test the model on unseen data\n", "def test(x,y):\n", " y_hat = model(x)\n", " loss = loss_function(y_hat, y)\n", " acc = accuracy(y_hat, y)\n", " return loss.item(), acc.item()\n", "\n", "# function to update the model weigths\n", "def update(x,y):\n", " y_hat = model(x) \n", " loss = loss_function(y_hat, y)\n", " loss.backward()\n", " optimzer.step()\n", " optimzer.zero_grad()\n", " acc = accuracy(y_hat, y)\n", " return loss.item(), acc.item()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 7: Fit the model for 10 epochs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_losses=[]\n", "test_losses=[]\n", "train_accuracys=[]\n", "test_accuracys=[]\n", "epochs=10\n", "\n", "# iterate over epochs\n", "for ep in range(epochs): \n", " \n", " # iterate over training data and update the model weigths\n", " for idx,(image,label) in enumerate(train_dataloader):\n", " loss_train,accuracy_train = update(image.cuda(),label.cuda())\n", " \n", " print(\"Batch\",f'{idx+1:04d}',\"von\",f'{len(train_dataloader):04d}',\":\",\n", " \"Train Loss\",f'{round(loss_train,4):.4f}',\n", " \"Train Accuray\",f'{round(accuracy_train,4):.4f}', end=\"\\r\")\n", "\n", " print(\"Batch\",f'{idx+1:04d}',\"von\",f'{len(train_dataloader):04d}',\":\",\n", " \"Train Loss\",f'{round(loss_train,4):.4f}',\n", " \"Train Accuray\",f'{round(accuracy_train,4):.4f}', end=\"\\n\")\n", "\n", " # iterate over test data and evalutate\n", " for idx, (image,label) in enumerate(test_dataloader):\n", " loss_test, accuracy_test = test(image.cuda(),label.cuda())\n", "\n", " print(\"Batch\",f'{idx+1:04d}',\"von\",f'{len(test_dataloader):04d}',\":\",\n", " \"Test Loss\",f'{round(loss_test,4):.4f}',\n", " \"Test Accuray\",f'{round(accuracy_test,4):.4f}', end=\"\\r\")\n", "\n", " print(\"Batch\",f'{idx+1:04d}',\"von\",f'{len(test_dataloader):04d}',\":\",\n", " \"Test Loss\",f'{round(loss_test,4):.4f}',\n", " \"Test Accuray\",f'{round(accuracy_test,4):.4f}', end=\"\\n\")\n", " \n", " # print results of epoch\n", " print(\"Epoch\",f'{ep+1:04d}',\"von\",f'{epochs:04d}',\":\",\n", " \"Train Loss\",f'{round(loss_train,4):.4f}',\n", " \"Test Loss\",f'{round(loss_test,4):.4f}',\n", " \"Train Accuracy\",f'{round(accuracy_train,4):.4f}',\n", " \"Test Accuracy\", f'{round(accuracy_test,4):.4f}') \n", " print('-'*100)\n", " \n", " train_losses.append(loss_train)\n", " test_losses.append(loss_test)\n", " train_accuracys.append(accuracy_train)\n", " test_accuracys.append(accuracy_test) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 8: Plot losses and accuracy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fid, ax1 = plt.subplots(figsize=(10,6))\n", "color = 'blue'\n", "ax1.set_xlabel('epochs')\n", "ax1.set_ylabel('loss', color=color) \n", "ax1.tick_params(axis='y', labelcolor=color)\n", "ax1.plot(train_losses,color=color, ls=\"-\", label=\"train loss\")\n", "ax1.plot(test_losses,color=color, ls=\"--\", label=\"test loss\")\n", "ax1.legend(loc='upper left')\n", "\n", "ax2 = ax1.twinx()\n", "\n", "color = 'red'\n", "ax2.set_ylabel('accuracy', color=color) \n", "ax2.tick_params(axis='y', labelcolor=color)\n", "ax2.plot(train_accuracys,color=color, ls=\"-\", label=\"train accuray\")\n", "ax2.plot(test_accuracys,color=color, ls=\"--\", label=\"test accuracy\")\n", "ax2.legend(loc='upper right')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 9: Predict some Examples in Test Set" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Step 9.1 Predict Images" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# get 9 images and labels from test dataset \n", "images, labels = next(iter(DataLoader(mnist_test, batch_size=9, num_workers=7, shuffle=True)))\n", "# set model in evaluation mode\n", "model.eval()\n", "# disable gradient calculation\n", "with torch.no_grad():\n", " predicted_digits = model(images.cuda()).argmax(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Step 9.2 Show Actual and Predicted" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# create plot grid\n", "fig, axes = plt.subplots(ncols=3,nrows=3,figsize=(10,10))\n", "\n", "# plot title for figure \n", "fig.suptitle('Prediction of 9 digits', fontsize=16)\n", "\n", "for idx,ax in enumerate(axes.flatten()):\n", " # transform data back to original value range and shape for displaying\n", " image = (np.array(un_normalizer(images[idx]).permute(1,2,0))*255).astype(int)\n", " # show image\n", " ax.imshow(image)\n", " # plot title for every image\n", " ax.set_title(\"Actual: \" + str(labels[idx].item())+ \" Predicted: \" + str(predicted_digits[idx].item()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 10: Save model to disk" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.save(model.state_dict(), \"mnist_model.pth\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 11: Test Model in Production Enviroment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Step 11.1: Load model from disk" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_loaded = models.resnet34(num_classes=10).cpu()\n", "model_loaded.load_state_dict(torch.load(\"mnist_model.pth\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Step 11.2: predict loaded images" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "images = []\n", "for f in [\"0.png\",\"1.png\",\"2.png\",\"4.png\",\"7.png\",\"9.png\"]:\n", " images.append(transformations(Image.open(f).convert(\"L\")))\n", "\n", "images=torch.stack(images)\n", "\n", "predicted_digits = model_loaded(images).argmax(1)\n", "\n", "# create plot grid\n", "fig, axes = plt.subplots(ncols=3,nrows=2,figsize=(10,10))\n", "\n", "for idx,ax in enumerate(axes.flatten()):\n", " # transform data back to original value range and shape for displaying\n", " image = (np.array(un_normalizer(images[idx]).permute(1,2,0))*255).astype(int)\n", " # show image\n", " ax.imshow(image)\n", " # plot title for every image\n", " ax.set_title(\"Predicted: \" + str(predicted_digits[idx].item()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "fin" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 4 }