\(•ᴗ•)/ QuPath scripting (#7): clij multi-GPU processing

clij GPU processing in QuPath scripts has been investigated in
\(•ᴗ•)/ QuPath scripting (#1): Using CluPath to save smoothed image regions
and
\(•ᴗ•)/ QuPath scripting (#2): Using CluPath and IJ to show image with additional local threshold channel

Here I have studied how multiple GPUs and/or multiple GPU workers can be used in QuPath scripts.

@haesleinhuepf and the following post gave the impulse for this work:

Some background information

QuPath uses highly efficient multi-thread processing in several places, e.g. in the OMEPyramidWriter.
This means that QuPath calls the readBufferedImage() function of an ImageServer from parallel threads multiple times.

In the above mentioned posts I used a TransformingImageServer to create temporary image entries.
In readBufferedImage() I uses clij GPU functions to create additional image data.

Now the question arises how parallel calls of readBufferedImage(), which each perform GPU calls separately, are processed.
Do the individual calls of function readBufferedImage() have to wait for the GPU call of the previous readBufferedImage() call to finish?
Does a single GPU slow down the parallel calls of readBufferedImage()?

Interestingly, multi-threaded calls of readBufferedImage() uses the capabilities of a single GPU in a multi-threaded manner.
A single GPU does not slow down multi-threaded readBufferedImage() calls by sequentially processing of the GPU calls.
GPU execution scheduling appears to be highly efficiently controlled even under these multi-threaded conditions.
And this scheduling is used in the scripts of the posts #1 and #2 automatically without additional development effort.

(This can be shown by assigning an index number to each of the ´readBufferedImage()´ calls and displaying the index after each of the clij calls. see out commented print commands in the script code).

So, using clij calls on a single GPU does not conflict with parallel processing in QuPath.

Using a single GPU can speed up the execution of a QuPath script.

But would it also be possible to call clij functions on multiple GPUs or on multiple workers on a single GPU?
Would it be possible to support the QuPath multi-threading calls by multiple GPU/workers additionally?

CLIJ multi-GPU/ multi-Workers in a QuPath script

The following script supports multiple GPUs and - when using high-end GPUs like NVidia RTX 2070 or NVidia RTX 2080 - multiple workers per GPU:

showAdditionalChannels_Absorbance_LocalThreshold_2c.txt (22.3 KB)

Here is the relevant part of the script:

// ******  Parameters  *******

// Gauss smoothing parameter
double gaussSigma = 1.0

// Auto Local Threshold (Phansalkar method)
double k = 0.125 //
double r = 0.5
double radius = 10.0

// ******  Script  *******

Project project = QP.getProject()

def entry = QP.getProjectEntry()

if (entry == null){
    print 'No image is loaded!' + '\n'
    return
}

def name = entry.getImageName()
print name + '\n'

def imageData = entry.readImageData()
def currentServer = imageData.getServer()

// Create channel list from existing channels
def channels = []
channels.addAll(updateChannelNames(name, currentServer.getMetadata().getChannels()))
// and add additional channels
channels.add(ImageChannel.getInstance(name +"_absR", ColorBrown))
channels.add(ImageChannel.getInstance(name +"_LT", ColorMagenta))
println 'n channels: ' + channels.size() + '\n'

// Create the new server, the new image & add to the project
def server8bit = new TypeConvertServer_AddChannels_clij_ij_2b(currentServer, channels, gaussSigma, k, r, radius)
def imageDataCreated = new ImageData<BufferedImage>(server8bit)
imageDataCreated.setImageType(ImageData.ImageType.FLUORESCENCE)

Platform.runLater {

    // Create new project entry for the new channels ...
    // (use FLUORESCENCE image type to have access to the single channels directly
    def newentry = ProjectCommands.addSingleImageToProject(project, currentServer, ImageData.ImageType.FLUORESCENCE)
    BufferedImage thumbnail = ProjectCommands.getThumbnailRGB(server8bit)
    newentry.setThumbnail(thumbnail)
    newentry.setImageName("My New Image")
    QPEx.getQuPath().openImageEntry(newentry)

    QPEx.getQuPath().refreshProject()

    QPEx.getCurrentViewer().setImageData(imageDataCreated)

    QPEx.getQuPath().openImageEntry(project.getEntry(imageDataCreated))

    QPEx.getQuPath().refreshProject()


    // .. or display the new channels in the current viewer of the original image instead
    //QPEx.getCurrentViewer().setImageData(imageDataCreated)
}

println 'Done!' + '\n'

// ****  end of script  ****



// Prepend a base name to channel names
List<ImageChannel> updateChannelNames(String name, Collection<ImageChannel> channels) {
    return channels
            .stream()
            .map( c -> {
                return ImageChannel.getInstance(name + '-' + c.getName(), c.getColor())
                }
            ).collect(Collectors.toList())
}


class TypeConvertServer_AddChannels_clij_ij_2b extends TransformingImageServer<BufferedImage> {
    //CLUPATH clupath

    CLIJx[] poolCLIJx
    boolean[] idle

    int extNP
    double gaussSigma
    float k, r, radius
    float[] absorbance

    ImageServer<BufferedImage> currentserver
    private List<ImageChannel> channels
    private ImageServerMetadata originalMetadata
    def cm = ColorModelFactory.getProbabilityColorModel32Bit(channels)

    ArrayList<ColorTransformer.ColorTransformMethod> colortransformation = new ArrayList<ColorTransformer.ColorTransformMethod>()

    TypeConvertServer_AddChannels_clij_ij_2b(ImageServer<BufferedImage> server, List<ImageChannel> channels,
                                       double gaussSigma, double k, double r, double radius) {
        super(server)
        currentserver = server
        this.channels = channels
        this.gaussSigma = gaussSigma
        this.k = k
        this.r = r
        this.radius = radius

        // Number of pixels of region extension
        extNP = Math.max(3, (int)(3 * gaussSigma))
        extNP = Math.max(extNP, Math.ceil(radius))

        this.originalMetadata = new ImageServerMetadata.Builder(currentserver.getMetadata())
               //.pixelType(PixelType.FLOAT32)
                .pixelType(PixelType.UINT8)
                .rgb(false)
                .channels(channels)
                .build()

        colortransformation.add(ColorTransformer.ColorTransformMethod.Red)
        colortransformation.add(ColorTransformer.ColorTransformMethod.Green)
        colortransformation.add(ColorTransformer.ColorTransformMethod.Blue)

        absorbance = new float[256]
        absorbance[0] = 255.0
        for (int n=1; n<256; n++) {
            absorbance[n] = -100.0*Math.log(n/255.0)   // scaling to ensure absorbance<=255 => * -46.0
        }

        // CLIJX single instance
        //clijx = CLUPATH.getInstance("")
        // optinal: Select a specific GPU
        //clupath = CLUPATH.getInstance("GeForce")
        //print(clupath.getGPUName() + '\n')

        // CLIJ Multi-GPU support: Get number of devices
        def deviceList = CLIJ.getAvailableDeviceNames()
        int devCnt = 0
        for (device in deviceList) {
            if (device.contains("UHD") || device.contains("GTX") || device.contains("RTX")) {
                print(device + '\n')
                if (device.contains("2070") )
                    devCnt += 2
                else if (device.contains("2080"))
                    devCnt += 4
                else
                    devCnt++
            }
        }

        // CLIJ Multi-GPU support: Create GPU device pool
        // Create list in reversed order to add NVidia first (in my case)
        // Approach not tested with nWorkers > 1 yet (no such GPU available during the test)
        poolCLIJx = new CLIJx[devCnt]
        idle = new boolean[poolCLIJx.length]
        devCnt--
        for (int i =0; i < deviceList.size(); i++) {
            String device = deviceList.get(i)
            int nWorkers = 1
            if (device.contains("UHD") || device.contains("GTX") || device.contains("RTX")) {
                if (device.contains("2070"))
                    nWorkers = 2
                if (device.contains("2080"))
                    nWorkers = 4
                for (int n = 0; n < nWorkers; n++ ) {
                    poolCLIJx[devCnt] = new CLIJx(new CLIJ(device))
                    poolCLIJx[devCnt].clear()
                    idle[devCnt] = true
                    devCnt--
                }
            }
        }
    }

    public ImageServerMetadata getOriginalMetadata() {
        return originalMetadata
    }

    @Override
    protected ImageServerBuilder.ServerBuilder<BufferedImage> createServerBuilder() {
        return currentserver.builder()
    }

    @Override
    protected String createID() {
        return UUID.randomUUID().toString()
    }

    @Override
    public String getServerType() {
        return "My 8bit Type converting image server"
    }

    public BufferedImage readBufferedImage(RegionRequest request) throws IOException {
        int idx, idxRoi

        String path = request.getPath()
        int ds = request.getDownsample()

        int z = request.getZ()
        int t = request.getT()

        int extNPminX, extNPmaxX, extNPminY, extNPmaxY
        extNPminX = extNPmaxX = extNPminY = extNPmaxY = extNP

        int xExt = request.x - extNPminX * ds
        int yExt = request.y - extNPminY * ds

        if (xExt < 0) {
            extNPminX = (int) (1.0 * request.x / ds)
            xExt = request.x - extNPminX * ds
        }
        if (yExt < 0) {
            extNPminY = (int) (1.0 * request.y / ds)
            yExt = request.y - extNPminY * ds
        }

        int wExt = request.width + (extNPminX + extNPmaxX) * ds
        int hExt = request.height + (extNPminY + extNPmaxY) * ds

        //print 'request: ' + ds + ' ' + request.x + ' ' + request.y + ' ' + request.width + ' ' + request.height + '\n'
        //print 'requestExt: ' + ds + ' ' + xExt + ' ' + yExt + ' ' + wExt + ' ' + hExt + '\n'

        // Create extended region to avoid edge effects of gaussian smoothing
        ImageRegion region = ImageRegion.createInstance(xExt, yExt, wExt, hExt, z, t)
        RegionRequest requestExt = RegionRequest.createInstance(path, ds, region)
        def img = getWrappedServer().readBufferedImage(requestExt)

        def raster = img.getRaster()

        int nBands = raster.getNumBands()
        int w = img.getWidth()
        int h = img.getHeight()
        int wRoi = w - extNPminX - extNPmaxX
        int hRoi = h - extNPminY - extNPmaxY

        SampleModel model = new BandedSampleModel(DataBuffer.TYPE_BYTE, wRoi, hRoi, nBands + 2)
        byte[][] bytes = new byte[nBands + 2][wRoi*hRoi]
        DataBufferByte buffer = new DataBufferByte(bytes, wRoi*hRoi)
        WritableRaster raster2 = Raster.createWritableRaster(model, buffer, null)

        GaussianBlur gb = new GaussianBlur()

        float[] pixelsRoi = null
        float[] pixels = new float[w*h]
        FloatProcessor fp = new FloatProcessor(w, h, pixels)

        int[] rgb = img.getRGB(0, 0, w, h, null, 0, w)

        // sequence B-R-G
        for (int b = nBands-1; b>=0; b--) {
        //-x-for (int b = 0; b<1; b++) {

            pixels = ColorTransformer.getSimpleTransformedPixels(rgb, colortransformation.get(b), null)

            // IJ blur
            fp.setPixels(pixels)
            gb.blurFloat(fp, gaussSigma, gaussSigma, 2.0E-4D)
            pixels = (float[]) fp.getPixels()

            if (pixelsRoi == null)
                pixelsRoi = new float[wRoi * hRoi]

            // Crop original region from extended region
            for (int y = extNPminY; y < h - extNPmaxY; y++) {
                idx = y * w + extNPminX
                idxRoi = (y - extNPminY) * wRoi
                System.arraycopy(pixels, idx, pixelsRoi, idxRoi, wRoi)
            }

            // Add the original RGB channels
            raster2.setSamples(0, 0, wRoi, hRoi, b, pixelsRoi)
        }

        // here: only Red channel is used to derive additional channels
        for (int b=0; b<1; b++) {
             // relevant content (smoothed Red channel) is still in fp, pixels and pixelsRoi

            // I: Create and add absorbance Red channel
            // relevant content is still in pixelsRoi
            for (int i=0; i<pixelsRoi.length; i++)
                pixelsRoi[i] = absorbance[(int)pixelsRoi[i]]

            raster2.setSamples(0, 0, wRoi, hRoi, nBands + 0, pixelsRoi)

            // relevant content (smoothed Red channel) is still in fp and pixels

            long startT = System.currentTimeMillis()

            // CLIJ Multi-GPU support: Find idle GPU device
            CLIJx clijx = null
            boolean getDev = true
            int poolIdx = -99
            while (getDev) {
                for (int i = 0; i < poolCLIJx.length; i++) {
                    if (idle[i]) {
                        idle[i] = false
                        clijx = poolCLIJx[i]
                        getDev = false
                        poolIdx = i
                        break
                    }
                }
                //ToDo: This is a potential endless loop. Fix this.
                try{
                    if (getDev)
                        Thread.sleep(10)
                }
                catch(InterruptedException e){
                    e.printStackTrace()
                    return null
                }
            }
            //print 'run on: ' + poolIdx + ' :' + clijx.getGPUName() + '\n'

            //int randNum = Math.random()*1000

            // II: Create and Add Channel: Local Threshold
            // relevant content is still in pixels
            // CLIJX: generate GPU memory buffer
            ClearCLBuffer input = clijx.pushArray(pixels, w, h, 1)
            ClearCLBuffer input2 = clijx.create(input.getDimensions(), clijx.Float)
            ClearCLBuffer clbMean = clijx.create(input.getDimensions(), clijx.Float)
            ClearCLBuffer clb1 = clijx.create(input.getDimensions(), clijx.Float)
            ClearCLBuffer clb2 = clijx.create(input.getDimensions(), clijx.Float)
            ClearCLBuffer clbtmp = clijx.create(input.getDimensions(), clijx.Float)

            // Auto Local Threshold (Phansalkar method) see: https://imagej.net/Auto_Local_Threshold
            // see code in:
            // https://github.com/fiji/Auto_Local_Threshold/blob/59f319b59e000f70d348577c388fe02188250f39/src/main/java/fiji/threshold/Auto_Local_Threshold.java
            // t = mean * (1 + p * exp(-q * mean) + k * ((stdev / r) - 1))

            double p = 2.0
            double q = -10.0

            // Calculation of the Auto Local Threshold (Phansalkar method)
            // 'Smooth' and 'Normalize' Input
            clijx.multiplyImageAndScalar(input, input2, 1.0/255.0)
            //print 'rand: ' + randNum + '\n'

            //clijx.mean2DBox(input2, clbMean, radius, radius)
            //clijx.standardDeviationBox(input2, clb2, radius, radius, 0.0)
            clijx.mean3DSphere(input2, clbMean, radius, radius, 0.0)
            //clijx.mean3DSphere(input2, clbMean, radius, radius, radius)
            //print 'rand: ' + randNum + '\n'

            //clijx.standardDeviationSphere(input2, clb2, radius, radius, 0.0)
            //-clijx.standardDeviationSphere(input2, clb2, radius, radius, radius)
            //print 'rand: ' + randNum + '\n'

            // StdDev by sqrMean (  stddev = sqrt(sqrmean - mean*mean)  )
            clijx.power(clbMean, clb1, 2.0)
            clijx.power(input2, clbtmp, 2.0)
            clijx.mean3DSphere(clbtmp, clb2, radius, radius, 0.0)
            //clijx.mean3DSphere(clbtmp, clb2, radius, radius, radius)
            clijx.subtractImages(clb2, clb1, clbtmp)

            clijx.power(clbtmp, clb2, 0.5)

            // my modification
            //clijx.power(clb2b, clb2a, 0.5)

            // stddev part of sum
            clijx.multiplyImageAndScalar(clb2, clb1, 1.0/r)
            //print 'rand: ' + randNum + '\n'
            clijx.addImageAndScalar(clb1, clbtmp, -1)
            //print 'rand: ' + randNum + '\n'
            clijx.multiplyImageAndScalar(clbtmp, clb2, k)
            //print 'rand: ' + randNum + '\n'
            // mean part of sum
            clijx.multiplyImageAndScalar(clbMean, clbtmp, q)
            //print 'rand: ' + randNum + '\n'
            clijx.exponential(clbtmp, clb1)
            //print 'rand: ' + randNum + '\n'
            clijx.multiplyImageAndScalar(clb1, clbtmp, p)
            //print 'rand: ' + randNum + '\n'
            clijx.addImageAndScalar(clbtmp, clb1, 1)
            //print 'rand: ' + randNum + '\n'
            // combined both parts
            clijx.addImages(clb1, clb2, clbtmp)
            //print 'rand: ' + randNum + '\n'
            clijx.multiplyImages(clbMean, clbtmp, clb1) // clb1 is threshold
            //print 'rand: ' + randNum + '\n'

            // apply threshold
            //clijx.greaterOrEqual(input2, clb1, clb2)
            clijx.greaterOrEqual(clb1, input2, clbtmp)
            //print 'rand: ' + randNum + '\n'

            // Invert Result #2
            //clijx.binaryNot(clb2, clbtmp)
            //print 'rand: ' + randNum + '\n'
            clijx.multiplyImageAndScalar(clbtmp, input2, 255.0)
            //print 'rand: ' + randNum + '\n'

            // END of Calculation of the Auto Local Threshold (Phansalkar method)

            ImagePlus imp = clijx.pull(input2)
            //print 'rand: ' + randNum + '\n'

            long endT = System.currentTimeMillis()
            //print 'Time_v2c: ' + (endT - startT) + '\n'

            float[] pixelsCLIJX = (float[]) imp.getProcessor().getPixels()

            // Crop original region from extended region
            for (int y = extNPminY; y < h - extNPmaxY; y++) {
                idx = y * w + extNPminX
                idxRoi = (y - extNPminY) * wRoi
                System.arraycopy(pixelsCLIJX, idx, pixelsRoi, idxRoi, wRoi)
            }

            raster2.setSamples(0, 0, wRoi, hRoi, nBands + 1, pixelsRoi)
            //-x-raster2.setSamples(0, 0, wRoi, hRoi, 0, pixelsRoi)

            // END of Add Channel: Local Threshold

            // close GPU memory buffers
            input.close()
            input2.close()
            clbMean.close()
            clb1.close()
            clb2.close()
            clbtmp.close()

            // CLIJ Multi-GPU support: Set GPU device idle
            idle[poolIdx] = true
        }

        return new BufferedImage(cm, Raster.createWritableRaster(model, buffer, null), false, null)
    }
}

The script is based on \(•ᴗ•)/ QuPath scripting (#2): Using CluPath and IJ to show image with additional local threshold channel - #3 by phaub

The extension of the script was essentially derived from clijx-faclon-heavy/CLIJxPool.java at 1f78d95e01c39ef4a750231d43dcaa461213065c · clij/clijx-faclon-heavy · GitHub

The relevant changes are marked in the code as // CLIJ Multi-GPU support:.

  • In the constructor of the TransformingImageServer a pool of CLIJ devices is created.
  • In the function readBufferedImage() the next free device is selected from this pool and released after execution of the function.

This article and the script posted here should not be considered a finished solution but rather a proof-of-concept study.

:slightly_smiling_face:

@haesleinhuepf Maybe faclon-heavy should be included in clupath in the future. (?)

2 Likes

Great idea, @phaub ! If I read it correctly, you’re not using faclon-heavy infrastructure, but you copied parts of it, right?

I need a better understanding of QuPaths image server infrastructure but I agree, those two might nicely work together :slightly_smiling_face:

1 Like

That’s correct.
I have replicated the pool in the QuPath script.

You will understand what I have done right away when you check the code at

and

and

It would be nice if functions for creating the pool, retrieving the next free worker, releasing a worker … could be provided by faclon-heavy in QuPath.

AbstractTileProcessor and TileProcessor are not needed in QuPath.

I think there will be no problem using faclon-heavy in QuPath.

… as in other cases: This is not a high priority. If someone wants to ‘play’ with multi-GPUs, everything that is needed is already there (here).

1 Like