Multi-GPU usage of CLIJ

Following on from a thread on Twitter here, I’d be very interested to learn more about using multiple GPUs with @haesleinhuepf’s CLIJ plugin.

The context is that we have access to an HPC system where we have some nodes with 4 GPUs per-node. In principle we can request multiple nodes as well - would it work across nodes too?

At the moment I’m mostly playing around with our visualisation nodes and running Fiji via the GUI, but headless batch-processing would also be of interest! :slight_smile:

I’d be very pleased to hear anybody’s advice or experience with this!

Cheers,
Martin

2 Likes

Hey @Martin_Jones,

thanks for reaching out. CLIJ Multi-GPU usage appears to me like a high goal, especially for tile-able and/or mutli-channel timelapse data. I give you a short tour of what happened on development side. Spoiler: Advanced coding skills are necessary - yet!

Let me answer your questions first

Multiple GPUs: Yes. I would actually recommend to run two threads per GPU. It depends a bit on how long your workflow is and how small the images are. Long workflows and small images means: Many threads per GPU make sense. Side note: CLIJs multi-threading GPU support was nicely explored and made possible by @maarzt.
Multi-Cluster-Node: There is no support yet for that. I think this should be solved on a different level; outside CLIJ and outside Java actually.

CLIJ is per default headless. It doesn’t show any result image until you explicitly pull/show it. There is also no graphical user interface involved if you call methods from scripts. If you want to run CLIJ from the command line, I recommend taking a look at this example. If you prefer executing docker containers with clij inside, this Apeer example might be of interest. Both use CLIJ and were never updated to CLIJ2 because of limited community interest. But I’m happy to update it if necessary. It should be pretty straight-forward.

Guided tour

In ImageJ, when having a timelapse dataset open and you process it with classical filters, the GUI will ask you “Do you want to process the whole stack?”. If CLIJ would work like this (it doesn’t), that would mean, all individual time points would pushed to GPU, processed and results would be pulled back; for every single operation! That means, the processing is CPU-GPU-memory-transfer-dominated and thus, wouldn’t work efficiently. More details on this issue on youtube. We need to go a different way.

Level 1: Scripting

In order to exploit multiple GPUs efficiently, we need to define the GPU-accelerated processing workflow as a thing/object/class and deploy it to the GPUs. “Dear GPUs, here is a workflow, please execute it in parallel.” The Jython scripting way of doing it looks like this:

Level 2: Java

Programming classes in Jython appears unnatural to me (maybe it’s just me). I prefer programming classes in Java and then just calling them from Jython. If you also think so, that code example (similar to the Jython version above) might be interesting for you as well:

In both above cases, you still need to organize what is processed by which Processor on which GPU. In case of multi-timepoint/channel data, that might always be the same boiler plate code. Thus, I implemented a way of generalizing that:

Level 3: Framor

Assume, you have a Java interface (FrameProcessor) which defines a procedure as “This thing can be executed on frames independently”. You can then implement workflows (mini examples here) which can be automatically deployed to multiple GPUs. Framor takes care or organizing what is when processed where.

Level 4: Tilor

What we can do with frames, can obviously also be done with tiles. I programmed it (actually before Framor) and it turned out that this doesn’t work efficiently with ImageJ data structures. To manage the tiles more efficiently, we need to go for imglib2. I discussed that with @axtimwalde here. I’m pretty convinced that this will work but didn’t find time yet to implement it. I focused on:

Level 0: Auto-generating GPU-accelerated code

Obviously, for all the above strategies you need a person who knows what functions CLIJ offers and how to assemble them to a workflow in advanced scripting or objected oriented programming languages. This is the major bottleneck. Thus, I programmed an interactive workflow-designer which comes with an expert system suggesting what operations are applicable/worthwhile to your images. This tool also allows you to generate code in various scripting languages which you could copy&paste into the Jython example. Furthermore, you can generate Java code for Fiji-Plugins minimizing efforts on the coding side.

One future step is to generate Fiji-Plugins which implement the above mentioned FrameProcessor interface. With that, you could design a workflow on screen and then deploy it together with a timelapse-dataset to multiple GPUs. Would be awesome, no? :wink:

If you want to dive into this level 0, I would be happy to see you at I2K! I’m introducing the public to Level 0 :smiley:

If you have any questions or need more pointers, let me know!

Cheers,
Robert

7 Likes

Thanks for the detailed response @haesleinhuepf! I’ll definitely try to sign up for the I2K session and I’ll let you know how I get on trying things out.

Cheers,
Martin

1 Like

Hi @Martin_Jones, hi @axtimwalde,

FYI: I took @axtimwalde’s tutorial code from #I2K2020 and started knitting a convenience layer around to provide user-friendly multi-GPU support for processing big image data tile-by-tile; working title “faCLon heavy”.

With this, you can define an image processing workflow using clijx in this format (complete example):

@Override
public void accept(ClearCLBuffer input, ClearCLBuffer output) {
    System.out.println("Start processing on " + clijx.getGPUName() + " image dimensions " + Arrays.toString(input.getDimensions()));
    long start_time = System.nanoTime();

    // allocated temporary memory
    ClearCLBuffer temp = clijx.create(input);

    // process the image
    clijx.differenceOfGaussian(input, temp, 1, 2, 3, 4, 5, 6);
    clijx.thresholdOtsu(temp, output);

    // clean up
    temp.close();
    long duration_ms = (System.nanoTime() - start_time) / 1000000;
    System.out.println("Finished processing on " + clijx.getGPUName() + " after " + duration_ms + " ms");
}

On a computer with multiple GPUs, you can define a pool of CLIJx instances for parallel processing, for example to run one thread on an integrated Intel GPU parallel to four threads on a dedicated/external NVidia RTX GPU:

CLIJxPool pool = CLIJxPool.fromDeviceNames(
                new String[]{"Intel(R) UHD Graphics 620", "RTX"},
                new int[]{                             1,     4});

You can then execute the workflow on an N5-compatible input image tile-by-tile:

// define data location. If you don't have this dataset, create it using the MakeBigData class in the same folder
final N5Reader n5 = N5Factory.openReader("C:/structure/data/n5/example.n5");
final RandomAccessibleInterval<FloatType> img = N5Utils.openVolatile(n5, "/volumes/raw");
// convert the input image from any type to FloatType
final RandomAccessibleInterval<FloatType> floats = Converters.convert(img, (a, b) -> b.set(a.getRealFloat()), new FloatType());

int margin = 20;
final CLIJxFilterOp<FloatType, FloatType> clijxFilter =
        new CLIJxFilterOp<>(Views.extendMirrorSingle(floats), pool, DummyFilter.class, margin, margin, margin);

// make a result image lazily
final RandomAccessibleInterval<FloatType> filtered = Lazy.generate(
        img,
        new int[] {256, 256, 256},
        new FloatType(),
        AccessFlags.setOf(AccessFlags.VOLATILE),
        clijxFilter);

Output:

Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Finished processing on GeForce RTX 2080 Ti after 949 ms
Finished processing on GeForce RTX 2080 Ti after 476 ms
Finished processing on GeForce RTX 2080 Ti after 1229 ms
Finished processing on GeForce RTX 2080 Ti after 973 ms
Start processing on Intel(R) UHD Graphics 620 image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Finished processing on GeForce RTX 2080 Ti after 429 ms
Finished processing on GeForce RTX 2080 Ti after 434 ms
Finished processing on GeForce RTX 2080 Ti after 460 ms
Finished processing on GeForce RTX 2080 Ti after 351 ms
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Finished processing on GeForce RTX 2080 Ti after 370 ms
Finished processing on GeForce RTX 2080 Ti after 376 ms
Finished processing on GeForce RTX 2080 Ti after 360 ms
Finished processing on GeForce RTX 2080 Ti after 395 ms
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Finished processing on GeForce RTX 2080 Ti after 339 ms
Finished processing on GeForce RTX 2080 Ti after 313 ms
Finished processing on GeForce RTX 2080 Ti after 342 ms
Finished processing on GeForce RTX 2080 Ti after 375 ms
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Finished processing on GeForce RTX 2080 Ti after 240 ms
Finished processing on GeForce RTX 2080 Ti after 206 ms
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Finished processing on GeForce RTX 2080 Ti after 243 ms
Finished processing on Intel(R) UHD Graphics 620 after 15825 ms
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Finished processing on GeForce RTX 2080 Ti after 151 ms
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Start processing on Intel(R) UHD Graphics 620 image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Finished processing on GeForce RTX 2080 Ti after 282 ms
Finished processing on GeForce RTX 2080 Ti after 503 ms
Finished processing on GeForce RTX 2080 Ti after 414 ms
Start processing on GeForce RTX 2080 Ti image dimensions [296, 296, 296]
Finished processing on Intel(R) UHD Graphics 620 after 1162 ms
Finished processing on GeForce RTX 2080 Ti after 478 ms

The current example code can be found here. Note that real performance benefit from using GPUs will come when executing longer workflows than in the demonstrated example. Furthermore, Stephans/my example code just uses BigDataViewer for visualisation of the processed image. Writing it to disc instead should be straight-forward. I plan to continue developing this convenience layer sporadically, because there is apparently community need. Just saying: I have no practical use case at hand. I will keep you posted and you’re welcome to step in whenever you feel like.

Use-cases / testing / feedback / PRs are welcome! :wink:

Cheers,
Robert

4 Likes

Neat! Now we need something that is worth doing it so. Is there any meaningfully complex non-trivial algorithm available in CLIJ?

BTW, multi-threaded saving/ processing the output is indeed trivial

N5Utils.save(filtered, n5Writer, ..., executorService);
1 Like

Well, clij is a toolbox for making such algorithms ad-hoc when non-trivial challenges arrive.

I could imagine workflows for processing slide-scanner data as this one for determining single-positive and double-positive cell counts and computing likelihood of DNA damage:

Or label classification based on intensity/shape/neighbors:

(both from this preprint

Alternatively, studying tissues, e.g. as shown in those examples would be interesting tile-by-tile because connected-component-labeling is part of the workflows and the assembly of the resulting tiles is not trivial (but doable if we assemble parametric images and no label images):

Btw. almost everything in clij works in 3D even though most examples are demonstrated in 2D.

But again, my data is typically not that big. Can you help out with this @axtimwalde ? :upside_down_face:

Hi Robert, the videos are not visible (for me).

1 Like

You can also download them from zenodo:

https://zenodo.org/record/4276076/files/mouse_brain_cell_classification_dna_damage.mp4?download=1

https://zenodo.org/record/4276076/files/weka_label_classifier_short.mp4?download=1

2 Likes