Подтвердить что ты не робот

Numpy: Разница между точками (a, b) и (a * b).sum()

Для 1-D массивов numpy эти два выражения должны дать тот же результат (теоретически):

(a*b).sum()/a.sum()
dot(a, b)/a.sum()

Последний использует dot() и быстрее. Но какой из них более точен? Почему?

Ниже приведен контекст.

Я хотел вычислить взвешенную дисперсию образца с помощью numpy. Я нашел выражение dot() в другом ответе с комментарием о том, что он должен быть более точным. Однако никаких объяснений здесь нет.

4b9b3361

Ответ 1

Точка Numpy - это одна из подпрограмм, которая вызывает библиотеку BLAS, которую вы связываете при компиляции (или создает свои собственные). Важность этого заключается в том, что библиотека BLAS может использовать операции Multiply-accumulate (обычно Fused-Multiply Add), которые ограничивают количество округлений, выполняемых вычислением.

Возьмите следующее:

>>> a=np.ones(1000,dtype=np.float128)+1E-14 
>>> (a*a).sum()  
1000.0000000000199948
>>> np.dot(a,a)
1000.0000000000199948

Неточно, но достаточно близко.

>>> a=np.ones(1000,dtype=np.float64)+1E-14
>>> np.dot(a,a)
1000.0000000000176  #off by 2.3948e-12
>>> (a*a).sum()
1000.0000000000059  #off by 1.40948e-11

np.dot(a, a) будет более точным из двух, поскольку он использует приблизительно половину числа округлений с плавающей запятой, что наивный (a*a).sum() делает.

Книга Nvidia имеет следующий пример для 4 цифр точности. rn обозначает 4 раунда с точностью до 4 цифр:

x = 1.0008
x2 = 1.00160064                    #    true value
rn(x2 − 1) = 1.6006 × 10−4         #    fused multiply-add
rn(rn(x2) − 1) = 1.6000 × 10−4     #    multiply, then add

Конечно, числа с плавающей запятой не округлены до 16-го десятичного знака в базе 10, но вы получаете идею.

Размещение np.dot(a,a) в приведенных выше обозначениях с помощью некоторого дополнительного псевдокода:

out=0
for x in a:
    out=rn(x*x+out)   #Fused multiply add

Пока (a*a).sum():

arr=np.zeros(a.shape[0])   
for x in range(len(arr)):
    arr[x]=rn(a[x]*a[x])

out=0
for x in arr:
    out=rn(x+out)

Из этого легко видеть, что число округлено в два раза больше, используя (a*a).sum() по сравнению с np.dot(a,a). Эти небольшие различия, которые могут быть суммированы, могут изменить ответ. Дополнительные примеры можно найти здесь.