'Vectorizing numpy.random.multinomial

I am trying to vectorize the following code:

for i in xrange(s.shape[0]):
            a[i] = np.argmax(np.random.multinomial(1,s[i,:]))

s.shape = 400 x 100 [given].

a.shape = 400 [expected].

s is a 2D matrix, which contains the probabilities of pairs. The multinomial is expected to draw a random sample from each row of the s matrix and store the result in vector a.



Solution 1:[1]

In the comments, it is said that there is an attempt at vectorizing this here, however, it's not only an attempt. It is a complete solution the this question too.

The goal of the question is to obtain the index of the postion containing the 1 of the multinomial event. That is, the following realization [0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0] will yield 14. Thus, it is actually equivalent to executing:

np.random.choice(np.arange(len(p)),p=p) # here, p is s[i,:]

Therefore, Warren Weckesser solution to Fast random weighted selection across all rows of a stochastic matrix is also a solution to this question. The only difference is whether the vectors of probability are defined in rows or in columns, which can be solved easily either transposing s to be used as prob_matrix or defining a custom version of vectorized that works for s strtucture:

def vectorized(prob_matrix, items):
    s = prob_matrix.cumsum(axis=1)
    r = np.random.rand(prob_matrix.shape[0])
    k = (s < r).sum(axis=1)
    return items[k]

In this question, with dimensions 400x400, the speedup is around a factor 10:

%%timeit
a = np.empty(400)
for i in range(s.shape[0]):
    a[i] = np.argmax(np.random.multinomial(1,s[i,:]))
# 5.96 ms ± 46.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit 
vals = np.arange(400,dtype=int)
vectorized(s,vals)
# 544 µs ± 5.49 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Solution 2:[2]

How about

[np.argmax(np.random.multinomial(1,s[i,:])) for i in xrange(s.shape[0])]

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2 Antimony