'Optimize Numpy extraction from multiband array

I'm relatively new to python, so please correct me if "extraction" isn't the right terminology.

My problem : I'm working on image processing/analysis using Numpy, my code works fine but very slow with high resolutions.

this is a simplified case of what I want to do :

import numpy as np
import matplotlib.pyplot as plt  

# Here i'm creating and example 6x6 image with 4 different bands.
band  = np.arange(1,37).reshape(6,6)
img = np.copy(band)
for i in range(1,4):
    img = np.dstack((img,band+i*100))
    
# Now creating the image segments, from which i'd like to extract pixel values
segments = np.kron( (np.arange(9)+1).reshape(3,3),np.ones((2,2)))

NOTE : In this example we have 9 square segments, but in practice there would be several thousands, with various shapes and sizes. (generated with the Quickshift methode)

plt.imshow(segments)

#Now my goal is, for each segment, to extract the pixel values ( all bands from img) under said segment.    
plt.imshow(segments==1)
# displaying only the first two bands for clarity.
print(A[:,:,0], A[:,:,1], sep="\n\n")

This is what I mean by extracting pixel values

#This is what it looks like for the first segment :
pixels = img[segments == 1]
pixels.T

array([[  1,   2,   7,   8],
       [101, 102, 107, 108],
       [201, 202, 207, 208],
       [301, 302, 307, 308]])

Now This is the part I want to optimize, Here I'm using a loop to get the pixel "under" each segment.

segment_ids = np.unique(segments)
segment_pixels = []

for i in segment_ids:
    pixels = img[ segments == i ]
    segment_pixels.append(pixels)

But, for very large images (> 2 Go), and lots of segments, this operation takes forever. Is there a way to speed this up ? I've read a bit about numpy vectorisation, but couldn't figure out how to apply it here. Does anyone know how I could improve performances ?

Thank you !



Solution 1:[1]

Rewriting answer after clarification

You should have given us an minimal reproducible example, more like this

import numpy as np
# 4M pixels with 1k segments
img = np.random.randint(0, 1000, (2000, 2000))
segments = np.random.randint(0, 1000, (2000, 2000))

Then get the segments

%%time
segment_pixels = []
for i in np.unique(segments):
  segment_pixels.append(img[segments == i])

It takes about 5 seconds

If you don't mind using pandas

import pandas as pd

It will take more memory but it provides a groupby function that is much faster than your approach if you have many groups and many records

segment_pixels = pd.DataFrame({
  'segments': segments.reshape(-1), 
  'pixels': img.reshape(-1)
}).groupby('segments').groups

Runs in about ~330ms where the previous aproach would take ~5s

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1