Строчная индексация в Numpy - программирование

Строчная индексация в Numpy

У меня две матрицы, A и B:

A = array([[2., 13., 25., 1.], [ 18., 5., 1., 25.]])
B = array([[2, 1], [0, 3]])

Я хочу индексировать каждую строку A с каждой строкой B, создавая срез:

array([[25., 13.], [18., 25.]])

То есть, я действительно хочу что-то вроде:

array([A[i,b] for i,b in enumerate(B)])

Есть ли способ представить это прямо? Лучшее, что я могу сделать, это "плоский взлом":

A.flat[B + arange(0,A.size,A.shape[1])[:,None]]
4b9b3361

Ответ 1

@Ответ на Ophion велик и заслуживает внимания, но я хотел добавить некоторые объяснения и предложить более интуитивную конструкцию.

Вместо того, чтобы вращать B и затем возвращать результат назад, лучше просто повернуть arange. Я думаю, что это дает наиболее интуитивное решение, даже если оно принимает больше символов:

A[((0,),(1,)), B]

или эквивалентно

A[np.arange(2)[:, None], B]

Это работает, потому что вы действительно делаете массив i и массив j, каждый из которых имеет ту же форму, что и ваш желаемый результат.

i = np.array([[0, 0],
              [1, 1]])
j = B

Но вы можете использовать только

i = np.array([[0],
              [1]])

Потому что он будет транслироваться в соответствии с B (это то, что дает np.arange(2)[:,None]).

Наконец, чтобы сделать его более общим (не зная 2 как размер arange), вы также можете сгенерировать i из B с помощью

i = np.indices(B.shape)[0]

однако вы создаете i и j, вы просто называете это как

>>> A[i, j]
array([[ 25.,  13.],
       [ 18.,  25.]])

Ответ 2

Не очень, но:

A[np.arange(2),B.T].T
array([[ 25.,  13.],
       [ 18.,  25.]])