Tag Archives: PyTorch

Wait... 24GB GPU memory is not enough? How to accumulate gradients in PyTorch

I was training the Nasnet-A-Large network on a 4 channel 512 by 512 images using PyTorch. Even with my beast GPU RTX Titan, I could only use a batch size of 8. The training is very volatile with that batch size, and I believe one way to combat that is to accumulate gradients for a few batches and then do a bigger update. Luckily, with PyTorch, it is very simple.

So, let's say below is your training loop:

We would only need a small modification to accumulate gradients:

The latter training code will accumulate gradients for 8 batches and do an update. Note that the backward pass is done on individual small batches still, this is crucial.

I originally implemented it as follows:

which is not only uglier but also not gonna work, because the backward step will cause GPU OOM for that the backward pass is on batch_size * num_batches images.

Writing your own loss function/module for PyTorch

Yes, I am switching to PyTorch, and I am so far very happy with it.

Recently, I am working on a multilabel classification problem, where the evaluation metric is the macro f1 score. So, ideally, we would want the loss function to be aligned with our evaluation metric, instead of using standard BCE.

Initially, I was using the following function:

It is perfectly usable for the purpose of a loss function, like your typical training code:

Better, we can make it a PyTorch module, so that the usage is more like your typical PyTorch loss:

That is simply to put the original f1_loss function on to the forward pass of a simple module. As a result, I can explicitly put the module to GPU.