|
| 1 | +######################################################################## |
| 2 | +# |
| 3 | +# Downloads the MNIST data-set for recognizing hand-written digits. |
| 4 | +# |
| 5 | +# Implemented in Python 3.6 |
| 6 | +# |
| 7 | +# Usage: |
| 8 | +# 1) Create a new object instance: data = MNIST(data_dir="data/MNIST/") |
| 9 | +# This automatically downloads the files to the given dir. |
| 10 | +# 2) Use the training-set as data.x_train, data.y_train and data.y_train_cls |
| 11 | +# 3) Get random batches of training data using data.random_batch() |
| 12 | +# 4) Use the test-set as data.x_test, data.y_test and data.y_test_cls |
| 13 | +# |
| 14 | +######################################################################## |
| 15 | +# |
| 16 | +# This file is part of the TensorFlow Tutorials available at: |
| 17 | +# |
| 18 | +# https://github.com/Hvass-Labs/TensorFlow-Tutorials |
| 19 | +# |
| 20 | +# Published under the MIT License. See the file LICENSE for details. |
| 21 | +# |
| 22 | +# Copyright 2016-18 by Magnus Erik Hvass Pedersen |
| 23 | +# |
| 24 | +######################################################################## |
| 25 | + |
| 26 | +import numpy as np |
| 27 | +import gzip |
| 28 | +import os |
| 29 | +from dataset import one_hot_encoded |
| 30 | +from download import download |
| 31 | + |
| 32 | +######################################################################## |
| 33 | + |
| 34 | +# Base URL for downloading the data-files from the internet. |
| 35 | +base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" |
| 36 | + |
| 37 | +# Filenames for the data-set. |
| 38 | +filename_x_train = "train-images-idx3-ubyte.gz" |
| 39 | +filename_y_train = "train-labels-idx1-ubyte.gz" |
| 40 | +filename_x_test = "t10k-images-idx3-ubyte.gz" |
| 41 | +filename_y_test = "t10k-labels-idx1-ubyte.gz" |
| 42 | + |
| 43 | +######################################################################## |
| 44 | + |
| 45 | + |
| 46 | +class MNIST: |
| 47 | + """ |
| 48 | + The MNIST data-set for recognizing hand-written digits. |
| 49 | + This automatically downloads the data-files if they do |
| 50 | + not already exist in the local data_dir. |
| 51 | +
|
| 52 | + Note: Pixel-values are floats between 0.0 and 1.0. |
| 53 | + """ |
| 54 | + |
| 55 | + # The images are 28 pixels in each dimension. |
| 56 | + img_size = 28 |
| 57 | + |
| 58 | + # The images are stored in one-dimensional arrays of this length. |
| 59 | + img_size_flat = img_size * img_size |
| 60 | + |
| 61 | + # Tuple with height and width of images used to reshape arrays. |
| 62 | + img_shape = (img_size, img_size) |
| 63 | + |
| 64 | + # Number of colour channels for the images: 1 channel for gray-scale. |
| 65 | + num_channels = 1 |
| 66 | + |
| 67 | + # Tuple with height, width and depth used to reshape arrays. |
| 68 | + # This is used for reshaping in Keras. |
| 69 | + img_shape_full = (img_size, img_size, num_channels) |
| 70 | + |
| 71 | + # Number of classes, one class for each of 10 digits. |
| 72 | + num_classes = 10 |
| 73 | + |
| 74 | + def __init__(self, data_dir="data/MNIST/"): |
| 75 | + """ |
| 76 | + Load the MNIST data-set. Automatically downloads the files |
| 77 | + if they do not already exist locally. |
| 78 | +
|
| 79 | + :param data_dir: Base-directory for downloading files. |
| 80 | + """ |
| 81 | + |
| 82 | + # Copy args to self. |
| 83 | + self.data_dir = data_dir |
| 84 | + |
| 85 | + # Number of images in each sub-set. |
| 86 | + self.num_train = 55000 |
| 87 | + self.num_val = 5000 |
| 88 | + self.num_test = 10000 |
| 89 | + |
| 90 | + # Download / load the training-set. |
| 91 | + x_train = self._load_images(filename=filename_x_train) |
| 92 | + y_train_cls = self._load_cls(filename=filename_y_train) |
| 93 | + |
| 94 | + # Split the training-set into train / validation. |
| 95 | + # Pixel-values are converted from ints between 0 and 255 |
| 96 | + # to floats between 0.0 and 1.0. |
| 97 | + self.x_train = x_train[0:self.num_train] / 255.0 |
| 98 | + self.x_val = x_train[self.num_train:] / 255.0 |
| 99 | + self.y_train_cls = y_train_cls[0:self.num_train] |
| 100 | + self.y_val_cls = y_train_cls[self.num_train:] |
| 101 | + |
| 102 | + # Download / load the test-set. |
| 103 | + self.x_test = self._load_images(filename=filename_x_test) / 255.0 |
| 104 | + self.y_test_cls = self._load_cls(filename=filename_y_test) |
| 105 | + |
| 106 | + # Convert the class-numbers from bytes to ints as that is needed |
| 107 | + # some places in TensorFlow. |
| 108 | + self.y_train_cls = self.y_train_cls.astype(np.int) |
| 109 | + self.y_val_cls = self.y_val_cls.astype(np.int) |
| 110 | + self.y_test_cls = self.y_test_cls.astype(np.int) |
| 111 | + |
| 112 | + # Convert the integer class-numbers into one-hot encoded arrays. |
| 113 | + self.y_train = one_hot_encoded(class_numbers=self.y_train_cls, |
| 114 | + num_classes=self.num_classes) |
| 115 | + self.y_val = one_hot_encoded(class_numbers=self.y_val_cls, |
| 116 | + num_classes=self.num_classes) |
| 117 | + self.y_test = one_hot_encoded(class_numbers=self.y_test_cls, |
| 118 | + num_classes=self.num_classes) |
| 119 | + |
| 120 | + def _load_data(self, filename, offset): |
| 121 | + """ |
| 122 | + Load the data in the given file. Automatically downloads the file |
| 123 | + if it does not already exist in the data_dir. |
| 124 | +
|
| 125 | + :param filename: Name of the data-file. |
| 126 | + :param offset: Start offset in bytes when reading the data-file. |
| 127 | + :return: The data as a numpy array. |
| 128 | + """ |
| 129 | + |
| 130 | + # Download the file from the internet if it does not exist locally. |
| 131 | + download(base_url=base_url, filename=filename, download_dir=self.data_dir) |
| 132 | + |
| 133 | + # Read the data-file. |
| 134 | + path = os.path.join(self.data_dir, filename) |
| 135 | + with gzip.open(path, 'rb') as f: |
| 136 | + data = np.frombuffer(f.read(), np.uint8, offset=offset) |
| 137 | + |
| 138 | + return data |
| 139 | + |
| 140 | + def _load_images(self, filename): |
| 141 | + """ |
| 142 | + Load image-data from the given file. |
| 143 | + Automatically downloads the file if it does not exist locally. |
| 144 | +
|
| 145 | + :param filename: Name of the data-file. |
| 146 | + :return: Numpy array. |
| 147 | + """ |
| 148 | + |
| 149 | + # Read the data as one long array of bytes. |
| 150 | + data = self._load_data(filename=filename, offset=16) |
| 151 | + |
| 152 | + # Reshape to 2-dim array with shape (num_images, img_size_flat). |
| 153 | + images_flat = data.reshape(-1, self.img_size_flat) |
| 154 | + |
| 155 | + return images_flat |
| 156 | + |
| 157 | + def _load_cls(self, filename): |
| 158 | + """ |
| 159 | + Load class-numbers from the given file. |
| 160 | + Automatically downloads the file if it does not exist locally. |
| 161 | +
|
| 162 | + :param filename: Name of the data-file. |
| 163 | + :return: Numpy array. |
| 164 | + """ |
| 165 | + return self._load_data(filename=filename, offset=8) |
| 166 | + |
| 167 | + def random_batch(self, batch_size=32): |
| 168 | + """ |
| 169 | + Create a random batch of training-data. |
| 170 | +
|
| 171 | + :param batch_size: Number of images in the batch. |
| 172 | + :return: 3 numpy arrays (x, y, y_cls) |
| 173 | + """ |
| 174 | + |
| 175 | + # Create a random index into the training-set. |
| 176 | + idx = np.random.randint(low=0, high=self.num_train, size=batch_size) |
| 177 | + |
| 178 | + # Use the index to lookup random training-data. |
| 179 | + x_batch = self.x_train[idx] |
| 180 | + y_batch = self.y_train[idx] |
| 181 | + y_batch_cls = self.y_train_cls[idx] |
| 182 | + |
| 183 | + return x_batch, y_batch, y_batch_cls |
| 184 | + |
| 185 | + |
| 186 | +######################################################################## |
0 commit comments