Copyright © 2017-2021 ABBYY Production LLC

[1]:
#@title
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Neural network for CIFAR-10

Download the tutorial as a Jupyter notebook

In this tutorial, we will use NeoML to create a neural network that classifies the CIFAR-10 dataset.

The tutorial includes the following steps:

Download the dataset

Note: This section doesn’t have any NeoML-specific code. It just downloads the dataset from the internet. If you are not running this notebook, you may skip this section.

[2]:
import os

def calc_md5(file_name):
    """Calculates md5 hash of an existing file"""
    import hashlib
    curr_hash = hashlib.md5()
    with open(file_name, 'rb') as file_in:
        chunk = file_in.read(8192)
        while chunk:
            curr_hash.update(chunk)
            chunk = file_in.read(8192)
    return curr_hash.hexdigest()


# Download data
url = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
file_name = url[url.rfind('/')+1:]
ARCHIVE_SIZE = 170498071
ARCHIVE_MD5 = 'c58f30108f718f92721af3b95e74349a'

# Download when archive is missing or broken
if (not os.path.isfile(file_name)) \
        or os.path.getsize(file_name) != ARCHIVE_SIZE \
        or calc_md5(file_name) != ARCHIVE_MD5:
    import requests
    with requests.get(url, stream=True) as url_stream:
        url_stream.raise_for_status()
        with open(file_name, 'wb') as file_out:
            for chunk in url_stream.iter_content(chunk_size=8192):
                file_out.write(chunk)

# Unpack data
import tarfile
tar = tarfile.open(file_name, 'r:gz')
tar.extractall()
tar.close()

Prepare the dataset

In this section we load the data from files into numpy arrays and preprocess it. Preprocessing includes:

  • Data type conversion, because NeoML takes 32-bit types for both integer and float data

  • Normalization

  • Image format conversion, because NeoML works with channel-last images

We’ll also lump the 5 training batches of the original dataset together, because when training we’ll use a much smaller batch size.

[3]:
import numpy as np

np.random.seed(666)

def load_batch_file(file_name):
    """Loads data from one of the batch files"""
    import pickle
    with open(file_name, 'rb') as file_in:
        result = pickle.load(file_in, encoding='bytes')
    return result

def transform_data(X):
    """Normalizes and transposes data for NeoML"""
    X = X.astype(np.float32)
    X = (X - 127.5) / 255.
    X = X.reshape((X.shape[0], 3, 32, 32))
    X = X.transpose((0, 2, 3, 1))  # NeoML uses channel-last pack
    return X

# Preparing data
batch_name = 'cifar-10-batches-py/data_batch_{0}'
train_data = [load_batch_file(batch_name.format(i)) for i in range(1, 6)]
X_train = np.concatenate(list(x[b'data'] for x in train_data), axis=0)
X_train = transform_data(X_train)
y_train = np.concatenate(list(x[b'labels'] for x in train_data), axis=0)
y_train = y_train.astype(np.int32)

test_data = load_batch_file('cifar-10-batches-py/test_batch')
X_test = test_data[b'data']
X_test = transform_data(X_test)
y_test = np.array(test_data[b'labels'], dtype=np.int32)

Build the network

Choose the device

We need to create a math engine that will perform all calculations and allocate data for the neural network. The math engine is tied to the processing device.

For faster training in this tutorial we’ll create the math engine that works on GPU.

Note: If NeoML doesn’t manage to find a compatible GPU, it’ll create a CPU math engine. You may check which math engine was created by looking at its info attribute.

[4]:
import neoml

# If you'd prefer to use a CPU, call neoml.MathEngine.CpuMathEngine() instead
math_engine = neoml.MathEngine.GpuMathEngine(0)
print('Device: ', math_engine.info)
Device:  CUDA: GeForce RTX 2060

Create the network and connect layers

Let’s create a neoml.Dnn.Dnn object that represents a neural network (a directed graph of layers). The network requires a math engine to perform its operations; it must be specified at creation and can’t be changed later.

[5]:
dnn = neoml.Dnn.Dnn(math_engine)

A neoml.Dnn.Source layer feeds the data into the network.

[6]:
data = neoml.Dnn.Source(dnn, 'data')  # source for data

The network in this tutorial will consist of several convolutional blocks. Each block contains a dropout layer to randomly zero out some of the inputs, a convolution layer with trainable coefficients, a batch normalization layer, and a ReLU activation.

Each layer gets its own name, so that it can be found if needed, and is connected to the output of the previous layer.

[7]:
class ConvBlock:
    """Block of dropout->conv->batch_norm->relu6"""
    def __init__(self, inputs, filter_count, name):
        self.dropout = neoml.Dnn.Dropout(inputs, rate=0.1, spatial=True,
                                         batchwise=True, name=name+'_dropout')
        self.conv = neoml.Dnn.Conv(self.dropout, filter_count=filter_count,
                                   filter_size=(3, 3), stride_size=(2, 2),
                                   padding_size=(1, 1), name=name+'_conv')
        self.bn = neoml.Dnn.BatchNormalization(self.conv, channel_based=True,
                                               name=name+'_bn')
        self.output = neoml.Dnn.ReLU(self.bn, threshold=6., name=name+'_relu6')


# Add a few convolutional blocks
# First convolutional block takes source layer's data as input
block1 = ConvBlock(data, filter_count=16, name='block1')  # -> (16,  16)
# Next convolutional blocks each take as input the output of the previous block
block2 = ConvBlock(block1.output, filter_count=32, name='block2')  # -> (8, 8)
block3 = ConvBlock(block2.output, filter_count=64, name='block3')  # -> (4, 4)

Afterwards we’ll use a fully-connected layer to generate logits (non-normalized probabilities) over classes.

[8]:
# Fully-connected layer flattens its input automatically
n_classes = 10  # the number of classes in CIFAR-10 dataset
fc = neoml.Dnn.FullyConnected(block3.output, n_classes, name='fc')

To train the network, we also need to define a loss function to be optimized. In NeoML this is done by adding one or several loss layers.

In this tutorial we’ll be optimizing cross-entropy loss.

A loss function needs to compare the network output with the correct labels, so we’ll add another source layer to pass the correct labels in.

Note: in case of multiple loss layers you may want to use neoml.Dnn.Loss.loss_weight properties of each layer to balance between several loss functions.

[9]:
# Before loss layer itself we need to create source layer for correct labels
labels = neoml.Dnn.Source(dnn, 'labels')
# Here you can see how to create a layer with multiple inputs
# Softmax will be applied within cross-entropy (no need for explicit softmax layer here)
loss = neoml.Dnn.CrossEntropyLoss((fc, labels), name='loss')

NeoML also provides a neoml.Dnn.Accuracy layer to calculate network accuracy. Let’s connect this layer and create an additional neoml.Dnn.Sink layer for extracting its output.

[10]:
# Auxiliary layers needed to get statistics
accuracy = neoml.Dnn.Accuracy((fc, labels), name='accuracy')
# The accuracy layer writes its result to its output
# We need additional sink layer to extract it
accuracy_sink = neoml.Dnn.Sink(accuracy, name='accuracy_sink')

Create a solver

Solver is an object that optimizes the weights using gradient values. It is necessary for training the network. In this sample we’ll use a neoml.Dnn.AdaptiveGradient solver, which is the NeoML implementation of Adam.

[11]:
lr = 1e-3 # Learning rate

# Create solver
dnn.solver = neoml.Dnn.AdaptiveGradient(math_engine, learning_rate=lr,
                                        l1=0., l2=0.,  # no regularization
                                        max_gradient_norm=1.,  # clip gradients
                                        moment_decay_rate=0.9,
                                        second_moment_decay_rate=0.999)

Train the network on the dataset

NeoML networks accept data only as neoml.Blob.Blob.

Blobs are 7-dimensional arrays located in device memory. Each dimension has a specific purpose:

  1. BatchLength - temporal axis (used in recurrent layers)

  2. BatchWidth - classic batch

  3. ListSize - list axis, used when objects are related to the same entity, but without ordering (unlike BatchLength)

  4. Height - height of the image

  5. Width - width of the image

  6. Depth - depth of the 3-dimensional image

  7. Channels - channels of the image (also used when object is a 1-dimensional vector)

We will use ndarray to split data into batches, then create blobs from these batches right before feeding them into the network.

[12]:
def make_blob(data, math_engine):
    """Wraps numpy data into a NeoML blob"""
    shape = data.shape
    if len(shape) == 4:  # images
        # Data is a batch of 2-dimensional multi-channel images
        # Wrap it into (BatchWidth, Height, Width, Channels) blob
        blob_shape = (1, shape[0], 1, shape[1], shape[2], 1, shape[3])
        return neoml.Blob.asblob(math_engine, data, blob_shape)
    elif len(shape) == 1:  # dense labels
        # Data contains dense labels (batch of integers)
        # Wrap it into blob of (BatchWidth,) shape
        return neoml.Blob.asblob(math_engine, data,
                                 (1, shape[0], 1, 1, 1, 1, 1))
    else:
        assert(False)


def cifar10_array_iter(X, y, batch_size):
    """Slices numpy arrays into batches"""
    start = 0
    data_size = y.shape[0]
    while start < data_size:
        yield X[start : start+batch_size], y[start : start+batch_size]
        start += batch_size


def cifar10_blob_iter(X, y, batch_size, math_engine):
    """Slices numpy arrays into batches and wraps them in blobs"""
    for X_b, y_b in cifar10_array_iter(X, y, batch_size):
        yield make_blob(X_b, math_engine), make_blob(y_b, math_engine)

To train the network, call dnn.learn with data as its argument.

To run the network without training, call dnn.run with data as its argument.

The input data is a dict where each key is a neoml.Dnn.Source layer name and the corresponding value is the neoml.Blob.Blob that should be passed in to this layer.

[13]:
def run_net(X, y, batch_size, dnn, is_train):
    """Runs dnn on given data"""
    start = time.time()
    total_loss = 0.
    run_iter = dnn.learn if is_train else dnn.run
    math_engine = dnn.math_engine
    layers = dnn.layers
    loss = layers['loss']
    accuracy = layers['accuracy']
    sink = layers['accuracy_sink']

    accuracy.reset = True  # Reset previous statistics
    # Iterate over batches
    for X_batch, y_batch in cifar10_blob_iter(X, y, batch_size, math_engine):
        # Run the network on the batch data
        run_iter({'data': X_batch, 'labels': y_batch})
        total_loss += loss.last_loss * y_batch.batch_width  # Update epoch loss
        accuracy.reset = False  # Don't reset statistics within one epoch

    avg_loss = total_loss / y.shape[0]
    avg_acc = sink.get_blob().asarray()[0]
    run_time = time.time() - start
    return avg_loss, avg_acc, run_time

In this tutorial, we’ll also demonstrate how to store and load progress during training.

Store training progress using the dnn.store_checkpoint method. Resume training from checkpoint by calling dnn.load_checkpoint.

Important: NeoML checkpoints contain all the information required for training, including the net architecture. That allows us to load_checkpoint into any neoml.Dnn.Dnn object without the need to re-create architecture or solver before loading. However, this leads to the creation of new layer, solver, and blob objects during each dnn.load_checkpoint. If you had any previously created python variables which were pointing to the objects of the net before loading (like solver, data variables here), you must re-initialize them with the new ones.

[14]:
import time

# Network params
batch_size = 50

n_epoch = 10
for epoch in range(n_epoch):
    # Train
    avg_loss, acc, run_time = run_net(X_train, y_train, batch_size,
                                      dnn, is_train=True)
    print(f'Train #{epoch}\tLoss: {avg_loss:.4f}\t'
          f'Accuracy: {acc:.4f}\tTime: {run_time:.2f} sec')
    # Test
    avg_loss, acc, run_time = run_net(X_test, y_test, batch_size,
                                      dnn, is_train=False)
    print(f'Test  #{epoch}\tLoss: {avg_loss:.4f}\t'
          f'Accuracy: {acc:.4f}\tTime: {run_time:.2f} sec')
    if epoch == 1:
        # If you want to save training progress you can do it via checkpoints
        # that store dnn weights and other training data (solver stats, etc.)
        print('Creating checkpoint...')
        dnn.store_checkpoint('cifar10_sample.checkpoint')
    if epoch == 5:
        # Resume training from the checkpoint
        print('Loading checkpoint... (this will roll dnn back to epoch #1)')
        dnn.load_checkpoint('cifar10_sample.checkpoint')
        # Be careful! dnn now points to the new net
        # But other layer/solver variables still pointing to the old net!
Train #0        Loss: 1.5371    Accuracy: 0.4499        Time: 6.12 sec
Test  #0        Loss: 1.2951    Accuracy: 0.5341        Time: 0.71 sec
Train #1        Loss: 1.2288    Accuracy: 0.5630        Time: 5.87 sec
Test  #1        Loss: 1.1361    Accuracy: 0.5951        Time: 0.71 sec
Creating checkpoint...
Train #2        Loss: 1.1138    Accuracy: 0.6064        Time: 5.90 sec
Test  #2        Loss: 1.2091    Accuracy: 0.5761        Time: 0.70 sec
Train #3        Loss: 1.0385    Accuracy: 0.6321        Time: 5.91 sec
Test  #3        Loss: 1.0687    Accuracy: 0.6177        Time: 0.71 sec
Train #4        Loss: 0.9907    Accuracy: 0.6520        Time: 5.92 sec
Test  #4        Loss: 1.0566    Accuracy: 0.6293        Time: 0.70 sec
Train #5        Loss: 0.9494    Accuracy: 0.6647        Time: 5.93 sec
Test  #5        Loss: 1.0407    Accuracy: 0.6361        Time: 0.70 sec
Loading checkpoint... (this will roll dnn back to epoch #1)
Train #6        Loss: 1.1142    Accuracy: 0.6035        Time: 5.91 sec
Test  #6        Loss: 1.0858    Accuracy: 0.6211        Time: 0.70 sec
Train #7        Loss: 1.0434    Accuracy: 0.6333        Time: 5.95 sec
Test  #7        Loss: 1.1010    Accuracy: 0.6111        Time: 0.71 sec
Train #8        Loss: 0.9904    Accuracy: 0.6490        Time: 5.89 sec
Test  #8        Loss: 1.0438    Accuracy: 0.6331        Time: 0.71 sec
Train #9        Loss: 0.9491    Accuracy: 0.6673        Time: 5.94 sec
Test  #9        Loss: 1.0083    Accuracy: 0.6476        Time: 0.71 sec

Prepare the network for inference

Before using the trained network for inference—that is, just to process data with no more changes to the network itself—we must delete training-only layers. All layers that receive correct labels as input should be deleted.

[15]:
# Remove training-only layers
dnn.delete_layer('labels')
dnn.delete_layer('loss')
dnn.delete_layer('accuracy')
dnn.delete_layer('accuracy_sink')

We still need a sink layer to extract logits. If you need exact normalized probabilities you should add neoml.Dnn.Softmax layer before sink. But we’re only interested in the index of the most probable class, so we’ll omit softmax.

[16]:
# Add sink for dnn output
sink = neoml.Dnn.Sink(dnn.layers['fc'], name='sink')

We’ll also fuse each batch normalization layer with the previous convolution or fully-connected layer, to reduce the number of operations during inference.

[17]:
def fuse_batch_norm(dnn, block_name):
    """Fuses batch_norm into convolution to reduce inference time
    Should be used after training
    """
    bn_name = block_name + '_bn'
    if not dnn.has_layer(bn_name):
        # Batch norm has already been fused
        return
    bn_layer = dnn.layers[bn_name]
    conv_name = block_name + '_conv'
    conv_layer = dnn.layers[conv_name]
    # Fuse batch normalization
    conv_layer.apply_batch_normalization(bn_layer)
    # Delete layer from net (conv already 'contains' it)
    dnn.delete_layer(bn_name)
    # Connect layer after batchnorm to convolution
    # because batchnorm was removed from the dnn
    output_name = block_name + '_relu6'
    dnn.layers[output_name].connect(conv_layer)


# Fuse batchnorms into convolutions
fuse_batch_norm(dnn, 'block1')
fuse_batch_norm(dnn, 'block2')
fuse_batch_norm(dnn, 'block2')

Serialize the network

Use dnn.store to save the trained network. This method stores all the info required for inference. Later you may load the saved network by calling dnn.load on any neoml.Dnn.Dnn object.

[18]:
# Store trained network
# store/load methods work best for this because, unlike checkpoints,
# they don't save training-related data and take up less disk space
dnn.store('cifar10_sample.dnn')
[19]:
# Load the trained network
dnn.load('cifar10_sample.dnn')

Once again, load leads to creation of the new layer objects. The sink variable created before loading now points to the old sink layer object, which doesn’t belong to the new network. Let’s fix it!

[20]:
sink = dnn.layers['sink']

Evaluate the performance

We can get the blob with the results of the last dnn.run by calling sink.get_blob. Then we’ll convert it into a numpy array via blob.asarray and calculate accuracy by the means of numpy.

We expect to get the same value for accuracy as during the test of the last epoch.

[21]:
# Evaluate inference
inference_acc = 0.
for X_b, y_b in cifar10_array_iter(X_test, y_test, batch_size):
    dnn.run({'data': make_blob(X_b, math_engine)})
    # Extract data from sink
    # Non-normalized probabilities of shape (batch_size, n_classes)
    logits = sink.get_blob().asarray()
    # Calculate accuracy
    inference_acc += (np.argmax(logits, axis=1) == y_b).sum()
inference_acc /= len(X_test)

# This number is expected to equal the test accuracy of the last epoch
print(f'Inference net test accuracy: {inference_acc:.4f}')
Inference net test accuracy: 0.6476