Подтвердить что ты не робот

Numpy объединить отсортированный массив в новый массив?

Есть ли способ сделать что-то вроде merge в mergesort, используя функцию numpy?

некоторая функция, например merge:

a = np.array([1,3,5])
b = np.array([2,4,6])
c = merge(a, b) # c == np.array([1,2,3,4,5,6])

Мне хотелось бы получить высокую производительность для больших данных благодаря numpy

4b9b3361

Ответ 1

Вы можете использовать

c = concatenate((a,b))
c.sort(kind='mergesort')

Я боюсь, что вы не сможете сделать это лучше, если только вы не напишите свою собственную функцию сортировки как расширение python, а cython.

См. этот вопрос для аналогичной проблемы, но сохраняя только уникальные значения в объединенном массиве. Оценки и комментарии там также проницательны.

Ответ 2

Когда один массив значительно больше другого, приличное ускорение (в 5 раз на моем компьютере) можно получить, выполнив команду np.searchorted, скорость которой ограничена, главным образом, путем поиска индексов вставки меньшего массива:

import numpy as np

def classic_merge(a, b):
    c = np.concatenate((a,b))
    c.sort(kind='mergesort')
    return c

def new_merge(a, b):
    if len(a) < len(b):
        b, a = a, b
    c = np.empty(len(a) + len(b), dtype=a.dtype)
    b_indices = np.arange(len(b)) + np.searchsorted(a, b)
    a_indices = np.ones(len(c), dtype=bool)
    a_indices[b_indices] = False
    c[b_indices] = b
    c[a_indices] = a
    return c

Сроки дает:

from timeit import timeit as t

results = []
for size_digits in range(2, 8):
    size = 10**size_digits
    # size difference of a factor 10 here makes the difference!
    a = np.arange(size // 10, dtype=np.int)
    b = np.arange(size, dtype=np.int)
    classic = t(lambda: classic_merge(a, b), number=10)
    new = t(lambda: new_merge(a, b), number=10)
    results.append((size_digits, classic, new))

if True:
    text_format = " ".join(["{:<15}"] * 3)
    print(text_format.format("log10(size)", "Classic", "New"))
    table_format = "         ".join(["{:.5f}"] * 3)
    for result in results:
        print(table_format.format(*result))

log10(size)     Classic         New            
2.00000         0.00009         0.00027
3.00000         0.00021         0.00030
4.00000         0.00233         0.00082
5.00000         0.02827         0.00601
6.00000         0.33322         0.06059
7.00000         4.40571         0.86764

Когда a и b примерно одинаковы, различия в длине меньше:

from timeit import timeit as t

results = []
for size_digits in range(2, 8):
    size = 10**size_digits
    # same size
    a = np.arange(size , dtype=np.int)
    b = np.arange(size, dtype=np.int)
    classic = t(lambda: classic_merge(a, b), number=10)
    new = t(lambda: new_merge(a, b), number=10)
    results.append((size_digits, classic, new))

if True:
    text_format = " ".join(["{:<15}"] * 3)
    print(text_format.format("log10(size)", "Classic", "New"))
    table_format = "         ".join(["{:.5f}"] * 3)
    for result in results:
        print(table_format.format(*result))

log10(size)     Classic         New            
2.00000         0.00026         0.00087
3.00000         0.00108         0.00182
4.00000         0.01257         0.01243
5.00000         0.16333         0.12692
6.00000         1.05006         0.49186
7.00000         8.35967         5.93732