'Optimize Numpy extraction from multiband array
I'm relatively new to python, so please correct me if "extraction" isn't the right terminology.
My problem : I'm working on image processing/analysis using Numpy, my code works fine but very slow with high resolutions.
this is a simplified case of what I want to do :
import numpy as np
import matplotlib.pyplot as plt
# Here i'm creating and example 6x6 image with 4 different bands.
band = np.arange(1,37).reshape(6,6)
img = np.copy(band)
for i in range(1,4):
img = np.dstack((img,band+i*100))
# Now creating the image segments, from which i'd like to extract pixel values
segments = np.kron( (np.arange(9)+1).reshape(3,3),np.ones((2,2)))
NOTE : In this example we have 9 square segments, but in practice there would be several thousands, with various shapes and sizes. (generated with the Quickshift methode)
plt.imshow(segments)
#Now my goal is, for each segment, to extract the pixel values ( all bands from img) under said segment.
plt.imshow(segments==1)
# displaying only the first two bands for clarity.
print(A[:,:,0], A[:,:,1], sep="\n\n")
#This is what it looks like for the first segment :
pixels = img[segments == 1]
pixels.T
array([[ 1, 2, 7, 8],
[101, 102, 107, 108],
[201, 202, 207, 208],
[301, 302, 307, 308]])
Now This is the part I want to optimize, Here I'm using a loop to get the pixel "under" each segment.
segment_ids = np.unique(segments)
segment_pixels = []
for i in segment_ids:
pixels = img[ segments == i ]
segment_pixels.append(pixels)
But, for very large images (> 2 Go), and lots of segments, this operation takes forever. Is there a way to speed this up ? I've read a bit about numpy vectorisation, but couldn't figure out how to apply it here. Does anyone know how I could improve performances ?
Thank you !
Solution 1:[1]
Rewriting answer after clarification
You should have given us an minimal reproducible example, more like this
import numpy as np
# 4M pixels with 1k segments
img = np.random.randint(0, 1000, (2000, 2000))
segments = np.random.randint(0, 1000, (2000, 2000))
Then get the segments
%%time
segment_pixels = []
for i in np.unique(segments):
segment_pixels.append(img[segments == i])
It takes about 5 seconds
If you don't mind using pandas
import pandas as pd
It will take more memory but it provides a groupby function that is much faster than your approach if you have many groups and many records
segment_pixels = pd.DataFrame({
'segments': segments.reshape(-1),
'pixels': img.reshape(-1)
}).groupby('segments').groups
Runs in about ~330ms where the previous aproach would take ~5s
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 |

