'Image produced is incomplete - Cannot copy to a TensorFlowLite tensor (input_1) with bytes

I am trying to load a tflite model and run it on an image.

My tflite model has the dimensions you see in the image. tflite

Right now, I am receiving:

Cannot copy to a TensorFlowLite tensor (input_1) with 49152 bytes from a Java Buffer with 175584 bytes.

I can't understand how to work with input and output tensor sizes. Right now, I am initializing using the input image size and the output image size will be input * 4.

At which point do I have to "add" the 1 * 64 * 64 * 3 dimensions since I need to manipulate every input image size?

 try {
                    tflitemodel = loadModelFile()
                    tflite = Interpreter(tflitemodel, options)
                } catch (e: IOException) {
                    Log.e(TAG, "Fail to load model", e)
                }

                val imageTensorIndex = 0
                val imageShape: IntArray =
                    tflite.getInputTensor(imageTensorIndex).shape()
                val imageDataType: DataType = tflite.getInputTensor(imageTensorIndex).dataType()
                // Build a TensorImage object
                var inputImageBuffer = TensorImage(imageDataType);

                // Load the Bitmap
                inputImageBuffer.load(bitmap)

                // Preprocess image
                val imgprocessor = ImageProcessor.Builder()
                    .add(ResizeOp(inputImageBuffer.height,
                        inputImageBuffer.width,
                        ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
                    //.add(NormalizeOp(127.5f, 127.5f))
                    //.add(QuantizeOp(128.0f, 1 / 128.0f))
                    .build()

                // Process the image
                val processedImage = imgprocessor.process(inputImageBuffer)

                // Access the buffer ( byte[] ) of the processedImage
                val imageBuffer = processedImage.buffer
                val imageTensorBuffer = processedImage.tensorBuffer

                // output result
                val outputImageBuffer = TensorBuffer.createFixedSize(
                    intArrayOf( inputImageBuffer.height * 4 ,
                        inputImageBuffer.width * 4 ) ,
                    DataType.FLOAT32 )

                // Normalize image
                val tensorProcessor = TensorProcessor.Builder()
                    // Normalize the tensor given the mean and the standard deviation
                    .add( NormalizeOp( 127.5f, 127.5f ) )
                    .add( CastOp( DataType.FLOAT32 ) )
                    .build()
                val processedOutputTensor = tensorProcessor.process(outputImageBuffer)


                tflite.run(imageTensorBuffer.buffer, processedOutputTensor.buffer)

I tried to cast the output tensor either to FLOAT32 or UINT8.

UPDATE

I also tried this :

 try {
         tflitemodel = loadModelFile()
         tflite = Interpreter(tflitemodel, options)
      } catch (e: IOException) {

          Log.e(TAG, "Fail to load model", e)
        }

 val imageTensorIndex = 0
 val imageDataType: DataType = tflite.getInputTensor(imageTensorIndex).dataType()

 val imgprocessor = ImageProcessor.Builder()
                    .add(ResizeOp(64,
                                 64,
                        ResizeOp.ResizeMethod.NEAREST_NEIGHBOR)
                        )
                    .add( NormalizeOp( 0.0f, 255.0f ) )
                    .add( CastOp( DataType.FLOAT32 ) )
                    .build()

 val inpIm = TensorImage(imageDataType)
 inpIm.load(bitmap)

 val processedImage = imgprocessor.process(inpIm)

 val output = TensorBuffer.createFixedSize(
                        intArrayOf(
                            124 * 4,
                            118 * 4,
                            3,
                            1
                        ),
                        DataType.FLOAT32
                    )

 val tensorProcessor = TensorProcessor.Builder()
                        
                        .add( NormalizeOp( 0.0f, 255.0f ) )
                        .add( CastOp( DataType.FLOAT32 ) )
                        .build()

 val processedOutputTensor = tensorProcessor.process(output)


 tflite.run(processedImage.buffer, processedOutputTensor.buffer)

which produces:

this image

Note, that the current image I am using as input has 124 * 118 * 3 dimensions.

The output image will have (124 * 4) * (118 * 4) * 3 dimensions.

The model needs 64 * 64 * 3 as input layer.



Solution 1:[1]

I took a look at your project, your class will be like:

class MainActivity : AppCompatActivity() {


    private val TAG = "SuperResolution"
    private val MODEL_NAME = "model_edsr.tflite"
    private val LR_IMAGE_HEIGHT = 24
    private val LR_IMAGE_WIDTH = 24
    private val UPSCALE_FACTOR = 4
    private val SR_IMAGE_HEIGHT = LR_IMAGE_HEIGHT * UPSCALE_FACTOR
    private val SR_IMAGE_WIDTH = LR_IMAGE_WIDTH * UPSCALE_FACTOR

    private lateinit var photoButton: Button
    private lateinit var srButton: Button
    private lateinit var colorizeButton: Button
    private var FILE_NAME = "photo.jpg"

    private lateinit var filename:String
    private var resultImg: Bitmap? = null

    private lateinit var gpuSwitch: Switch

    private lateinit var tflite: Interpreter
    private lateinit var tflitemodel: ByteBuffer

    private val INPUT_SIZE: Int = 96
    private val PIXEL_SIZE: Int = 3
    private val IMAGE_MEAN = 0
    private val IMAGE_STD = 255.0f


    private var bitmap: Bitmap? = null
    private var bitmapResult: Bitmap? = null

    /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as input/output  */
    private lateinit var imgDataInput: ByteBuffer
    private lateinit var imgDataOutput: ByteBuffer

    /** Dimensions of inputs.  */
    private val DIM_BATCH_SIZE = 1

    private val DIM_PIXEL_SIZE = 3

    private val DIM_IMG_SIZE_X = 64
    private val DIM_IMG_SIZE_Y = 64
    private lateinit var catBitmap: Bitmap
    /* Preallocated buffers for storing image data in. */
    private val intValues = IntArray(DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y)
    private lateinit var superImage: ImageView

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
        superImage = findViewById(R.id.super_resolution_image)

        //val assetManager = assets
        catBitmap = getBitmapFromAsset("cat.png")


        srButton = findViewById(R.id.super_resolution)
        srButton.setOnClickListener { view: View ->

            val intent = Intent(this, SelectedImage::class.java)
            getImageResult.launch(intent)
        }


    }

    private fun getBitmapFromAsset(filePath: String?): Bitmap {
        val assetManager = assets
        val istr: InputStream
        var bitmap: Bitmap? = null
        try {
            istr = assetManager.open(filePath!!)
            bitmap = BitmapFactory.decodeStream(istr)
        } catch (e: IOException) {
            // handle exception
            Log.e("Bitmap_except", e.toString())

        }

        if (bitmap != null) {
            bitmap = Bitmap.createScaledBitmap(bitmap,64,64,true)
        }

        return bitmap?: Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888)
    }

    private val getImageResult =
        registerForActivityResult(ActivityResultContracts.StartActivityForResult()) { result ->
            if (result.resultCode == Activity.RESULT_OK) {
                var theImageUri: Uri? = null
                theImageUri = result.data?.getParcelableExtra<Uri>("imageuri")

                filename = "SR_" + theImageUri?.getOriginalFileName(this).toString()

                bitmap = uriToBitmap(theImageUri!!)!!//catBitmap//
                Log.v("width", bitmap!!.width.toString())

                if (bitmap != null) {
                    // call DL
                    val options = Interpreter.Options()
                    options.setNumThreads(5)
                    options.setUseNNAPI(true)
                    try {
                        tflitemodel = loadModelFile()
                        tflite = Interpreter(tflitemodel, options)
                        val index = tflite.getInputIndex("input_1")
                        tflite.resizeInput(
                            index,
                            intArrayOf(1, bitmap!!.width, bitmap!!.height, 3)
                        )
                    } catch (e: IOException) {
                        Log.e(TAG, "Fail to load model", e)
                    }

                    val imgprocessor = ImageProcessor.Builder()
                        .add(
                           ResizeOp(bitmap!!.width,
                                bitmap!!.height,
                                ResizeOp.ResizeMethod.NEAREST_NEIGHBOR)
                        )
                        .add( CastOp( DataType.FLOAT32 ) )
                        .build()

                    val inpIm = TensorImage(DataType.FLOAT32)
                    inpIm.load(bitmap)

                    // Process the image
                    val processedImage = imgprocessor.process(inpIm)

                    val output2 = Array(1) { Array(4*bitmap!!.width) { Array(4*bitmap!!.height) { FloatArray(3) } } }

                    tflite.run(processedImage.buffer, output2)

                    bitmapResult = convertArrayToBitmap(output2, 4*bitmap!!.height, 4*bitmap!!.width)

                    Log.v("widthHR", bitmapResult!!.height.toString())
                    superImage.setImageBitmap(bitmapResult)

                }
            }
        }


    @Throws(IOException::class)
    private fun loadModelFile(): MappedByteBuffer {
        val fileDescriptor = assets.openFd(MODEL_NAME)
        val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
        val fileChannel = inputStream.channel
        val startOffset = fileDescriptor.startOffset
        val declaredLength = fileDescriptor.declaredLength
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
    }


    private fun uriToBitmap(selectedFileUri: Uri): Bitmap? {
        try {
            val parcelFileDescriptor = contentResolver.openFileDescriptor(selectedFileUri, "r")
            val fileDescriptor: FileDescriptor = parcelFileDescriptor!!.fileDescriptor
            val image = BitmapFactory.decodeFileDescriptor(fileDescriptor)
            parcelFileDescriptor.close()
            return image
        } catch (e: IOException) {
            e.printStackTrace()
        }
        return null
    }

    private fun getOutputImage(output: ByteBuffer): Bitmap? {
        output.rewind()
        val outputWidth = 124 * 4
        val outputHeight = 118 * 4
        val bitmap = Bitmap.createBitmap(outputWidth, outputHeight, Bitmap.Config.ARGB_8888)
        val pixels = IntArray(outputWidth * outputHeight)
        for (i in 0 until outputWidth * outputHeight) {
            val a = 0xFF
            val r = output.float * 255.0f
            val g = output.float * 255.0f
            val b = output.float * 255.0f
            pixels[i] = a shl 24 or (r.toInt() shl 16) or (g.toInt() shl 8) or b.toInt()
        }
        bitmap.setPixels(pixels, 0, outputWidth, 0, 0, outputWidth, outputHeight)
        return bitmap
    }

    // save bitmap image to gallery
    private fun saveToGallery(context: Context, bitmap: Bitmap, albumName: String) {
        //val filename = "${System.currentTimeMillis()}.png"
        val write: (OutputStream) -> Boolean = {
            bitmap.compress(Bitmap.CompressFormat.PNG, 100, it)
        }

        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
            val contentValues = ContentValues().apply {
                put(MediaStore.MediaColumns.DISPLAY_NAME, filename)
                put(MediaStore.MediaColumns.MIME_TYPE, "image/png")
                put(MediaStore.MediaColumns.RELATIVE_PATH, "${Environment.DIRECTORY_DCIM}/$albumName")
            }

            context.contentResolver.let {
                it.insert(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, contentValues)?.let { uri ->
                    it.openOutputStream(uri)?.let(write)
                }
            }
        } else {
            val imagesDir = Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DCIM).toString() + File.separator + albumName
            val file = File(imagesDir)
            if (!file.exists()) {
                file.mkdir()
            }
            val image = File(imagesDir, filename)
            write(FileOutputStream(image))
        }
    }

    // get the filename from an image uri
    private fun Uri.getOriginalFileName(context: Context): String? {
        return context.contentResolver.query(this,
            null,
            null,
            null,
            null)?.use {
            val nameColumnIndex = it.getColumnIndex(OpenableColumns.DISPLAY_NAME)
            it.moveToFirst()
            it.getString(nameColumnIndex)
        }
    }
    fun convertArrayToBitmap(
        imageArray: Array<Array<Array<FloatArray>>>,
        imageWidth: Int,
        imageHeight: Int
    ): Bitmap {

        val conf = Bitmap.Config.ARGB_8888 // see other conf types
        val bitmap = Bitmap.createBitmap(imageWidth, imageHeight, conf)

        for (x in imageArray[0].indices) {
            for (y in imageArray[0][0].indices) {
                // Create bitmap to show on screen after inference
                val color = Color.rgb(
                    (imageArray[0][x][y][0]).toInt(),
                    (imageArray[0][x][y][1]).toInt(),
                    (imageArray[0][x][y][2]).toInt()
                )

                // this y, x is in the correct order!!!
                bitmap.setPixel(y, x, color)
            }
        }
        return bitmap
    }

}

take a look inside how we resize the inputs of the model inside android, how we create input buffer and output array and how we convert the produced array to a Bitmap. For these procedures check if you can use Gpu of the phone to have x3 speed and of course there are plenty to read at the official documentation.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1