Подтвердить что ты не робот

Быстрая, бездисковая беззнаковая абсолютная разность

У меня есть программа, которая проводит большую часть своего времени, вычисляя евклидово расстояние между значениями RGB (3 кортежа беззнакового 8-битного Word8). Мне нужна быстрая, безветвленная функция без абсолютной разности без знака так, чтобы

unsigned_difference :: Word8 -> Word8 -> Word8
unsigned_difference a b = max a b - min a b

особенно,

unsigned_difference ab == unsigned_difference ba

Я придумал следующее, используя новые примулы из GHC 7.8:

-- (a < b) * (b - a) + (a > b) * (a - b)
unsigned_difference (I# a) (I# b) =
    I# ((a <# b) *# (b -# a) +# (a ># b) *# (a -# b))]

который ghc -O2 -S компилируется в

.Lc42U:
    movq 7(%rbx),%rax
    movq $ghczmprim_GHCziTypes_Izh_con_info,-8(%r12)
    movq 8(%rbp),%rbx
    movq %rbx,%rcx
    subq %rax,%rcx
    cmpq %rax,%rbx
    setg %dl
    movzbl %dl,%edx
    imulq %rcx,%rdx
    movq %rax,%rcx
    subq %rbx,%rcx
    cmpq %rax,%rbx
    setl %al
    movzbl %al,%eax
    imulq %rcx,%rax
    addq %rdx,%rax
    movq %rax,(%r12)
    leaq -7(%r12),%rbx
    addq $16,%rbp
    jmp *(%rbp)

компиляция с ghc -O2 -fllvm -optlo -O3 -S производит следующий asm:

.LBB6_1:
    movq    7(%rbx), %rsi
    movq    $ghczmprim_GHCziTypes_Izh_con_info, 8(%rax)
    movq    8(%rbp), %rcx
    movq    %rsi, %rdx
    subq    %rcx, %rdx
    xorl    %edi, %edi
    subq    %rsi, %rcx
    cmovleq %rdi, %rcx
    cmovgeq %rdi, %rdx
    addq    %rcx, %rdx
    movq    %rdx, 16(%rax)
    movq    16(%rbp), %rax
    addq    $16, %rbp
    leaq    -7(%r12), %rbx
    jmpq    *%rax  # TAILCALL

Таким образом, LLVM удается заменить сравнения (более эффективными?) Инструкциями условного перемещения. К сожалению, компиляция с -fllvm мало влияет на время выполнения моей программы.

Однако есть две проблемы с этой функцией.

  • Я хочу сравнить Word8, но сравнение primops требует использования Int. Это вызывает ненужное распределение, так как я вынужден хранить 64-битный Int а не Word8.

Я профилировал и подтвердил, что использование fromIntegral :: Word8 → Int отвечает за 42,4 процента от общего объема ассигнований программы.

  • Моя версия использует 2 сравнения, 2 умножения и 2 вычитания. Интересно, есть ли более эффективный метод, использующий побитовые операции или SIMD-инструкции и использующий тот факт, что я сравниваю Word8?

Ранее я пометил вопрос C/C++ чтобы привлечь внимание тех, кто более склонен к битовой манипуляции. В моем вопросе используется Haskell, но я бы принял ответ, реализующий правильный метод на любом языке.

Заключение:

Я решил использовать

w8_sad :: Word8 -> Word8 -> Int16
w8_sad a b = xor (diff + mask) mask
    where diff = fromIntegral a - fromIntegral b
          mask = unsafeShiftR diff 15

так как он быстрее, чем моя оригинальная функция unsigned_difference, и прост в реализации. Принадлежности SIMD в Haskell еще не достигли зрелости. Поэтому, пока SIMD-версии работают быстрее, я решил использовать скалярную версию.

4b9b3361

Ответ 1

Хорошо, я попытался немного сравнить. Я использую Criterion для тестов, потому что он делает правильные тесты значимости. Я также использую QuickCheck здесь, чтобы гарантировать, что все методы возвращают те же результаты.

Я скомпилировал с GHC 7.6.3 (поэтому, к сожалению, я не мог включить вашу функцию primops) и -O3:

ghc -O3 AbsDiff.hs -o AbsDiff && ./AbsDiff

В первую очередь мы можем видеть разницу между наивной реализацией и небольшим количеством fiddel:

absdiff1_w8 :: Word8 -> Word8 -> Word8
absdiff1_w8 a b = max a b - min a b

absdiff2_w8 :: Word8 -> Word8 -> Word8
absdiff2_w8 a b = unsafeCoerce $ xor (v + mask) mask
  where v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
        mask = unsafeShiftR v 63

Вывод:

benchmarking absdiff_Word8/1
mean: 249.8591 us, lb 248.1229 us, ub 252.4321 us, ci 0.950
....

benchmarking absdiff_Word8/2
mean: 202.5095 us, lb 200.8041 us, ub 206.7602 us, ci 0.950
...

Я использую тэг абсолютное целочисленное значение от "Бит Twiddling Hacks здесь". К сожалению, нам нужны касты, я не думаю, что можно хорошо решить проблему в домене Word8, но разумно использовать нативный целочисленный тип в любом случае (определенно нет необходимости создавать объект кучи, хотя).

На самом деле это не очень большая разница, но моя тестовая установка также не идеальна: я сопоставляю функцию над большим списком случайных значений, чтобы исключить предсказание ветки, что делает ветвящуюся версию более эффективной, чем она, Это заставляет громко наращивать память, что может сильно повлиять на тайминги. Когда мы вычитаем постоянные накладные расходы для поддержания списка, мы могли бы увидеть намного больше, чем 20% -ное ускорение.

Сгенерированная сборка на самом деле довольно хороша (это встроенная версия функции):

.Lc4BB:
    leaq 7(%rbx),%rax
    movq 8(%rbp),%rbx
    subq (%rax),%rbx
    movq %rbx,%rax
    sarq $63,%rax
    movq $base_GHCziInt_I64zh_con_info,-8(%r12)
    addq %rax,%rbx
    xorq %rax,%rbx
    movq %rbx,0(%r12)
    leaq -7(%r12),%rbx
    movq $s4z0_info,8(%rbp)

1 вычитание, 1 сложение, 1 сдвиг вправо, 1 xor и отсутствие ветки, как и ожидалось. Использование LLVM-бэкэнд не улучшает время выполнения.

Надеюсь, это полезно, если вы хотите попробовать больше вещей.

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Main where

import Data.Word
import Data.Int
import Data.Bits
import Control.Arrow ((***))
import Control.DeepSeq (force)
import Control.Exception (evaluate)
import Control.Monad
import System.Random
import Unsafe.Coerce

import Test.QuickCheck hiding ((.&.))
import Criterion.Main

absdiff1_w8 :: Word8 -> Word8 -> Word8
absdiff1_w8 !a !b = max a b - min a b

absdiff1_int16 :: Int16 -> Int16 -> Int16
absdiff1_int16 a b = max a b - min a b

absdiff1_int :: Int -> Int -> Int
absdiff1_int a b = max a b - min a b

absdiff2_int16 :: Int16 -> Int16 -> Int16
absdiff2_int16 a b = xor (v + mask) mask
  where v = a - b
        mask = unsafeShiftR v 15

absdiff2_w8 :: Word8 -> Word8 -> Word8
absdiff2_w8 !a !b = unsafeCoerce $ xor (v + mask) mask
  where !v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
        !mask = unsafeShiftR v 63

absdiff3_w8 :: Word8 -> Word8 -> Word8
absdiff3_w8 a b = if a > b then a - b else b - a

{-absdiff4_int :: Int -> Int -> Int-}
{-absdiff4_int (I# a) (I# b) =-}
    {-I# ((a <# b) *# (b -# a) +# (a ># b) *# (a -# b))-}

e2e :: (Enum a, Enum b) => a -> b
e2e = toEnum . fromEnum

prop_same1 x y = absdiff1_w8 x y == absdiff2_w8 x y
prop_same2 (x::Word8) (y::Word8) = absdiff1_int16 x' y' == absdiff2_int16 x' y'
    where x' = e2e x
          y' = e2e y

check = quickCheck prop_same1
     >> quickCheck prop_same2

instance (Random x, Random y) => Random (x, y) where
  random gen1 =
    let (x, gen2) = random gen1
        (y, gen3) = random gen2
    in ((x,y),gen3)

main =
    do check
       !pairs_w8 <- fmap force $ replicateM 10000 (randomIO :: IO (Word8,Word8))
       let !pairs_int16 = force $ map (e2e *** e2e) pairs_w8
       defaultMain
         [ bgroup "absdiff_Word8" [ bench "1" $ nf (map (uncurry absdiff1_w8)) pairs_w8
                                  , bench "2" $ nf (map (uncurry absdiff2_w8)) pairs_w8
                                  , bench "3" $ nf (map (uncurry absdiff3_w8)) pairs_w8
                                  ]
         , bgroup "absdiff_Int16" [ bench "1" $ nf (map (uncurry absdiff1_int16)) pairs_int16
                                  , bench "2" $ nf (map (uncurry absdiff2_int16)) pairs_int16
                                  ]
         {-, bgroup "absdiff_Int"   [ bench "1" $ whnf (absdiff1_int 13) 14-}
                                  {-, bench "2" $ whnf (absdiff3_int 13) 14-}
                                  {-]-}
         ]

Ответ 2

Если вы ориентируетесь на систему с инструкциями SSE, вы можете использовать ее для повышения производительности. Я тестировал это против других опубликованных методов, и это, по-видимому, самый быстрый подход.

Примеры результатов для различного большого количества значений:

diff0: 188.020679 ms // branching
diff1: 118.934970 ms // max min
diff2: 97.087710 ms  // branchless mul add
diff3: 54.495269 ms  // branchless signed
diff4: 31.159628 ms  // sse
diff5: 30.855885 ms  // sse v2

Мой полный тестовый код ниже. Я использовал инструкции SSE2, которые в наши дни широко доступны в x86ish-процессорах, через встроенные функции SSE, которые должны быть довольно портативными (MSVC, GCC, Clang, компиляторы Intel и т.д.).

Примечания:

  • Эффективно это вычисляет max, затем min, а затем вычитает, но делает 16 значений сразу с каждой инструкцией.
  • Развернуть его в diff5 кажется малоэффективным, но, возможно, можно настроить.
  • Откат для последних 15 или менее значений в настоящее время использует подписанный метод трюка в цикле, но его можно было бы ускорить с помощью разворачивания и/или SSE.
  • Функции сами по себе довольно просты, поэтому они должны быть легко переносимы на все, что связано с SSE intrinsics или asm.
  • Я использовал специальные функции синхронизации Windows, потому что std::chrono::high_resolution_clock имеет низкую точность в реализации MSVC, извините за это, и за грязное сочетание тестового кода C/С++.
  • После определения времени выполнения результаты проверяются против реализации репликации ссылок, поэтому они должны быть правильными.

Пожалуйста, оставьте комментарий, если у вас есть какие-либо вопросы/предложения относительно кода или этого подхода в целом.

#include <cstdlib>
#include <cstdint>
#include <cstdio>
#include <cmath>
#include <random>
#include <algorithm>

#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <Windows.h>

#include <emmintrin.h> // sse2

// branching
void diff0(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    for (std::size_t i = 0; i < n; i++) {
        res[i] = a[i] > b[i] ? a[i] - b[i] : b[i] - a[i];
    }
}

// max min
void diff1(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    for (std::size_t i = 0; i < n; i++) {
        res[i] = std::max(a[i], b[i]) - std::min(a[i], b[i]);
    }
}

// branchless mul add
void diff2(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    for (std::size_t i = 0; i < n; i++) {
        res[i] = (a[i] > b[i]) * (a[i] - b[i]) + (a[i] < b[i]) * (b[i] - a[i]);
    }
}

// branchless signed
void diff3(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    for (std::size_t i = 0; i < n; i++) {
        std::int16_t  diff = a[i] - b[i];
        std::uint16_t mask = diff >> 15;
        res[i] = (diff + mask) ^ mask;
    }
}

// sse
void diff4(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    auto pA = reinterpret_cast<const __m128i*>(a);
    auto pB = reinterpret_cast<const __m128i*>(b);
    auto pRes = reinterpret_cast<__m128i*>(res);
    std::size_t i = 0;
    for (std::size_t j = n / 16; j--; i++) {
        __m128i max = _mm_max_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
        __m128i min = _mm_min_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
        _mm_store_si128(pRes + i, _mm_sub_epi8(max, min));
    }
    for (i *= 16; i < n; i++) { // fallback for the remaining <16 values
        std::int16_t  diff = a[i] - b[i];
        std::uint16_t mask = diff >> 15;
        res[i] = (diff + mask) ^ mask;
    }
}

// sse v2
void diff5(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    auto pA = reinterpret_cast<const __m128i*>(a);
    auto pB = reinterpret_cast<const __m128i*>(b);
    auto pRes = reinterpret_cast<__m128i*>(res);
    std::size_t i = 0;
    const std::size_t UNROLL = 2;
    for (std::size_t j = n / (16 * UNROLL); j--; i += UNROLL) {
        __m128i max0 = _mm_max_epu8(_mm_load_si128(pA + i + 0), _mm_load_si128(pB + i + 0));
        __m128i min0 = _mm_min_epu8(_mm_load_si128(pA + i + 0), _mm_load_si128(pB + i + 0));
        __m128i max1 = _mm_max_epu8(_mm_load_si128(pA + i + 1), _mm_load_si128(pB + i + 1));
        __m128i min1 = _mm_min_epu8(_mm_load_si128(pA + i + 1), _mm_load_si128(pB + i + 1));
        _mm_store_si128(pRes + i + 0, _mm_sub_epi8(max0, min0));
        _mm_store_si128(pRes + i + 1, _mm_sub_epi8(max1, min1));
    }
    for (std::size_t j = n % (16 * UNROLL) / 16; j--; i++) {
        __m128i max = _mm_max_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
        __m128i min = _mm_min_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
        _mm_store_si128(pRes + i, _mm_sub_epi8(max, min));
    }
    for (i *= 16; i < n; i++) { // fallback for the remaining <16 values
        std::int16_t  diff = a[i] - b[i];
        std::uint16_t mask = diff >> 15;
        res[i] = (diff + mask) ^ mask;
    }
}

int main() {
    const std::size_t ALIGN = 16; // sse requires 16 bit align
    const std::size_t N = 10 * 1024 * 1024 * 3;

    auto a = static_cast<uint8_t*>(_mm_malloc(N, ALIGN));
    auto b = static_cast<uint8_t*>(_mm_malloc(N, ALIGN));

    { // fill with random values
        std::mt19937 engine(std::random_device{}());
        std::uniform_int<std::uint8_t> distribution(0, 255);
        for (std::size_t i = 0; i < N; i++) {
            a[i] = distribution(engine);
            b[i] = distribution(engine);
        }
    }

    auto res0 = static_cast<uint8_t*>(_mm_malloc(N, ALIGN)); // diff0 results
    auto resX = static_cast<uint8_t*>(_mm_malloc(N, ALIGN)); // diff1+ results

    LARGE_INTEGER f, t0, t1;
    QueryPerformanceFrequency(&f);

    QueryPerformanceCounter(&t0);
    diff0(a, b, res0, N);
    QueryPerformanceCounter(&t1);
    printf("diff0: %.6f ms\n",
        static_cast<double>(t1.QuadPart - t0.QuadPart) / f.QuadPart * 1000);

#define TEST(diffX)\
    QueryPerformanceCounter(&t0);\
    diffX(a, b, resX, N);\
    QueryPerformanceCounter(&t1);\
    printf("%s: %.6f ms\n", #diffX,\
        static_cast<double>(t1.QuadPart - t0.QuadPart) / f.QuadPart * 1000);\
    for (std::size_t i = 0; i < N; i++) {\
        if (resX[i] != res0[i]) {\
            printf("error: %s(%03u, %03u) == %03u != %03u\n", #diffX,\
                a[i], b[i], resX[i], res0[i]);\
            break;\
        }\
    }

    TEST(diff1);
    TEST(diff2);
    TEST(diff3);
    TEST(diff4);
    TEST(diff5);

    _mm_free(a);
    _mm_free(b);
    _mm_free(res0);
    _mm_free(resX);

    getc(stdin);
    return 0;
}

Ответ 3

Изменить: меняя мой ответ, у меня были неверные настройки для оптимизации.

Я установил быструю тестовую кровать для этого в C, и я обнаружил, что

a - b + (a < b) * ((b - a) << 1);

волосы лучше, по крайней мере, в моей настройке. Преимущество моего подхода - исключить сравнение. Ваша версия неявно обрабатывает a - b == 0 как отдельный случай, если это не требуется.

Мой тест с твоей принимает

  • Ваша реализация: 371ms
  • Эта реализация: 324ms
  • Ускорение: 14%

Я пробовал подход с не ветвящимся абсолютным значением, и результаты были лучше. Обратите внимание, что входы или выходы считаются подписанными или нет компилятором, не имеет значения. Он перемещается вокруг больших значений без знака, но поскольку он должен работать только с небольшими значениями (как указано в вопросе), он должен быть достаточным.

s32 diff = a - b;
u32 mask = diff >> 31;
return (diff + mask) ^ mask;
  • Ваша реализация: 371ms
  • Эта реализация: 241ms
  • Ускорение: 53%!