Copyright © 2017-2023 ABBYY

[1]:
#@title
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Neural network with custom loss function

Download the tutorial as a Jupyter notebook

In this tutorial we will use NeoML to train a network that uses a custom loss function. User-defined loss functions have to be constructed out of operations supported in our autodifferentiation module.

The tutorial includes the following steps:

Build the network

We’ll fix the random seed and use a single-thread CPU math engine to make experiments more precise and reproducible.

[2]:
import neoml
import numpy as np

np.random.seed(666)
math_engine = neoml.MathEngine.CpuMathEngine()

The network architecture will be:

  • Source for data

  • FullyConnected with 1024 elements

  • Tanh activation

  • FullyConnected with 1 element (for binary classification)

[3]:
dnn = neoml.Dnn.Dnn(math_engine)
data = neoml.Dnn.Source(dnn, name='data')
fc1 = neoml.Dnn.FullyConnected(data, 1024, name='fc1')
tanh = neoml.Dnn.Tanh(fc1, name='tanh')
fc2 = neoml.Dnn.FullyConnected(tanh, 1, name='fc2')

Create custom loss

To use a custom loss function with NeoML, you need to implement the class that will calculate the function, derived from neoml.Dnn.CustomLossCalculatorBase. Implement the abstract calc(self, data, labels) method of this class. Its input parameters are data and labels, which are float blobs of size (batchLength, batchWidth, listSize, height, width, depth, channels) containing the network response and correct labels respectively.

The calc method must return a blob of size (batchLength, batchWidth, listSize, 1, 1, 1, 1) with loss function values for each object in batch. Object weights processing, total loss calculation and gradient calculation will be done automatically afterwards.

The following functions may be used in your custom losses:

  • / * + - operations between blobs and floats

  • neoml.AutoDiff.* functions like neoml.AutoDiff.max, neoml.AutoDiff.top_k etc.

  • neoml.AutoDiff.const for creating additional blobs with given values

In this example we’ll implement hinge loss for binary clasification, just for the sake of demonstration (in fact, NeoML already provides a HingeLoss layer with this loss function).

[4]:
class HingeLossCalculator(neoml.Dnn.CustomLossCalculatorBase):
    def calc(self, data, labels):
        # data contains net outputs (float, [batchLength, batchWidth, listSize, 1, 1, 1, 1])
        # label contains correct answers (float, [batchLength, batchWidth, listSize, 1, 1, 1], +/-1)
        # (vectorSize from above is equal to 1 because it's a binary classification loss)
        # the formula is max(0, 1 - y * t) where t is a correct label and y is a prediction
        return neoml.AutoDiff.max(0., 1. - data * labels)

Now we’ll create the custom loss layer using neoml.Dnn.CustomLoss(...), with a HingeLossCalculator() instance as loss_calculator parameter, and then connect this layer, and the correct labels source layer it requires, to the network.

[5]:
# Additional source for class labels
label = neoml.Dnn.Source(dnn, name='label')
# Custom loss layer with HingeLossCalculator
loss = neoml.Dnn.CustomLoss((fc2, label), name='loss',
                             loss_calculator=HingeLossCalculator())

Train the network

Hinge loss solves a binary classification task with +/-1 class labels.

Let’s generate some random data for classification and train the network on it.

The data is generated in the following way:

  • each object is a vector of 128 elements

  • +1 class objects are vectors with elements from N(0.25, 1) distribution

  • -1 class objects are vectors with elements from N(-0.25, 1) distribution

[6]:
batch_size = 32
channels = 128

print('Loss values (per iter)')
for _ in range(10):
    data_shape = (1, batch_size, 1, 1, 1, 1, channels)
    # Each class gets half of the batch
    data_ndarr = np.vstack((np.random.normal(0.25, 1., (batch_size // 2, channels)),
                            np.random.normal(-0.25, 1., (batch_size // 2, channels))))
    data_blob = neoml.Blob.asblob(math_engine, data_ndarr.astype(np.float32), data_shape)
    label_shape = (1, batch_size, 1, 1, 1, 1, 1)
    # Each class gets half of the batch
    label_ndarr = np.vstack((np.ones(batch_size//2),
                             -np.ones(batch_size//2)))
    label_blob = neoml.Blob.asblob(math_engine, label_ndarr.astype(np.float32), label_shape)
    # Train the network on the generated data
    dnn.learn({'data': data_blob, 'label': label_blob})
    print(loss.last_loss)
Loss values (per iter)
1.3969335556030273
1.175158977508545
0.37655696272850037
0.37330472469329834
0.15725426375865936
0.10323527455329895
0.044364433735609055
0.010994041338562965
0.0
0.021620113402605057