Написание эффективного итерационного цикла для ST monad - программирование
Подтвердить что ты не робот

Написание эффективного итерационного цикла для ST monad

go рабочий хвост-рекурсивный шаблон цикла, кажется, очень хорошо работает для написания чистого кода. Каким будет эквивалентный способ написать такой цикл для монады ST? В частности, я хочу избежать выделения новых кучи в итерациях цикла. Я предполагаю, что он включает в себя CPS transformation или fixST, чтобы переписать код таким образом, чтобы все значения, которые менялись в цикле, передавались через каждую итерацию, тем самым делая места регистрации (или стек в случае разлива) доступными для эти значения на итерации. У меня есть упрощенный пример ниже (не пытайтесь запустить его - он, скорее всего, сбой с ошибкой сегментации!) С функцией findSnakes, которая имеет рабочий шаблон go, но переменные значения состояния не передаются через аргументы аккумулятора:

{-# LANGUAGE BangPatterns #-}
module Test where

import Data.Vector.Unboxed.Mutable as MU
import Data.Vector.Unboxed as U hiding (mapM_)
import Control.Monad.ST as ST
import Control.Monad.Primitive (PrimState)
import Control.Monad as CM (when,forM_)
import Data.Int

type MVI1 s  = MVector (PrimState (ST s)) Int

-- function to find previous y
findYP :: MVI1 s -> Int -> Int -> ST s Int
findYP fp k offset = do
              y0 <- MU.unsafeRead fp (k+offset-1) >>= \x -> return $ 1+x
              y1 <- MU.unsafeRead fp (k+offset+1)
              if y0 > y1 then return y0
              else return y1
{-#INLINE findYP #-}

findSnakes :: Vector Int32 -> MVI1 s ->  Int -> Int -> (Int -> Int -> Int) -> ST s ()
findSnakes a fp !k !ct !op = go 0 k
     where
           offset=1+U.length a
           go x k'
            | x < ct = do
                yp <- findYP fp k' offset
                MU.unsafeWrite fp (k'+offset) (yp + k')
                go (x+1) (op k' 1)
            | otherwise = return ()
{-#INLINE findSnakes #-}

Глядя на вывод cmm в ghc 7.6.1 (с моим ограниченным знанием cmm - пожалуйста, исправьте меня, если я ошибаюсь), я вижу этот поток вызовов, с циклом в s1tb_info (что вызывает распределение кучи и кучи на каждой итерации):

findSnakes_info -> a1_r1qd_info -> $wa_r1qc_info (new stack allocation, SpLim check)
-> s1sy_info -> s1sj_info: if arg > 1 then s1w8_info else R1 (can't figure out 
what that register points to)

-- I am guessing this one below is for go loop
s1w8_info -> s1w7_info (big heap allocation, HpLim check) -> s1tb_info: if arg >= 1
then s1td_info else R1

s1td_info (big heap allocation, HpLim check) -> if arg >= 1 then s1tb_info
(a loop) else s1tb_info (after executing a different block of code)

Я предполагаю, что проверка кода формы arg >= 1 в cmm заключается в том, чтобы определить, завершен ли цикл go или нет. Если это правильно, кажется, что если цикл go не перезаписан, чтобы пройти yp через цикл, распределение кучи произойдет через цикл для новых значений (я предполагаю, что yp вызывает распределение кучи). Что было бы эффективным способом написать цикл go в приведенном выше примере? Я думаю, yp должен быть передан как аргумент в цикле go или эквивалентным способом преобразования fixST или CPS. Я не могу придумать хороший способ переписать цикл go выше, чтобы удалить выделение кучи и позаботится о помощи.

4b9b3361

Ответ 1

Я переписал ваши функции, чтобы избежать какой-либо явной рекурсии и удалил некоторые избыточные операции, вычисляющие смещения. Это скомпоновано гораздо лучше, чем ваши оригинальные функции.

Core, кстати, вероятно, лучший способ проанализировать ваш скомпилированный код для такого профилирования. Используйте ghc -ddump-simpl, чтобы увидеть сгенерированный вывод ядра, или такие инструменты, как ghc-core

import Control.Monad.Primitive                                                                               
import Control.Monad.ST                                                                                      
import Data.Int                                                                                              
import qualified Data.Vector.Unboxed.Mutable as M                                                            
import qualified Data.Vector.Unboxed as U                                                                    

type MVI1 s  = M.MVector (PrimState (ST s)) Int                                                              

findYP :: MVI1 s -> Int -> ST s Int                                                                                                                                                     
findYP fp offset = do                                                                                      
    y0 <- M.unsafeRead fp (offset+0)                                                                       
    y1 <- M.unsafeRead fp (offset+2)                                                                       
    return $ max (y0 + 1) y1                                                                                  

findSnakes :: U.Vector Int32 -> MVI1 s ->  Int -> Int -> (Int -> Int -> Int) -> ST s ()                                                                                                         
findSnakes a fp k0 ct op = U.mapM_ writeAt $ U.iterateN ct (`op` 1) k0                                       
    where writeAt k = do    
              let offset = U.length a + k                                                                 
              yp <- findYP fp offset                                                                        
              M.unsafeWrite fp (offset + 1) (yp + k)

          -- or inline findYP manually
          writeAt k = do
             let offset = U.length a + k
             y0 <- M.unsafeRead fp (offset + 0)
             y1 <- M.unsafeRead fp (offset + 2)
             M.unsafeWrite fp (offset + 1) (k + max (y0 + 1) y1)

Кроме того, вы передаете U.Vector Int32 в findSnakes, только для вычисления его длины и никогда не используйте a снова. Почему бы не пройти по длине напрямую?