'Numpy: change max in each row to 1, all other numbers to 0
I'm trying to implement a numpy function that replaces the max in each row of a 2D array with 1, and all other numbers with zero:
>>> a = np.array([[0, 1],
... [2, 3],
... [4, 5],
... [6, 7],
... [9, 8]])
>>> b = some_function(a)
>>> b
[[0. 1.]
[0. 1.]
[0. 1.]
[0. 1.]
[1. 0.]]
What I've tried so far
def some_function(x):
a = np.zeros(x.shape)
a[:,np.argmax(x, axis=1)] = 1
return a
>>> b = some_function(a)
>>> b
[[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]]
Solution 1:[1]
I prefer using numpy.where like so:
a[np.where(a==np.max(a))] = 1
Solution 2:[2]
a==np.max(a) will raise an error in the future, so here's a tweaked version that will continue to broadcast correctly.
I know this question is pretty ancient, but I think I have a decent solution that's a bit different from the other solutions.
# get max by row and convert from (n, ) -> (n, 1) which will broadcast
row_maxes = a.max(axis=1).reshape(-1, 1)
np.where(a == row_maxes, 1, 0)
np.where(a == row_maxes).astype(int)
if the update needs to be in place, you can do
a[:] = np.where(a == row_maxes, 1, 0)
Solution 3:[3]
b = (a == np.max(a))
That worked for me
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Cyclone |
| Solution 2 | Alex Riina |
| Solution 3 | Mushfirat Mohaimin |
