Random states in multiprocessing, learnt a lesson after wasted a weeks GPU time.

I was recently training CNN on the openimages dataset using Keras. I am using a custom batch generator together with the .fit_generator() method in Keras, and observed super slow training progress.

My code looks something like this:

I wasted a lot of time debugging the model structure, loss, and optimizer, but the problem is much simpler. I eventually found it by printing out the indices been sampled.

The problem with the code is that when the generator gets duplicated on multiple workers, the random states also get copied, so the 8 workers have the same random state. As a result, during training, the model will see the exact same batch 8 times before seeing a new batch. The fix is easy, just insert a np.random.seed() before sampling the indices.




Leave a Reply

Your email address will not be published. Required fields are marked *