Training StarDist on MultiChannel images

Thanks for developing StarDist Uwe Schmidt et al.

Currently in our centre, cell segmentation is done manually by experts and they use multiple channels (usually dystrophin and VDAC) context to identify 1) All the viable cells in the sample & 2) boundaries of these cells

I am trying to build a cell segmentation machine learning model that can segment all the “viable” cells in IMC (image mass cytometry) and EM images of muscle biopsies.

For this, I trained the 2D StarDist model on 29 single channel IMC (image mass cytometry) images of muscle fibres + 470 images from DSB2018. We found this trained model was good at predicting the boundaries of cells but not good with filtering non-viable cells.

I can see the obvious reason for this i.e. the model is not trained on multiple-channel. I understand currently StarDist doesn’t support multi-channel/stacked images as input. But do you have advice on how to deal with this?

if need be I am ready to spend time on modifying StarDist to achieve this objective.

many thanks
PS, I am a trained machine learning engineer

Hi @Atif_Khan

I am not entirely sure what is meant by “viable cells” (maybe you could show an image demonstrating the problem?), but

stardist does support multi-channel input when training a custom model. So maybe training with multichannel input might help?

All the best,

Dear Martin

I have a similar issue, I am working with fluorescence images and I also need to run a training using multi-channel images and 3D StarDist. In my case, I need to segment the nuclei of the cells contained in cell organoid fluorescence microscopy images.
My question is, to perform this training do I need to have a multi-channel annotation as well or I just need to use the multi-channel images I want to use as the fluorescent data and the one channel annotation stack as the target? Does StarDist 3D supports multi-channel training?

Thank you very much

Hi @xgalindo,




Super thanks for all your help :smiley:

Dear Uwe

I already tried to run my 3D Stardist training using my 2 channels images fluorescence data however, I had some problems in the output of the plots and I am not able to run the training due to a error. Here I sent some pictures illustrating the error I am getting. I am using a 16-bit 2 channels images of size 340x310 and 85 frames and another one of 16-bit image of size 197x247 and 80 frames. I have tried to run a training using the 1 channel version of these images and I have not had any problem with their sizes and deep. I hope you can help me, thank in advance.

Something’s clearly wrong since the image isn’t displayed properly. Our notebooks expect input images with semantic axes ZYXC if I remember correctly. Please do a print(X[0].shape, Y[0].shape) to check your image dimensions and axes.

The error messages is due to line 94: “images and masks should have corresponding shapes/dimensions”. I suspect this is again caused by the channel axis not being last, as our notebooks expect.

(The error message isn’t correctly displayed here due to a small bug, which has been fixed in version 0.6.2. Please update StarDist.)


Dear @uschmidt83
I already printed X[0].shape and Y[0] and is clear that the shape of X[0] is not the correct one. In this case I am questioning myself if I built the 2 channel images in the correct way because I do not understand why the shape is ZCYX instead of ZYXC.

In my case, I have 2 image stacks taken with different lasers(561nm and 647nm) to have different information shown in the picture. Both images are 16-bit, 85 frames and 340x310 and to build the 2 channel stack I used Fiji—>Image—>Color—>C1= 647 image, C2=561 image with keep originals and Create composite clicked on

Is this the correct way to build a multichannel image for 3D Stardist using Fiji or there is another way to do it to have the correct shape form ?


Thanks in advance for your help


You can change the axes of your images after loading them in Python, i.e. you don’t have to change your images on disk:

from csbdeep.utils import move_image_axes
X = [move_image_axes(x, 'ZCYX', 'ZYXC') for x in X]

Sorry, can’t help you with how to correctly save your images with Fiji. Maybe someone else can.


Dear @uschmidt83

Thank you for your reply, I was able to run my training, but now I am having problems to run the prediction jupyter notebook, the error I am having now is shown in the picture below. The image I am using to run this notebook is a fluorescence 16-bit image stack of 85 frames and 1398 X1532X3.

I hope you can help me :smiley: !

This just means your GPU ran out of memory for such images you can use n_tiles option for doing the model prediction:

labels, details = model.predict_instances(img, n_tiles = (2,2,1))

Maybe check if all other notebook Kernels are shutdown before restarting and running this one?

Ah you said 85 frames, maybe that is the issue if you trained on a bunch of 2D 3 channel images you have to apply the prediction frame by frame than feeding in all the timepoints at once. You can do this by:

Label = np.zeros([img.shape[0], img.shape[1], img.shape[2], img.shape[3]])
for i in range(0, img.shape[0]):
    Label[i,:], _ = model.predict_instances(img[i,:])
1 Like

Thanks for your quick answer, sadly changing the number of tiles did not work, I tried and this error appeared (attached image). before running the notebook again I verified that all the other notebooks were shutdown.

Following the conversation :smiley:

I did my training using the same type of images, 85 fluorescence images stack of size 1398 X1532X3 using 3D StarDist Jupyter notebook. In this case, your solution still applies?

Thanks for all your help


@kapoorlab had the right idea, but please modify the command to use more tiles and you need to set a tile of 1 for the channel axis.

labels, polys = model.predict_instances(img, axes='ZCYX', n_tiles=(1,1,8,8))
# OR, if your images have been converted to this axis order:
# labels, polys = model.predict_instances(img, axes='ZYXC', n_tiles=(1,8,8,1))
1 Like

Dear @uschmidt83

Thanks for your reply, In this case, my image shape is X[0].shape=(85, 1398, 1532, 3) thus, I decided to use

labels, polys = model.predict_instances(img, axes=‘ZYXC’, n_tiles=(1,8,8,1))

and I got the next error


What could be happening?
Thanks for your help.


It seems you still need to increase the number of tiles, try (1,16,16,1).

Thanks for the answer, it seems this worked to do the prediction but I got the same error in the following cells, What could be causing this problem if the prediction is apparently done correctly?


It is the missing n_tiles in the prediction function again. So in all of your notebooks you would have to add that keyword in the prediction function to avoid this error.

Tanks for all your answers @uschmidt83 and @kapoorlab, my prediction is running using

labels, polys = model.predict_instances(img, axes=‘ZCYX’, n_tiles=(1,16,16,1)).

However, the prediction is very poor, mostly inexistent (1rst image attached). However, I have a question, why in the training notebook they show a prediction example (2nd image attached) that is way better than the prediction done by the prediction jupyter notebook, If they are using the same model to do the predictions why are they so different? Shouldn’t the predictions be the same?

Prediction image example image taken from the predinction notebook

Plot a GT/prediction example taken from the training notebook

Thanks for all the help.