Performance Slowdown on inference of Stardist 2D model in batch

I am using Stardist2D followed by a classification model in my pipeline to identify, enumerate and calculate different features of artifacts.
The issue is that Stardist inference is iteratively taking more time in batch. (Our Images are of very high resolution and it takes from ~350+ seconds for one image). If I keep multiple copies of same image in input and than run through all iteratively (4 slices per image in loop) than the time taken to compute each image is not constant or close. The time taken in every next image follows an upwards trend and at a point even reach 1200 sec per image. (there are iteration that take a little lesser time than last image but the overall trend is upwards in term of time)
image time of the same image on 24th to 31st iteration.

Screenshot of same Image 1st to 5th Iteration in batch. (stardist2d loading and inference highlighted)

I have tested,
1.while calculation of each slice of image are about 100% complete, than memory hits its peak and before going to next slice it releases memory too.
2. no new processes are being created.
3. program releases memory after the run is complete.

On a little research I have found that people have faced performance issue on inference in loop with model.predict() while using keras/tensorflow and I have seen that the function model.Predict_instance() here is also using model.predict(). Do not know if its related or not.
I have been stuck on this issue since quite a while now. I would really be thankful if somebody could help.

Hi @talha_khwaja, welcome to the forum and sorry for the late reply!

In order to help, I really need to see the code that you’re using to benchmark this. Ideally, can you provide me with a Python script or Jupyter notebook that I can run myself to reproduce the issue?