Подтвердить что ты не робот

Argmax каждой строки или столбца в scipy разреженной матрице

scipy.sparse.coo_matrix.max возвращает максимальное значение каждой строки или столбца с учетом оси. Я хотел бы знать не значение, а индекс максимального значения каждой строки или столбца. Я еще не нашел способ сделать это эффективным образом, поэтому я с радостью принимаю любую помощь.

4b9b3361

Ответ 1

Из scipy версии 0.19, как csr_matrix, так и csc_matrix поддерживают методы argmax() и argmin().

Ответ 2

Я бы предложил изучить код для

moo._min_or_max_axis

где moo - coo_matrix.

mat = mat.tocsc()  # for axis=0
mat.sum_duplicates()

major_index, value = mat._minor_reduce(min_or_max)
not_full = np.diff(mat.indptr)[major_index] < N
value[not_full] = min_or_max(value[not_full], 0)

mask = value != 0
major_index = np.compress(mask, major_index)
value = np.compress(mask, value)
return coo_matrix((value, (np.zeros(len(value)), major_index)),
                      dtype=self.dtype, shape=(1, M))

В зависимости от оси он предпочитает работать с csc over csr. У меня не было времени анализировать это, но я предполагаю, что в расчет должно быть включено argmax.


Это предложение может не работать. Ключ - это метод mat._minor_reduce, который делает с некоторой доработкой:

ufunc.reduceat(mat.data, mat.indptr[:-1])

То есть применяется ufunc к блокам массива data, используя indptr для определения блоков. np.sum, np.maxiumum ufunc, где это работает. Я не знаю эквивалента argmax ufunc.

В общем случае, если вы хотите делать что-то "row" для матрицы csr (или col csc), вам нужно либо перебирать строки, что относительно дорого, либо использовать этот ufunc.reduceat для того же вещь над плоским mat.data вектором.

группа argmax/argmin по индексам разбиения на numpy пытается выполнить a argmax.reduceat. Решение там может быть адаптировано к разреженной матрице.

Ответ 3

Если A является вашим scipy.sparse.coo_matrix, вы получите строку и столбец максимального значения следующим образом:

I=A.data.argmax()
maxrow = A.row[I]
maxcol=A.col[I]

Чтобы получить индекс максимального значения в каждой строке, см. EDIT ниже:

from scipy.sparse import coo_matrix
import numpy as np
row  = np.array([0, 3, 1, 0])
col  = np.array([0, 2, 3, 2])
data = np.array([-3, 4, 11, -7])
A= coo_matrix((data, (row, col)), shape=(4, 4))
print A.toarray()

nrRows=A.shape[0]
maxrowind=[]
for i in range(nrRows):
    r = A.getrow(i)# r is 1xA.shape[1] matrix
    maxrowind.append( r.indices[r.data.argmax()] if r.nnz else 0)
print maxrowind 

r.nnz - это счетчик явно сохраненных значений (то есть ненулевых значений)

Ответ 4

Последняя версия пакета numpy_indexed (отказ от ответственности: я его автор) может эффективно и элегантно решить эту проблему:

import numpy_indexed as npi
col, argmax = group_by(coo.col).argmax(coo.data)
row = coo.row[argmax]

Здесь мы группируем col, поэтому его argmax над столбцами; swapping row и col даст вам argmax над строками.

Ответ 5

Расширяя ответы от @hpaulj и @joeln и используя код из группы argmax/argmin по индексам секционирования в numpy, как предлагается, эта функция будет вычислять argmax по столбцам для CSR или argmax над строками для CSC:

import numpy as np
import scipy.sparse as sp

def csr_csc_argmax(X, axis=None):
    is_csr = isinstance(X, sp.csr_matrix)
    is_csc = isinstance(X, sp.csc_matrix)
    assert( is_csr or is_csc )
    assert( not axis or (is_csr and axis==1) or (is_csc and axis==0) )

    major_size = X.shape[0 if is_csr else 1]
    major_lengths = np.diff(X.indptr) # group_lengths
    major_not_empty = (major_lengths > 0)

    result = -np.ones(shape=(major_size,), dtype=X.indices.dtype)
    split_at = X.indptr[:-1][major_not_empty]
    maxima = np.zeros((major_size,), dtype=X.dtype)
    maxima[major_not_empty] = np.maximum.reduceat(X.data, split_at)
    all_argmax = np.flatnonzero(np.repeat(maxima, major_lengths) == X.data)
    result[major_not_empty] = X.indices[all_argmax[np.searchsorted(all_argmax, split_at)]]
    return result

Он возвращает -1 для argmax любых строк (CSR) или столбцов (CSC), которые являются полностью разреженными (т.е. полностью равными нулю после X.eliminate_zeros()).