'Displaying a prediction in TensorFlow Lite Android

I am trying to process an image through a tflite model on my Android tablet in real time, however, the output is not correct. My application contains a SurfaceView that runs in the foreground and displays the portion of the screen directly "behind" it. Each frame's bitmap is then shrunk to 128 x 128 pixels and passed to my model who's shape is 1 x 128 x 128 x 3.

Already I had some issues passing in the bitmap's ByteBuffer to the model since there was a mismatch in array lengths; the TensorBuffer requiring a flatsize of 128x128x3x4=196608 and the image's size being 128x128x4=65536. A quick hack was to allocate a ByteBuffer with three times the size and then copy the pixels to it but, as expected, it contains only zeros at the 65536th position forward. The result is shown in the image below. enter image description here

I am not sure where the issue truly lies.

public void onImageAvailable(ImageReader imageReader) {
    final long startTime = System.currentTimeMillis();
    Image image = null;
    Bitmap bitmap = null;
    Canvas canvas = this.window.getScreenShot().getHolder().lockCanvas();
    Paint paint = new Paint();

    try {
      image = imageReader.acquireLatestImage();
      if (image != null) {
        Image.Plane[] planes = image.getPlanes();
        ByteBuffer buffer = planes[0].getBuffer();
        int pixelStride = planes[0].getPixelStride();
        int rowStride = planes[0].getRowStride();
        int rowPadding = rowStride - pixelStride * config.width;

        // create bitmap
        bitmap = Bitmap.createBitmap(config.width + rowPadding / pixelStride,
                config.height, Bitmap.Config.ARGB_8888);
        bitmap.copyPixelsFromBuffer(buffer);

        //create new bitmap of area where the surfaceview resides
        int[] point = new int[2];
        this.window.getScreenShot().getLocationOnScreen(point);
        Point surfaceCoords = new Point();
        surfaceCoords.set(point[0],point[1]);
        int surfaceWidth = this.window.getScreenShot().getMeasuredWidth();
        int surfaceHeight = this.window.getScreenShot().getMeasuredHeight();
        Bitmap cropImg = Bitmap.createBitmap(bitmap, surfaceCoords.x, surfaceCoords.y, surfaceWidth, surfaceHeight);

        //scale bitmap down to 128 by 128 pixels
        cropImg = getResizedBitmap(cropImg, 128, 128);
        int size = cropImg.getRowBytes() * cropImg.getHeight() * 3;
        ByteBuffer byteBuffer = ByteBuffer.allocate(size);
        cropImg.copyPixelsToBuffer(byteBuffer);

        //INFERENCE
        TensorBuffer outputFeature0 = null;
        try {
          Model model = Model.newInstance(this.ctxt);

          // Creates inputs for reference.
          TensorBuffer inputFeature0 = TensorBuffer.createFixedSize(new int[]{1, 128, 128, 3}, DataType.FLOAT32);
          inputFeature0.loadBuffer(byteBuffer);

          // Runs model inference and gets result.
          Model.Outputs outputs = model.process(inputFeature0);
          outputFeature0 = outputs.getOutputFeature0AsTensorBuffer();

          // Releases model resources if no longer used.
          model.close();
        } catch (IOException e) {
          // TODO Handle the exception
        }

        //Convert output byte array to bitmap and scale it back up
        Bitmap.Config configBmp = Bitmap.Config.valueOf(cropImg.getConfig().name());
        Bitmap bitmap_tmp = Bitmap.createBitmap(cropImg.getWidth(), cropImg.getHeight(), configBmp);
        ByteBuffer new_bbf = ByteBuffer.wrap(outputFeature0.getBuffer().array());
        new_bbf.rewind();
        bitmap_tmp.copyPixelsFromBuffer(new_bbf);
        Bitmap scaledUp = getResizedBitmap(bitmap_tmp, this.window.getScreenShot().getMeasuredWidth(), this.window.getScreenShot().getMeasuredHeight());
        canvas.drawBitmap(scaledUp, 0,0 ,paint);
        this.window.getScreenShot().getHolder().unlockCanvasAndPost(canvas);
      }
      image.close();
      final long endTime = System.currentTimeMillis();
      Log.i("FPS", "Frame Rate: " + 1000/(endTime - startTime) +" FPS");
    } catch (Exception e) {
      e.printStackTrace();
    }
  }


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source