Подтвердить что ты не робот

Каков наиболее эффективный способ найти позицию первого значения np.nan?

рассмотрим массив a

a = np.array([3, 3, np.nan, 3, 3, np.nan])

Я мог бы сделать

np.isnan(a).argmax()

Но для этого нужно найти все np.nan, чтобы найти первое.
Есть ли более эффективный способ?


Я пытался выяснить, могу ли я передать параметр np.argpartition, чтобы np.nan сначала сортировался, а не последним.


EDIT относительно [dup].
Есть несколько причин, по которым этот вопрос отличается.

  • Этот вопрос и ответы касались равенства ценностей. Это относится к isnan.
  • Ответы на эти ответы страдают от той же самой проблемы, с которой я столкнулся. Заметьте, я дал совершенно верный ответ, но подчеркнул его неэффективность. Я ищу, чтобы исправить неэффективность.

EDIT относительно второго [dup].

Все еще обращаясь к равенству, а вопрос/ответы старые и, возможно, устаревшие.

4b9b3361

Ответ 1

Я назначу

a.argmax()

С @fuglede's тестовым массивом:

In [1]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])
In [2]: np.isnan(a).argmax()
Out[2]: 9999
In [3]: np.argmax(a)
Out[3]: 9999
In [4]: a.argmax()
Out[4]: 9999

In [5]: timeit a.argmax()
The slowest run took 29.94 ....
10000 loops, best of 3: 20.3 µs per loop

In [6]: timeit np.isnan(a).argmax()
The slowest run took 7.82 ...
1000 loops, best of 3: 462 µs per loop

У меня нет numba, поэтому можно сравнить это. Но мое ускорение относительно short больше, чем @fuglede's 6x.

Я тестирую в Py3, который принимает <np.nan, а Py2 вызывает предупреждение во время выполнения. Но поиск кода предполагает, что это не зависит от этого сравнения.

/numpy/core/src/multiarray/calculation.c PyArray_ArgMax играет с осями (перемещая интерес к концу) и делегирует действие arg_func = PyArray_DESCR(ap)->f->argmax, функции, которая зависит от dtype.

В numpy/core/src/multiarray/arraytypes.c.src он выглядит как BOOL_argmax коротких замыканий, возвращающихся, как только он встречает True.

for (; i < n; i++) {
    if (ip[i]) {
        *max_ind = i;
        return 0;
    }
}

И @[email protected]_argmax также короткие замыкания на максимальном nan. np.nan также является "максимальным" в argmin.

#if @[email protected]
    if (@[email protected](mp)) {
        /* nan encountered; it maximal */
        return 0;
    }
#endif

Комментарии от опытных кодеров c приветствуются, но мне кажется, что по крайней мере для np.nan простая argmax будет такой же быстрой, как вы можете получить.

Воспроизведение с 9999 при генерации a показывает, что время a.argmax зависит от этого значения в соответствии с коротким замыканием.

Ответ 2

Возможно, стоит посмотреть в numba.jit; без него векторизованная версия, скорее всего, пойдет по прямолинейному поиску чистого Python в большинстве сценариев, но после компиляции кода обычный поиск займет лидирующую позицию, по крайней мере, в моем тестировании:

In [63]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])

In [70]: %paste
import numba

def naive(a):
        for i in range(len(a)):
                if np.isnan(a[i]):
                        return i

def short(a):
        return np.isnan(a).argmax()

@numba.jit
def naive_jit(a):
        for i in range(len(a)):
                if np.isnan(a[i]):
                        return i

@numba.jit
def short_jit(a):
        return np.isnan(a).argmax()
## -- End pasted text --

In [71]: %timeit naive(a)
100 loops, best of 3: 7.22 ms per loop

In [72]: %timeit short(a)
The slowest run took 4.59 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 37.7 µs per loop

In [73]: %timeit naive_jit(a)
The slowest run took 6821.16 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.79 µs per loop

In [74]: %timeit short_jit(a)
The slowest run took 395.51 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 144 µs per loop

Изменить: Как отметил @hpaulj в своем ответе, numpy действительно поставляется с оптимизированным короткозамкнутым поиском, производительность которого сопоставима с поиском JITted выше:

In [26]: %paste
def plain(a):
        return a.argmax()

@numba.jit
def plain_jit(a):
        return a.argmax()
## -- End pasted text --

In [35]: %timeit naive(a)
100 loops, best of 3: 7.13 ms per loop

In [36]: %timeit plain(a)
The slowest run took 4.37 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 7.04 µs per loop

In [37]: %timeit naive_jit(a)
100000 loops, best of 3: 6.91 µs per loop

In [38]: %timeit plain_jit(a)
10000 loops, best of 3: 125 µs per loop

Ответ 3

Вот пифонический подход с использованием itertools.takewhile():

from itertools import takewhile
sum(1 for _ in takewhile(np.isfinite, a))

Ориентир с использованием метода generator_expression_within_ next: 1

In [118]: a = np.repeat(a, 10000)

In [120]: %timeit next(i for i, j in enumerate(a) if np.isnan(j))
100 loops, best of 3: 12.4 ms per loop

In [121]: %timeit sum(1 for _ in takewhile(np.isfinite, a))
100 loops, best of 3: 11.5 ms per loop

Но все же (безусловно) медленнее, чем numpy подход:

In [119]: %timeit np.isnan(a).argmax()
100000 loops, best of 3: 16.8 µs per loop

<суб > 1. Проблема с этим подходом заключается в использовании функции enumerate. Возвращает объект enumerate из массива numpy first (который является объектом итератора) и вызывает функцию генератора, а атрибут next итератора займет время. Суб >

Ответ 4

При поиске первого совпадения в различных сценариях мы можем выполнить итерацию и искать первое совпадение и выйти из первого совпадения, а не переходить/обрабатывать весь массив. Итак, у нас был бы подход с использованием Python next function, например:

next((i for i, val in enumerate(a) if np.isnan(val)))

Примеры прогона -

In [192]: a = np.array([3, 3, np.nan, 3, 3, np.nan])

In [193]: next((i for i, val in enumerate(a) if np.isnan(val)))
Out[193]: 2

In [194]: a[2] = 10

In [195]: next((i for i, val in enumerate(a) if np.isnan(val)))
Out[195]: 5