# Wait... 24GB GPU memory is not enough? How to accumulate gradients in PyTorch

I was training the Nasnet-A-Large network on a 4 channel 512 by 512 images using PyTorch. Even with my beast GPU RTX Titan, I could only use a batch size of 8. The training is very volatile with that batch size, and I believe one way to combat that is to accumulate gradients for a few batches and then do a bigger update. Luckily, with PyTorch, it is very simple.

So, let's say below is your training loop:

We would only need a small modification to accumulate gradients:

The latter training code will accumulate gradients for 8 batches and do an update. Note that the backward pass is done on individual small batches still, this is crucial.

I originally implemented it as follows:

which is not only uglier but also not gonna work, because the backward step will cause GPU OOM for that the backward pass is on batch_size * num_batches images.

# Label Embedding in Multi-label classification

In the recent Kaggle competition, inclusive images challenge  I tried out label embedding technique for training multilabel classifiers, outlined in this paper by François Chollet.

The basic idea here is to decompose the pointwise mutual information(PMI) matrix from the training labels and use that to guide the training of the neural network model. The steps are as follow:

1. Encode training labels as you would with multilabel classification settings. Let $M$ (of size n by m, ie n training example with m labels) denote the matrix constructed by vertically stacking the label vectors.
2. The PMI (of size m by m) is a matrix with $PMI_{i,j}=log(\frac{P(i,j)}{P(i)*P(j)})$, it can be easily implemented via vectorized operations thus very efficient in computing, even on large datasets. See more explanation of the PMI here.
3. The embedding matrix $E$ is obtained by computing the singular value decomposition on PMI matrix and then take the dot product between $U$ and the first k columns of $\sqrt{\Sigma}$.
4. We then can use the embedding matrix to transform the original sparse encoded labels into dense vectors.
5. During the training of deep learning model, instead of using m sigmoid activations together with BCE loss in the end, now we can use k linear activation with cosine proximity loss.
6. During inference time, we take the model prediction and search in the rows from the embedding matrix $E$ and select the top similar vectors and find their corresponding labels.

Below is a toy example calculation of the label embedding procedure. The two pictures are the pairwise cosine similarity between item labels in the embedding space and a 2d display of items in the embedding space.

In my own experiments, I find the model trained on label embeddings are a bit more robust to label noises, it is faster in convergence and returns higher top k precision compared with models with logistic outputs.

I believe it is due to the high number of labels in the competition (m ~= 7000) problem contrasted with the small batches the model is trained on. As this label embedding is obtained from matrix factorization, it is similar to PCA that we keep crucial information and throw out some unnecessary detail/noise, except we are doing so on the labels instead of the inputs.

# Kaggle Digit Recognizer Revisited (Using Convolutional NN with Keras)

Almost a year ago, I revisited the Kaggle version of the Hand Written Digit Recognition problem, the link to that post is here. At that time, my go to language is R, since the majority of friends around me use R as well. This summer, I evidently switched back to use python as my primary language to do almost everything, because it is just so efficient.

So, here is a convolutional neural network using Keras to tackle this problem again, in less than 100 lines of code you can get a convolutional neural network and obtain 99% accuracy on the Kaggle leaderboard.

A quick note about training time, it took close to 9 minutes to be trained on my laptop with GeForce GTX 970M chip. You can increase the number of epochs and run it by yourself, it should be able to lead to better results.