Подтвердить что ты не робот

Каков правильный способ выполнения вложенных циклов с постоянным пространством в Haskell?

Существует два очевидных "идиоматических" способа выполнения вложенных циклов в Haskell: использование монады списка или использование forM_ для замены традиционного fors. Я установил контрольный показатель, чтобы определить, скомпилированы ли они для жестких циклов:

import Control.Monad.Loop
import Control.Monad.Primitive
import Control.Monad
import Control.Monad.IO.Class
import qualified Data.Vector.Unboxed.Mutable as MV
import qualified Data.Vector.Unboxed as V

times = 100000
side  = 100

-- Using `forM_` to replace traditional fors
test_a mvec = 
    forM_ [0..times-1] $ \ n -> do
        forM_ [0..side-1] $ \ y -> do
            forM_ [0..side-1] $ \ x -> do
                MV.write mvec (y*side+x) 1

-- Using the list monad to replace traditional forms
test_b mvec = sequence_ $ do
    n <- [0..times-1]
    y <- [0..side-1]
    x <- [0..side-1]
    return $ MV.write mvec (y*side+x) 1

main = do
    let vec = V.generate (side*side) (const 0)
    mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    -- test_a mvec
    -- test_b mvec
    vec' <- V.unsafeFreeze mvec :: IO (V.Vector Int)
    print $ V.sum vec'

Этот тест создает вектор 100x100, записывает 1 в каждый индекс с помощью вложенного цикла и повторяет это 100k раз. Компилируя те, которые имеют только ghc -O2 test.hs -o test (версия ghc 7.8.4), результаты: 3.853s для версии forM_ и 10.460s для list monad. Чтобы предоставить ссылку, я также запрограммировал этот тест на JavaScript:

var side  = 100;
var times = 100000;
var vec   = [];

for (var i=0; i<side*side; ++i)
    vec.push(0);

for (var n=0; n<times; ++n)
    for (var y=0; y<side; ++y)
        for (var x=0; x<side; ++x)
            vec[x+y*side] = 1;

var s = 0;
for (var i=0; i<side*side; ++i)
    s += vec[i];

console.log(s);

Эта эквивалентная программа JavaScript принимает 1s, что приводит к избиению нераспознанных векторов Haskell, что является необычным, предполагая, что Haskell не запускает цикл в постоянном пространстве, а вместо этого выполняет выделение. Затем я нашел библиотеку, которая утверждает, что предоставляет жесткие циклы, гарантированные типом Control.Monad.Loop:

-- Using `for` from Control.Monad.Loop
test_c mvec = exec_ $ do
    n <- for 0 (< times) (+ 1)
    x <- for 0 (< side) (+ 1)
    y <- for 0 (< side) (+ 1)
    liftIO (MV.write mvec (y*side+x) 1)

Что работает в 1s. Эта библиотека не очень используется и далека от идиоматической, хотя, так что каков идиоматический способ быстрого вычисления двумерных вычислений с постоянным пространством? (Обратите внимание, что это не относится к REPA, поскольку я хочу для выполнения произвольных операций ввода-вывода в сетке.)

4b9b3361

Ответ 1

Написание жесткого мутирующего кода с GHC может быть сложным иногда. Я собираюсь написать о нескольких разных вещах, возможно, в более бессвязном и т.д., Чем я бы предпочел.

Во-первых, мы должны использовать GHC 7.10 в любом случае, поскольку иначе forM_ и список решений монады никогда не сливаются.

Кроме того, я заменил MV.write на MV.unsafeWrite, отчасти потому, что он быстрее, но, что более важно, он уменьшает часть беспорядка в результирующем ядре. С этого момента статистика времени выполнения относится к коду с unsafeWrite.

Страшный пусть плавающий

Даже с GHC 7.10 мы должны сначала заметить все те выражения [0..times-1] и [0..side-1], потому что они будут разрушать производительность каждый раз, если мы не предпримем необходимые шаги. Проблема в том, что они являются постоянными диапазонами, а -ffull-laziness (который включен по умолчанию на -O) выгружает их на верхний уровень. Это предотвращает слияние списков, а итерация по диапазону Int# дешевле, чем итерация по списку вложенных пакетов Int -s, так что это очень плохая оптимизация.

Посмотрите некоторые промежутки времени в секундах для неизмененных (кроме использования unsafeWrite) кода. ghc -O2 -fllvm, и я использую +RTS -s для синхронизации.

test_a: 1.6
test_b: 6.2
test_c: 0.6

Для просмотра GHC Core я использовал ghc -O2 -ddump-simpl -dsuppress-all -dno-suppress-type-signatures.

В случае test_a диапазоны [0..99] сняты:

main4 :: [Int]
main4 = eftInt 0 99 -- means "enumFromTo" for Int.

хотя самый внешний цикл [0..9999] слит в хвостовой рекурсивный помощник:

letrec {
          a3_s7xL :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a3_s7xL =
            \ (x_X5zl :: Int#) (s1_X4QY :: State# RealWorld) ->
              case a2_s7xF 0 s1_X4QY of _ { (# ipv2_a4NA, ipv3_a4NB #) ->
              case x_X5zl of wild_X1S {
                __DEFAULT -> a3_s7xL (+# wild_X1S 1) ipv2_a4NA;
                99999 -> (# ipv2_a4NA, () #)
              }
              }; }

В случае test_b снова снимается только [0..99]. Тем не менее, test_b работает намного медленнее, потому что ему необходимо собрать и упорядочить фактические списки [IO ()]. По крайней мере, GHC достаточно разумен, чтобы создать только один [IO ()] для двух внутренних циклов, а затем выполнить последовательность 10000 раз.

 let {
          lvl7_s4M5 :: [IO ()]
          lvl7_s4M5 = -- omitted
        letrec {
          a2_s7Av :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a2_s7Av =
            \ (x_a5xi :: Int#) (eta_B1 :: State# RealWorld) ->
              letrec {
                a3_s7Au
                  :: [IO ()] -> State# RealWorld -> (# State# RealWorld, () #)
                a3_s7Au =
                  \ (ds_a4Nu :: [IO ()]) (eta1_X1c :: State# RealWorld) ->
                    case ds_a4Nu of _ {
                      [] ->
                        case x_a5xi of wild1_X1y {
                          __DEFAULT -> a2_s7Av (+# wild1_X1y 1) eta1_X1c;
                          99999 -> (# eta1_X1c, () #)
                        };
                      : y_a4Nz ys_a4NA ->
                        case (y_a4Nz `cast` ...) eta1_X1c
                        of _ { (# ipv2_a4Nf, ipv3_a4Ng #) ->
                        a3_s7Au ys_a4NA ipv2_a4Nf
                        }
                    }; } in
              a3_s7Au lvl7_s4M5 eta_B1; } in
-- omitted

Как мы можем исправить это? Мы можем устранить проблему с помощью {-# OPTIONS_GHC -fno-full-laziness #-}. Это действительно помогает в нашем случае:

test_a: 0.5
test_b: 0.48
test_c: 0.5

В качестве альтернативы мы могли бы заниматься с помощью INLINE прагм. Очевидно, что функции включения после того, как разрешено плавание, сохраняет хорошую производительность. Я обнаружил, что GHC строит наши тестовые функции даже без прагмы, но явная прагма заставляет его встраиваться только после того, как плавающие. Например, это приводит к хорошей производительности без -fno-full-laziness:

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
{-# INLINE test_a #-}

Но вложение слишком рано приводит к низкой производительности:

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
{-# INLINE [~2] test_a #-} -- "inline before the first phase please"

Проблема с этим решением INLINE заключается в том, что он довольно хрупкий перед лицом плавающего натиска GHC. Например, ручная вставка не сохраняет производительность. Следующий код медленный, так как аналогично INLINE [~2] он дает GHC возможность выплыть:

main = do
    let vec = V.generate (side*side) (const 0)
    mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1    

Итак, что нам делать?

Во-первых, я думаю, что использование -fno-full-laziness является вполне жизнеспособным и даже предпочтительным вариантом для тех, кто хотел бы написать высокопроизводительный код и иметь хорошую идею, что они делают. Например, он используется в unordered-containers. Благодаря этому мы имеем более точный контроль над совместным использованием, и мы всегда можем просто выплыть или встроить вручную.

Для более регулярного кода я считаю, что нет ничего плохого в использовании Control.Monad.Loop или любого другого пакета, который предоставляет функциональные возможности. Многие пользователи Haskell не скрупулезны в зависимости от небольших библиотек "бахромы". Мы также можем просто переопределить for в желаемой общности. Например, следующее выполняется так же хорошо, как и другие решения:

for :: Monad m => a -> (a -> Bool) -> (a -> a) -> (a -> m ()) -> m ()
for init while step body = go init where
  go !i | while i = body i >> go (step i)
  go i = return ()
{-# INLINE for #-}

Цикл в действительно постоянном пространстве

Сначала я был озадачен данными +RTS -s о распределении кучи. test_a выделено нетривиально с помощью -fno-full-laziness, а также test_c без полной лень, и эти распределения масштабируются линейно с числом итераций times, но test_b с полной леностью, выделенной только для вектора:

-- with -fno-full-laziness, no INLINE pragmas
test_a: 242,521,008 bytes
test_b: 121,008 bytes
test_c: 121,008 bytes -- but 240,120,984 with full laziness!

Кроме того, INLINE прагмы для test_c вообще не помогли.

Я потратил некоторое время, пытаясь найти признаки выделения кучи в ядре для соответствующих программ без успеха, пока не исполнилось меня: кадры стека GHC находятся в куче, включая кадры основного потока, и функции которые занимались распределением кучи, в основном выполняли трижды вложенные циклы в большинстве трех фреймов стека. Выделение кучи, зарегистрированное +RTS -s, представляет собой просто постоянное выскакивание и толкание кадров стека.

Это довольно очевидно из Core для следующего кода:

{-# OPTIONS_GHC -fno-full-laziness #-}

-- ...

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
main = do
    let vec = V.generate (side*side) (const 0)
    mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    test_a mvec

Который я включаю здесь в его славу. Не стесняйтесь пропустить.

main1 :: State# RealWorld -> (# State# RealWorld, () #)
main1 =
  \ (s_a5HK :: State# RealWorld) ->
    case divInt# 9223372036854775807 8 of ww4_a5vr { __DEFAULT ->

    -- start of vector creation ----------------------
    case tagToEnum# (># 10000 ww4_a5vr) of _ {
      False ->
        case newByteArray# 80000 (s_a5HK `cast` ...)
        of _ { (# ipv_a5fv, ipv1_a5fw #) ->
        letrec {
          $s$wa_s8jS
            :: Int#
               -> Int#
               -> State# (PrimState IO)
               -> (# State# (PrimState IO), Int #)
          $s$wa_s8jS =
            \ (sc_s8jO :: Int#)
              (sc1_s8jP :: Int#)
              (sc2_s8jR :: State# (PrimState IO)) ->
              case tagToEnum# (<# sc1_s8jP 10000) of _ {
                False -> (# sc2_s8jR, I# sc_s8jO #);
                True ->
                  case writeIntArray# ipv1_a5fw sc_s8jO 0 (sc2_s8jR `cast` ...)
                  of s'#_a5Gn { __DEFAULT ->
                  $s$wa_s8jS (+# sc_s8jO 1) (+# sc1_s8jP 1) (s'#_a5Gn `cast` ...)
                  }
              }; } in
        case $s$wa_s8jS 0 0 (ipv_a5fv `cast` ...)
        -- end of vector creation -------------------

        of _ { (# ipv6_a4Hv, ipv7_a4Hw #) ->
        letrec {
          a2_s7MJ :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a2_s7MJ =
            \ (x_a5Ho :: Int#) (eta_B1 :: State# RealWorld) ->
              letrec {
                a3_s7ME :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                a3_s7ME =
                  \ (x1_X5Id :: Int#) (eta1_XR :: State# RealWorld) ->
                    case ipv7_a4Hw of _ { I# dt4_a5x6 ->
                    case writeIntArray#
                           (ipv1_a5fw `cast` ...) (*# x1_X5Id 100) 1 (eta1_XR `cast` ...)
                    of s'#_a5Gn { __DEFAULT ->
                    letrec {
                      a4_s7Mz :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                      a4_s7Mz =
                        \ (x2_X5J8 :: Int#) (eta2_X1U :: State# RealWorld) ->
                          case writeIntArray#
                                 (ipv1_a5fw `cast` ...)
                                 (+# (*# x1_X5Id 100) x2_X5J8)
                                 1
                                 (eta2_X1U `cast` ...)
                          of s'#1_X5Hf { __DEFAULT ->
                          case x2_X5J8 of wild_X2o {
                            __DEFAULT -> a4_s7Mz (+# wild_X2o 1) (s'#1_X5Hf `cast` ...);
                            99 -> (# s'#1_X5Hf `cast` ..., () #)
                          }
                          }; } in
                    case a4_s7Mz 1 (s'#_a5Gn `cast` ...)
                    of _ { (# ipv2_a4QH, ipv3_a4QI #) ->
                    case x1_X5Id of wild_X1e {
                      __DEFAULT -> a3_s7ME (+# wild_X1e 1) ipv2_a4QH;
                      99 -> (# ipv2_a4QH, () #)
                    }
                    }
                    }
                    }; } in
              case a3_s7ME 0 eta_B1 of _ { (# ipv2_a4QH, ipv3_a4QI #) ->
              case x_a5Ho of wild_X1a {
                __DEFAULT -> a2_s7MJ (+# wild_X1a 1) ipv2_a4QH;
                99999 -> (# ipv2_a4QH, () #)
              }
              }; } in
        a2_s7MJ 0 (ipv6_a4Hv `cast` ...)
        }
        };
      True ->
        case error
               (unpackAppendCString#
                  "Primitive.basicUnsafeNew: length to large: "#
                  (case $wshowSignedInt 0 10000 ([])
                   of _ { (# ww5_a5wm, ww6_a5wn #) ->
                   : ww5_a5wm ww6_a5wn
                   }))
        of wild_00 {
        }
    }
    }

main :: IO ()
main = main1 `cast` ...

main2 :: State# RealWorld -> (# State# RealWorld, () #)
main2 = runMainIO1 (main1 `cast` ...)

main :: IO ()
main = main2 `cast` ...

Мы также можем продемонстрировать распределение кадров следующим образом. Пусть изменение test_a:

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-50] $ \ x -> -- change here
                MV.unsafeWrite mvec (y*side+x) 1

Теперь распределение кучи остается точно таким же, потому что самый внутренний цикл является хвостовым рекурсивным и использует один кадр. При следующем изменении, распределение кучи половину (до 124,921,008 байт), потому что мы нажимаем и набираем пополам столько кадров:

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-50] $ \ y -> -- change here
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1

test_b и test_c (без полной лени) вместо этого компилируются в код, который использует вложенную конструкцию case внутри одного стекового фрейма и просматривает индексы, чтобы увидеть, какой из них нужно увеличивать. Смотрите Core для следующего main:

{-# LANGUAGE BangPatterns #-} -- later I'll talk about this
{-# OPTIONS_GHC -fno-full-laziness #-}

main = do
    let vec = V.generate (side*side) (const 0)
    !mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    test_c mvec

Вуаля:

main1 :: State# RealWorld -> (# State# RealWorld, () #)
main1 =
  \ (s_a5Iw :: State# RealWorld) ->
    case divInt# 9223372036854775807 8 of ww4_a5vT { __DEFAULT ->

    -- start of vector creation ----------------------
    case tagToEnum# (># 10000 ww4_a5vT) of _ {
      False ->
        case newByteArray# 80000 (s_a5Iw `cast` ...)
        of _ { (# ipv_a5g3, ipv1_a5g4 #) ->
        letrec {
          $s$wa_s8ji
            :: Int#
               -> Int#
               -> State# (PrimState IO)
               -> (# State# (PrimState IO), Int #)
          $s$wa_s8ji =
            \ (sc_s8je :: Int#)
              (sc1_s8jf :: Int#)
              (sc2_s8jh :: State# (PrimState IO)) ->
              case tagToEnum# (<# sc1_s8jf 10000) of _ {
                False -> (# sc2_s8jh, I# sc_s8je #);
                True ->
                  case writeIntArray# ipv1_a5g4 sc_s8je 0 (sc2_s8jh `cast` ...)
                  of s'#_a5GP { __DEFAULT ->
                  $s$wa_s8ji (+# sc_s8je 1) (+# sc1_s8jf 1) (s'#_a5GP `cast` ...)
                  }
              }; } in
        case $s$wa_s8ji 0 0 (ipv_a5g3 `cast` ...)
        of _ { (# ipv6_a4MX, ipv7_a4MY #) ->
        case ipv7_a4MY of _ { I# dt4_a5xy ->
        -- end of vector creation

        letrec {
          a2_s7Q6 :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a2_s7Q6 =
            \ (x_a5HT :: Int#) (eta_B1 :: State# RealWorld) ->
              letrec {
                a3_s7Q5 :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                a3_s7Q5 =
                  \ (x1_X5J9 :: Int#) (eta1_XP :: State# RealWorld) ->
                    letrec {
                      a4_s7MZ :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                      a4_s7MZ =
                        \ (x2_X5Jl :: Int#) (s1_X4Xb :: State# RealWorld) ->
                          case writeIntArray#
                                 (ipv1_a5g4 `cast` ...)
                                 (+# (*# x1_X5J9 100) x2_X5Jl)
                                 1
                                 (s1_X4Xb `cast` ...)
                          of s'#_a5GP { __DEFAULT ->

                          -- the interesting part! ------------------
                          case x2_X5Jl of wild_X1y {
                            __DEFAULT -> a4_s7MZ (+# wild_X1y 1) (s'#_a5GP `cast` ...);
                            99 ->
                              case x1_X5J9 of wild1_X1o {
                                __DEFAULT -> a3_s7Q5 (+# wild1_X1o 1) (s'#_a5GP `cast` ...);
                                99 ->
                                  case x_a5HT of wild2_X1c {
                                    __DEFAULT -> a2_s7Q6 (+# wild2_X1c 1) (s'#_a5GP `cast` ...);
                                    99999 -> (# s'#_a5GP `cast` ..., () #)
                                  }
                              }
                          }
                          }; } in
                    a4_s7MZ 0 eta1_XP; } in
              a3_s7Q5 0 eta_B1; } in
        a2_s7Q6 0 (ipv6_a4MX `cast` ...)
        }
        }
        };
      True ->
        case error
               (unpackAppendCString#
                  "Primitive.basicUnsafeNew: length to large: "#
                  (case $wshowSignedInt 0 10000 ([])
                   of _ { (# ww5_a5wO, ww6_a5wP #) ->
                   : ww5_a5wO ww6_a5wP
                   }))
        of wild_00 {
        }
    }
    }

main :: IO ()
main = main1 `cast` ...

main2 :: State# RealWorld -> (# State# RealWorld, () #)
main2 = runMainIO1 (main1 `cast` ...)

main :: IO ()
main = main2 `cast` ...

Я должен признать, что я в основном не знаю, почему какой-то код избегает создания фреймов стека, а некоторые нет. Я подозреваю, что поддержка от "внутри" помогает, и быстрый осмотр сообщил мне, что Control.Monad.Loop использует CPS-кодировку, которая может быть уместна здесь, хотя решение Monad.Loop чувствительно, чтобы позволить плавать, и я не мог определить в короткое время из Core, почему test_c, если плавающее не работает в одном стеке стека.

Теперь преимущество производительности при работе в одном стеке меньше. Мы видели, что test_b только немного быстрее, чем test_a. Я включаю этот обход в ответ, потому что нашел его поучительным.

Хранение состояния и строгие привязки

Так называемый государственный хак делает GHC агрессивным в вложении в действия IO и ST. Я думаю, что я должен упомянуть об этом здесь, потому что, помимо того, что плавание - это еще одна вещь, которая может полностью испортить производительность.

Взлом состояния разрешен с оптимизацией -O и может замедлять программы асимптотически. Простой пример из Рид Бартон:

import Control.Monad
import Debug.Trace

expensive :: String -> String
expensive x = trace "$$$" x

main :: IO ()
main = do
  str <- fmap expensive getLine
  replicateM_ 3 $ print str

С GHC-7.10.2 это печатает "$$$" один раз без оптимизации, но три раза с -O2. И похоже, что с GHC-7.10 мы не можем избавиться от этого поведения с помощью -fno-state-hack (который является предметом связанного билета Рида Бартона).

Строгие монадические привязки надежно избавляются от этой проблемы:

main :: IO ()
main = do
  !str <- fmap expensive getLine
  replicateM_ 3 $ print str

Я считаю хорошей привычкой делать строгие привязки в IO и ST. И у меня есть некоторый опыт (не окончательный, хотя я далек от того, чтобы быть экспертом GHC), что строгие привязки особенно необходимы, если мы используем -fno-full-laziness. По-видимому, полная лень может помочь избавиться от части дублирования работы, вызванного вложением, вызванным государственным взломом; с test_b и без полной лень, опустив строгую привязку на !mvec <- V.unsafeThaw vec, вызвало небольшое замедление и крайне уродливый вывод Core.

Ответ 2

По моему опыту forM_ [0..n-1] может работать хорошо, но, к сожалению, он не надежный. Просто добавьте прагму INLINE к test_a, а использование -O2 запустит ее намного быстрее (от 4s до 1s для меня), но ручная установка ее (копирование) замедляет ее снова.

Более надежной функцией является for из statistics, который реализован как

-- | Simple for loop.  Counts from /start/ to /end/-1.
for :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
for n0 !n f = loop n0
  where
    loop i | i == n    = return ()
           | otherwise = f i >> loop (i+1)
{-# INLINE for #-}

Использование его похоже на forM_ со списками:

test_d :: MV.IOVector Int -> IO ()
test_d mv =
  for 0 times $ \_ ->
    for 0 side $ \i ->
      for 0 side $ \j ->
        MV.unsafeWrite mv (i*side + j) 1

но надежно работает (0.85s для меня) без риска выделения списка.