Tag Archives: f1 loss

Writing your own loss function/module for PyTorch

Yes, I am switching to PyTorch, and I am so far very happy with it.

Recently, I am working on a multilabel classification problem, where the evaluation metric is the macro f1 score. So, ideally, we would want the loss function to be aligned with our evaluation metric, instead of using standard BCE.

Initially, I was using the following function:

It is perfectly usable for the purpose of a loss function, like your typical training code:

Better, we can make it a PyTorch module, so that the usage is more like your typical PyTorch loss:

That is simply to put the original f1_loss function on to the forward pass of a simple module. As a result, I can explicitly put the module to GPU.