Training stardist -- 3D dataset

Hi, I am training the 3D stardist model using 20 128x128x128 volumes. Each volume about 100 cells. Mean bounding box size for a cell is about 30. Here is an example:

I trained the model for 400 epochs. Loss functions reach a plateau by the of training. Here are the scores for the training which are close to the scores I get for a small validation set. To get this score I had to use (1,1,1) subsample grid, otherwise scores were near zero for (2,2,2) which makes sense because the dataset has a low resolution to begin with. I would like to optimize the parameters so we can improve the scores for the training data. I just want to explore the fitting power of the model and tame overfitting later. Any suggestions for parameters to explore?

Thanks,
Abbas

I didn’t use augmentation yet. here are the full list of params:

{'n_dim': 3,
 'axes': 'ZYXC',
 'n_channel_in': 1,
 'n_channel_out': 97,
 'train_checkpoint': 'weights_best.h5',
 'train_checkpoint_last': 'weights_last.h5',
 'train_checkpoint_epoch': 'weights_now.h5',
 'n_rays': 96,
 'grid': (1, 1, 1),
 'anisotropy': (1.0, 1.0, 1.0175438596491229),
 'backbone': 'resnet',
 'rays_json': {'name': 'Rays_GoldenSpiral',
  'kwargs': {'n': 96, 'anisotropy': (1.0, 1.0, 1.0175438596491229)}},
 'resnet_n_blocks': 4,
 'resnet_kernel_size': (3, 3, 3),
 'resnet_kernel_init': 'he_normal',
 'resnet_n_filter_base': 32,
 'resnet_n_conv_per_block': 3,
 'resnet_activation': 'relu',
 'resnet_batch_norm': False,
 'net_conv_after_resnet': 128,
 'net_input_shape': (None, None, None, 1),
 'net_mask_shape': (None, None, None, 1),
 'train_patch_size': (64, 64, 64),
 'train_background_reg': 0.0001,
 'train_foreground_only': 0.9,
 'train_dist_loss': 'mae',
 'train_loss_weights': (1, 0.2),
 'train_epochs': 400,
 'train_steps_per_epoch': 100,
 'train_learning_rate': 0.0003,
 'train_batch_size': 6,
 'train_n_val_patches': None,
 'train_tensorboard': True,
 'train_reduce_lr': {'factor': 0.5, 'patience': 40, 'min_delta': 0},
 'use_gpu': True}

Hi,

things I would try (judging from your images)

  • switch to backend="unet"
  • use larger patch_size=(128,128,128) and batch_size=1
  • use grid=(2,2,2) (its strange that it didn’t work for you)

You could additionally use basic augmentation (lateral fliprot and random intensity scaling), which is already part of the 3D example notebook in the dev branch.

Hope that helps!

M

PS: Are these real or synthetic images?

@mweigert thanks for the suggestions. Did you mean “backbone”?

They are synthetic volumes.

Ah, well, yes :slight_smile:

Great suggestion. With unet backbone after 20 epochs scores have improved.

1 Like

I am trying to get the highest scores by overfitting a single image that itself is the GT mask without the cell IDs. This basically is going to set the upper bound for the scores. I train for 400 epochs and loss gets to about 0.36 and reaches a plateau. Here is an example:

Some cells may slightly deviate from star-convex criteria. Would that prevent us from getting near perfect scores for say iou=0.7? My accuracy, recall stay below 0.8 for iou=0.7.

Any suggestions?

Abbas

What config did you use?

Here are the parameters:

{'n_dim': 3,
 'axes': 'ZYXC',
 'n_channel_in': 1,
 'n_channel_out': 97,
 'train_checkpoint': 'weights_best.h5',
 'train_checkpoint_last': 'weights_last.h5',
 'train_checkpoint_epoch': 'weights_now.h5',
 'n_rays': 96,
 'grid': (2, 2, 2),
 'anisotropy': (1.0, 1.0, 1.0175438596491229),
 'backbone': 'unet',
 'rays_json': {'name': 'Rays_GoldenSpiral',
  'kwargs': {'n': 96, 'anisotropy': (1.0, 1.0, 1.0175438596491229)}},
 'unet_n_depth': 2,
 'unet_kernel_size': (3, 3, 3),
 'unet_n_filter_base': 32,
 'unet_n_conv_per_depth': 2,
 'unet_pool': (2, 2, 2),
 'unet_activation': 'relu',
 'unet_last_activation': 'relu',
 'unet_batch_norm': False,
 'unet_dropout': 0.0,
 'unet_prefix': '',
 'net_conv_after_unet': 128,
 'net_input_shape': (None, None, None, 1),
 'net_mask_shape': (None, None, None, 1),
 'train_patch_size': (128, 128, 128),
 'train_background_reg': 0.0001,
 'train_foreground_only': 0.9,
 'train_dist_loss': 'mae',
 'train_loss_weights': (1, 0.2),
 'train_epochs': 400,
 'train_steps_per_epoch': 100,
 'train_learning_rate': 0.0003,
 'train_batch_size': 4,
 'train_n_val_patches': None,
 'train_tensorboard': True,
 'train_reduce_lr': {'factor': 0.5, 'patience': 20, 'min_delta': 0},
 'use_gpu': True}

Did you run the first notebook?

Among other things, it produces a plot of the theoretical achievable accuracy depending on the number of rays. So if this is <0.8 for 96 rays, then this is the best you can get…

Yes, the score for 96 rays is close to 0.88, it approaches 0.9 for 256 rays. I will train with a higher number of rays. How difficult is to get close to the theoretical accuracy max?

@mweigert here are the results for 128 rays:

FN starts at 17 (green is GT and red is model predicted. yellow green-red overlap). All the cells missed are the ones not fully in the volume. Typically smaller once are missed. These cells were reconstructed by relabel_image_stardist3D with 128 rays:
GT_vs_pred

Why would the model catch some of the incomplete cells and miss some? Any way to improve FN?

Did you try to choose a smaller value for prob_thresh when you call model.predict_instances?

Other things to try that might have an effect (not sure at all):

  • patch_size=(96,96,96)
  • unet_n_depth=3

Thanks @uschmidt83 for the suggestions. I was using optimize_thresholds to set the threshold. But choosing a lower prob_thresh than suggested decreased FN (increasing FP).

Abbas