Hello People! In this article, we’ll build and train a neural network to classify images of clothing, like sneakers and shirts. This implementaion uses tf.keras , a high-level API to build and train models in TensorFlow.

Importing dependencies

import tensorflow as tf
# Import TensorFlow Datasets
import tensorflow_datasets as tfds
# Helper libraries
import math
import numpy as np
import matplotlib.pyplot as plt
import logging

Importing the Fashion MNIST dataset

I am going to use the Fashion MNIST dataset, which contains 70,000 grayscale images in 10 categories. The images show individual articles of clothing at low resolution (28 × 28 pixels).

Fashion MNIST sprite
Fashion MNIST samples

I’ll use 60,000 images to train the network and 10,000 images to evaluate how accurately the network learned to classify images.

dataset, metadata = tfds.load('fashion_mnist', as_supervised=True, with_info=True)
train_dataset, test_dataset = dataset['train'], dataset['test']

Loading the dataset returns the metadata as well as training and test dataset.

  • The model is trained using train_dataset.
  • The model is tested against test_dataset.

The images are 28 × 28 arrays, with pixel values in the range [0, 255]. The labels are an array of integers, in the range [0, 9]. These correspond to the class of clothing the image represents:

9Ankle boot

Each image is mapped to a single label. Since the class names are not included with the dataset, store them here to use later when plotting the images.

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',      'Shirt',   'Sneaker',  'Bag',   'Ankle boot']

Exploring the data

Let’s explore the format of the dataset before training the model. The following shows there are 60,000 images in the training set, and 10000 images in the test set:

num_train_examples = metadata.splits['train'].num_examples
num_test_examples = metadata.splits['test'].num_examples
print("Number of training examples: {}".format(num_train_examples))
print("Number of test examples:     {}".format(num_test_examples))

Preprocessing the data

The value of each pixel in the image data is an integer in the range [0,255]. For the model to work properly, these values need to be normalized to the range [0,1]. So here we create a normalization function, and then apply it to each image in the test and train datasets.

def normalize(images, labels):
   images = tf.cast(images, tf.float32)
   images /= 255
   return images, labels

# The map function applies the normalize function to each element in the train
# and test datasets
train_dataset =  train_dataset.map(normalize)
test_dataset  =  test_dataset.map(normalize)

# The first time you use the dataset, the images will be loaded from disk
# Caching will keep them in memory, making training faster
train_dataset =  train_dataset.cache()
test_dataset  =  test_dataset.cache()

Building the model

Building the neural network requires configuring the layers of the model, then compiling the model.

Setting up the layers

This network has three layers:

  • input layer tf.keras.layers.Flatten — This layer transforms the images from a 2d-array of 28 × 28 pixels, to a 1d-array of 784 pixels (28*28). Think of this layer as unstacking rows of pixels in the image and lining them up. This layer has no parameters to learn, as it only reformats the data.
  • “hidden” layer tf.keras.layers.Dense— A densely connected layer of 128 neurons. Each neuron (or node) takes input from all 784 nodes in the previous layer, weighting that input according to hidden parameters which will be learned during training, and outputs a single value to the next layer.
  • output layer tf.keras.layers.Dense — A 128-neuron, followed by 10-node softmax layer. Each node represents a class of clothing. As in the previous layer, the final layer takes input from the 128 nodes in the layer before it, and outputs a value in the range [0, 1], representing the probability that the image belongs to that class. The sum of all 10 node values is 1.
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
    tf.keras.layers.Dense(128, activation=tf.nn.relu),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)

Compiling the model

Before the model is ready for training, it needs a few more settings. These are added during the model’s compile step:

  • Loss function — An algorithm for measuring how far the model’s outputs are from the desired output. The goal of training is this measures loss.
  • Optimizer —An algorithm for adjusting the inner parameters of the model in order to minimize loss.
  • Metrics —Used to monitor the training and testing steps. The following example uses accuracy, the fraction of the images that are correctly classified.

Training the model

First, we define the iteration behavior for the train dataset:

  1. Repeat forever by specifying dataset.repeat() (the epochs parameter described below limits how long we perform training).
  2. The dataset.shuffle(60000) randomizes the order so our model cannot learn anything from the order of the examples.
  3. And dataset.batch(32) tells model.fit to use batches of 32 images and labels when updating the model variables.
train_dataset = train_dataset.cache().repeat().shuffle(num_train_examples).batch(BATCH_SIZE)
test_dataset = test_dataset.cache().batch(BATCH_SIZE)

Training is performed by calling the model.fit method:

  1. Feed the training data to the model using train_dataset.
  2. The model learns to associate images and labels.
model.fit(train_dataset, epochs=5, steps_per_epoch=math.ceil(num_train_examples/BATCH_SIZE))

The epochs=5 parameter limits training to 5 full iterations of the training dataset, so a total of 5 * 60000 = 300000 examples.

As the model trains, the loss and accuracy metrics are displayed. This model reaches an accuracy of about 0.88 (or 88%) on the training data.

Evaluating accuracy

Next, we compare how the model performs on the test dataset. Use all examples we have in the test dataset to assess accuracy.

test_loss, test_accuracy = model.evaluate(test_dataset, steps=math.ceil(num_test_examples/32))
print('Accuracy on test dataset:', test_accuracy)

As it turns out, the accuracy on the test dataset is smaller than the accuracy on the training dataset. This is completely normal, since the model was trained on the train_dataset. When the model sees images it has never seen during training, (that is, from the test_dataset), we can expect performance to go down.

Making predictions

After the model is trained, we can use it to make predictions about images.

for test_images, test_labels in test_dataset.take(1):
  test_images = test_images.numpy()
  test_labels = test_labels.numpy()
  predictions = model.predict(test_images)

def plot_image(i, predictions_array, true_labels, images):
  predictions_array, true_label, img = predictions_array[i], true_labels[i], images[i]
  plt.imshow(img[...,0], cmap=plt.cm.binary)

  predicted_label = np.argmax(predictions_array)
  if predicted_label == true_label:
    color = 'blue'
    color = 'red'
  plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],

def plot_value_array(i, predictions_array, true_label):
  predictions_array, true_label = predictions_array[i], true_label[i]
  thisplot = plt.bar(range(10), predictions_array, color="#777777")
  plt.ylim([0, 1]) 
  predicted_label = np.argmax(predictions_array)

num_rows = 5
num_cols = 3
num_images = num_rows*num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows))
for i in range(num_images):
  plt.subplot(num_rows, 2*num_cols, 2*i+1)
  plot_image(i, predictions, test_labels, test_images)
  plt.subplot(num_rows, 2*num_cols, 2*i+2)
  plot_value_array(i, predictions, test_labels)
Predicted output


In this blog, you learned how to train a simple CNN on the Fashion MNIST dataset using Keras.

However,  it cannot be used directly in real-world fashion classification tasks, unless you preprocess your images in the exact same manner as Fashion MNIST (segmentation, thresholding, grayscale conversion, resizing, etc.). In most real-world fashion applications mimicking the Fashion MNIST pre-processing steps will be near impossible.

Let us know other ways in building a more robust fashion classification system in the comments section.

Hope you enjoyed my work! Please share your views in comments section.

Krishna Pal Deora


Leave a Reply

Your email address will not be published. Required fields are marked *

Insert math as
Additional settings
Formula color
Text color
Type math using LaTeX
Nothing to preview