CARE 3D, memory usage

Hello!

I’ve been working with CARE 3D denoising to accelerate the acquisition of a large tissue sample. First tests with ~20 zstacks (10GB size) went very well, but it seems like some of the (rather rare) structures were not restored in the predicted images so I decided to try pursuing the training for longer on augmented data and see if it improves the results.
I rotated the training data to augment it (rotation 90°,180° and 270°) and ended up with a training patches file of about 40GB. The machine having 120GB of RAM, I wasn’t much concerned about it and started the training (800 steps/epoch, 200 epochs).
The attached image shows the monitored memory usage of the machine

. When starting the training (around 15:50), I can see ~40GB memory being rapidly taken as expected for loading the patches and then I can see that each new epoch is using ~4GB RAM more with a diminution of the memory usage of about 20GB every 4-5 epochs (I guess some automated garbage collection?). Unfortunately, this ends to a steady increase of the used memory and the training fails after ~150 epochs (out of 200 planned) due to a lack of available memory.

@uschmidt83 @mweigert Is this increase of memory usage something expected or could there be a memory leak somewhere? Am I doing something wrong or is there a way I could avoid this issue? I guess I could double the RAM, but there may be a more clever solution :slight_smile:

I checked the tensorboard display to make sure that it was useful to do that many epochs and it seems like the model is still improving by the crash of the training tensorBoard_there_is_still_stuff_to_learn|524x499

The code I’m using is derived from the example Jupyter notebooks shared on the CSBDeep GitHub, I can attach it if it can be useful. Also I’m using a CENTOS7 machine with a GTX1080 GPU in a conda environment (I’m attaching the conda env environment.yml (5.6 KB) ).

Let me know if I can add anything!

Thanks!

Sebastien

1 Like

I never investigated memory consumption thoroughly, but I’d say this shouldn’t happen.

The learning rate is very low when the crash happens, it probably won’t help much to continue training.

Yes, that might be helpful.

Best,
Uwe

1 Like

Hi @sebherbert , well if you run into situations like these it is always a good idea to start the training from checkpoints by loading the model weights from the already trained models, also helpful for transfer learning.

In the cell where you see create the model CARE you can add these lines below to do the training for 50 epochs, stop, change the learning rate if you want and then re-start the training to start from a checkpoint rather than from scratch:

if os.path.exists(model_dir + model_name + '/' + 'weights_now.h5'):
       print('Loading checkpoint model')
       model.load_weights(model_dir + model_name + '/' + 'weights_now.h5')
                                    
 if os.path.exists(model_dir + model_name + '/' + 'weights_last.h5'):
       print('Loading checkpoint model')
       model.load_weights(model_dir + model_name + '/' + 'weights_last.h5')
                                    
  if os.path.exists(model_dir + model_name + '/' + 'weights_best.h5'):
      print('Loading checkpoint model')
      model.load_weights(model_dir + model_name + '/' + 'weights_best.h5')
1 Like

Hi Uwe!
Thanks for the feedback!

That was my intuition too.

Ok, so until I can think of a cleaner solution, I’ll probably try Harun’s solution of 2x100 epochs and see what’s the result.

Here’s the notebook SWITCHdrive
It’s really close from the example notebook you shared on your GitHub here.

Let me know if I can do something else

Sebastien

Thanks Varun!
I had some intuition that it should work but thanks for confirming, I’ll try that!
It’s not exactly a clean solution though :sweat_smile:

Sebastien

1 Like

Haaha yeah, its what I use if I want to do transfer learning but doesn’t address memory leaks just avoids it so in a way it doesn’t solves the problem but removes it :slight_smile:

Cheers,

Varun

I restarted the process following your suggestion.
I realized something, by doing this am I not reshuffling the validation dataset into the training dataset and probably messing up the overfit estimation and best model selection?
Cheers,
Sebastien

Hi @sebherbert

This indeed should not be the case and I cannot directly reproduce this. Could you try to use only a very small validation set (e.g. 16 patches) and see whether the issue persists?

1 Like

Hi Sebastian,

Well you need the training and validation dataset if you are trying to do hyper parameter optimization but after you have found the best network can’t you just put all the data as training data or just highly reduce your validation data for the second run? In this way most of your patches go into training.

Hi Martin!

Thanks for your suggestion and test.
I wasn’t sure how to force a small validation set, so I’ve just restarted the training with a 0.001% validation split which brings the training and validation patches to

number of training images: 77746
number of validation images: 78
image size (3D): (16, 64, 64)
axes: SZYXC
channels in / out: 1 / 1

I hope that’s good enough? I received a “userWarning: small number of validation images” when the training begun. It takes a bit of time for each epochs (~260s), so far the memory usage seems more stable, but I’ll edit the post later when I have a few more epochs to reevaluate (I kept the 800 steps/epoch, 100 epochs configuration)

→ Update after ~50min RAM usage only increased of 400MB without any of these steps that I was seeing earlier so it’s apparently solving the problem. (I’ll let it run overnight to be sure, but it seems quite clear already)
Does that mean I should use less validation patches during my trainings and limit it to something between 100-500?

→ Overnight update: All good during the night and RAM usage was stable. So I guess it means that it’s due to the number of validation patches indeed, I’m wondering what could be accumulated during the training?

Cheers,
Sebastien

Correct me if I’m wrong, but aren’t you selecting the lowest validation loss as your best model in any case?

Yes but in 100 epochs when the learning rate is lowered by the rate scheduler would already give you the model with lowest loss and then you can train that model using a very small validation data. Don’t your different models already show that by 100 epochs?

Ah, I think I see. Yes indeed, the learning rate is quite low after the 100th epoch. That’s what Uwe was also pointing at. It might be good enough to just stop there.

Yeah so in this case I think you can choose the winner model after 100 epoch and then feed that network all the training data keeping only 0.01% as validation data just so that the programs do not crash because of getting None as validation data. Then you can take your winner network and change the learning rate to the value shown by the scheduler for next run if you want to train it more by setting the model weights.

When you get more training data you can just take your winner model and keep training it keeping the validation data to minimum and using model load weights rather than starting from scratch.

1 Like