Выполнение чего-то типа
import numpy as np
a = np.random.rand(10**4, 10**4)
b = np.dot(a, a)
использует несколько ядер, и он работает хорошо.
Элементы в a
, однако, являются 64-битными поплавками (или 32-разрядными в 32-разрядных платформах?), и я хотел бы умножить 8-разрядные целые массивы. Однако попробуйте следующее:
a = np.random.randint(2, size=(n, n)).astype(np.int8)
приводит к точечному продукту, не использующему несколько ядер, и, таким образом, на моем ПК медленнее на 1000 раз.
array: np.random.randint(2, size=shape).astype(dtype)
dtype shape %time (average)
float32 (2000, 2000) 62.5 ms
float32 (3000, 3000) 219 ms
float32 (4000, 4000) 328 ms
float32 (10000, 10000) 4.09 s
int8 (2000, 2000) 13 seconds
int8 (3000, 3000) 3min 26s
int8 (4000, 4000) 12min 20s
int8 (10000, 10000) It didn't finish in 6 hours
float16 (2000, 2000) 2min 25s
float16 (3000, 3000) Not tested
float16 (4000, 4000) Not tested
float16 (10000, 10000) Not tested
Я понимаю, что NumPy использует BLAS, который не поддерживает целые числа, но если я использую обертки SciPy BLAS, т.е.
import scipy.linalg.blas as blas
a = np.random.randint(2, size=(n, n)).astype(np.int8)
b = blas.sgemm(alpha=1.0, a=a, b=a)
вычисление многопоточное. Теперь blas.sgemm
работает с точно таким же временем, как np.dot
для float32, но для не-float он преобразует все в float32
и выводит поплавки, чего нет np.dot
. (Кроме того, b
теперь находится в F_CONTIGUOUS
порядке, что является меньшей проблемой).
Итак, если я хочу сделать умножение целых матриц, я должен выполнить одно из следующих действий:
- Используйте NumPy больно медленно
np.dot
и радуйся, что я могу сохранить 8-битные целые числа. - Используйте SciPy
sgemm
и используйте 4-кратную память. - Используйте Numpy
np.float16
и используйте только память 2x, с предостережением, чтоnp.dot
намного медленнее на массивах float16, чем на массивах float32, больше, чем int8. - Найти оптимизированную библиотеку для многопоточного умножения целых матриц (на самом деле, Mathematica делает это, но я бы предпочел решение Python), идеально поддерживая 1-битные массивы, хотя 8-битные массивы тоже отлично... (я на самом деле намерен сделать умножение матриц над конечным полем Z/2Z, и я знаю, что могу сделать это с помощью Sage, что довольно Pythonic, но, опять же, есть ли что-то строго Python?)
Могу ли я выполнить опцию 4? Существует ли такая библиотека?
Отказ от ответственности: я фактически запускаю NumPy + MKL, но я пробовал аналогичный тест на vanilly NumPy с аналогичными результатами.