了解Haskell函数的性能

r1wp621o  于 2022-12-13  发布在  其他
关注(0)|答案(1)|浏览(139)

我写了一个haskell程序,需要用一个函数压缩两个列表,但我不想让它停在较短列表的末尾(就像zipWith的标准版本),而是继续到较长列表的末尾,在到达较短列表的末尾后使用某个默认值。我的第一个实现看起来像这样:

zipWithAll :: (a -> b -> c) -> a -> b -> [a] -> [b] -> [c]
zipWithAll f x y = go
    where go []     []     = []
          go (a:as) []     = f a y : go as []
          go []     (b:bs) = f x b : go [] bs
          go (a:as) (b:bs) = f a b : go as bs

然而,我通常更喜欢使用标准库的高阶函数来编写函数,而不是使用显式递归,因为它通常会提高性能,并使代码看起来更漂亮。因此,我尝试了以下方法:

zipWithAll' :: (a -> b -> c) -> a -> b -> [a] -> [b] -> [c]
zipWithAll' f x y xs ys = zipWith f xs' ys'
    where n   = max (length xs) (length ys)
          xs' = take n $ xs ++ repeat x
          ys' = take n $ ys ++ repeat y

在我看来,这在性能方面要差得多,因为使用两次length意味着多两次列表遍历。但令人惊讶的是,当我比较平均时间时,第一个版本比第二个版本慢了大约20%。
所以,我想我应该用第二个,并为列表的长度添加参数,因为在某些情况下,它们是事先知道的。因此,我写了这样一个:

zipWithAll'' :: (a -> b -> c) -> Int -> Int -> a -> b -> [a] -> [b] -> [c]
zipWithAll'' f n m x y xs ys = zipWith f xs' ys'
    where k   = max n m
          xs' = take k $ xs ++ repeat x
          ys' = take k $ ys ++ repeat y

但是,更令人惊讶的是,第三个版本仅仅提高了很小的性能。对于两个随机生成的Intxsyslength xs = length ys = n = 1000000列表,我得到了以下结果:

| function                       | average time, 30 evaluations |
+--------------------------------+------------------------------+
| zipWithAll   (+) 0 0 xs ys     |                        1.20s |
| zipWithAll'  (+) 0 0 xs ys     |                        0.95s |
| zipWithAll'' (+) n n 0 0 xs ys |                        0.94s |
+--------------------------------+------------------------------+

我知道这并不是最全面的基准测试,但它仍然违背了我对haskell程序运行速度的直觉,让我觉得我遗漏了一些重要的东西,无法理解haskell函数的性能。
所以基本上我想知道的是:
为什么简单递归方法最慢?是因为对标准zipWith函数进行了优化吗?如果是这样,我可以做些什么来使它的性能与zipWith相似吗?
另外,我假设第二个版本执行3n操作,而第三个版本只执行n操作,这是不是错了?如果是这样,为什么这对性能没有更大的影响?我可以想象,如果zipping函数非常耗时,这就不那么重要了,但我在这里只使用(+)
最后,是否有一种方法可以实现更快的zipWithAll,通过利用标准库函数的优化,而无需事先知道列表的长度?
(编辑)这是我使用的基准测试代码的相关部分。

{-# LANGUAGE BangPatterns #-}

import Control.Monad (replicateM, forM_)
import Data.Foldable (foldl')
import Data.Time (diffUTCTime, getCurrentTime, NominalDiffTime)
import Numeric (showEFloat, showFFloat)
import Test.QuickCheck

main = do
    let n  = 1000000
        fs = [ ("zipWithAll", uncurry4 $ zipWithAll (+))
             , ("zipWithAll'", uncurry4 $ zipWithAll' (+))
             , ("zipWithAll''", uncurry4 $ zipWithAll'' (+) n n)]

    xs <- generate (vectorOf n arbitrary :: Gen [Int])
    ys <- generate (vectorOf n arbitrary :: Gen [Int])
    benchmark fs (0, 0, xs, ys) 30

uncurry4 :: (a -> b -> c -> d -> e) -> (a,b,c,d) -> e
uncurry4 f (a,b,c,d) = f a b c d

-- | Measure and print the average time it takes for each function in the list to return.
benchmark :: (Show a, Show b) => [(String, (a -> b))] -> a -> Int -> IO ()
benchmark fs x rep = do
    force x
    forM_ fs $ \(name, f) -> do
        ts <- replicateM rep (measureTime f x)
        putStrLn $ "function: " ++ name ++ ", time = " ++ (showSignificant 2 $ average ts)

-- | Get the time measurement for a function applied to an arguemnt
measureTime :: Show b => (a -> b) -> a -> IO NominalDiffTime
measureTime f x = do
    t1 <- getCurrentTime
    force (f x)
    t2 <- getCurrentTime
    return $ diffUTCTime t2 t1

-- | Force the computation of a value
force :: Show a => a -> IO ()
force a = maximum (show a) `seq` return ()

-- | Show a time difference using @n@ significant figures
showSignificant :: Int -> NominalDiffTime -> String
showSignificant n a = showFFloat Nothing b "s"
  where
    ae = showEFloat (Just (n-1)) (fromRational (toRational a)) ""
    b  = read ae :: Double

-- | Take the average of the elements in a foldable data structure
average :: (Foldable t, Fractional a) => t a -> a
average = uncurry (/) . foldl' f (0,0)
    where f (s,l) x = (s', l')
            where !s' = x + s
                  !l' = 1 + l
2ul0zpep

2ul0zpep1#

Criterion是Haskell中基准测试的黄金标准。我真的不相信来自其他地方的基准测试,所以我把你的套件移植到了Criterion。我在这个答案的底部包含了我的源文件,这样如果我做错了什么,有人可以很容易地修复它。一个重要的区别是:我让列表xsys的大小不同,以便实际执行函数中有趣的部分:ysxs的两倍大。下面是我看到的结果:

benchmarking standalone/zipWithAll
time                 18.53 ms   (18.08 ms .. 18.96 ms)
                     0.996 R²   (0.993 R² .. 0.998 R²)
mean                 19.77 ms   (19.28 ms .. 20.40 ms)
std dev              1.345 ms   (1.021 ms .. 1.711 ms)
variance introduced by outliers: 30% (moderately inflated)

benchmarking standalone/zipWithAll'
time                 43.61 ms   (43.25 ms .. 44.00 ms)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 43.37 ms   (43.27 ms .. 43.57 ms)
std dev              256.9 μs   (140.4 μs .. 437.6 μs)

benchmarking standalone/zipWithAll''
time                 27.65 ms   (27.32 ms .. 28.20 ms)
                     0.999 R²   (0.997 R² .. 1.000 R²)
mean                 27.58 ms   (27.40 ms .. 27.92 ms)
std dev              513.2 μs   (339.7 μs .. 770.0 μs)

简单的递归方法比遍历列表两次的方法快一倍--这并不奇怪!如果提前传递大小,可以保存一些额外的开销,但是仍然需要进行连接和take,这两种方法都不是免费的,因此明显落后。
为什么简单版本最快?你提到zipWith在标准库中有一个优化的实现,但是如果你看一下它,你会发现它的实现正是你或我所写的。一个有趣的事情是关于融合的注解,我认为这主要意味着,如果你写map succ (zipWith f (filter even xs) ys)或类似的东西,它可以融合filterzipWithmap合并到一个循环操作中,而不必具体化中间列表。所以,我在上面的时候撒谎了,我声称我对你的套件唯一有趣的修改是改变列表大小。我还添加了以这种方式使用函数的基准,我们可以在这里看到:

benchmarking fused/zipWithAll
time                 43.80 ms   (43.43 ms .. 44.29 ms)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 43.65 ms   (43.41 ms .. 43.94 ms)
std dev              522.9 μs   (374.5 μs .. 651.1 μs)

benchmarking fused/zipWithAll'
time                 132.3 ms   (128.3 ms .. 138.7 ms)
                     0.998 R²   (0.994 R² .. 1.000 R²)
mean                 131.4 ms   (127.6 ms .. 133.7 ms)
std dev              4.495 ms   (2.430 ms .. 6.794 ms)
variance introduced by outliers: 11% (moderately inflated)

benchmarking fused/zipWithAll''
time                 52.83 ms   (52.36 ms .. 53.36 ms)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 52.51 ms   (52.34 ms .. 52.72 ms)
std dev              366.1 μs   (246.8 μs .. 486.0 μs)

版本1仍然是赢家,但是版本2变得更糟,而版本3缩小差距。这是什么特别的证据吗?也许它表明在使用zipWithAll''时,一些多余的操作确实被融合了。也许不是,我敢打赌ff' lambda使GHC很难一路内联。我现在没有时间讨论这个问题。如果您愿意,可以给予-ddump-simpl尝试一下。
正如所承诺的那样,下面是我的基准测试的代码:

module Main (main) where

import Criterion.Main
import Test.QuickCheck

zipWithAll :: (a -> b -> c) -> a -> b -> [a] -> [b] -> [c]
zipWithAll f x y = go
    where go []     []     = []
          go (a:as) []     = f a y : go as []
          go []     (b:bs) = f x b : go [] bs
          go (a:as) (b:bs) = f a b : go as bs

zipWithAll' :: (a -> b -> c) -> a -> b -> [a] -> [b] -> [c]
zipWithAll' f x y xs ys = zipWith f xs' ys'
    where n   = max (length xs) (length ys)
          xs' = take n $ xs ++ repeat x
          ys' = take n $ ys ++ repeat y

zipWithAll'' :: (a -> b -> c) -> Int -> Int -> a -> b -> [a] -> [b] -> [c]
zipWithAll'' f n m x y xs ys = zipWith f xs' ys'
    where k   = max n m
          xs' = take k $ xs ++ repeat x
          ys' = take k $ ys ++ repeat y

main :: IO ()
main = do
  let xSize = 1000000
      ySize = xSize * 2
  xs <- generate (vectorOf xSize arbitrary :: Gen [Int])
  ys <- generate (vectorOf ySize arbitrary :: Gen [Int])
  let impls = [ ("zipWithAll", zipWithAll (+) 0 0)
              , ("zipWithAll'", zipWithAll' (+) 0 0)
              , ("zipWithAll''", zipWithAll'' (+) xSize ySize 0 0)
              ]
  defaultMain [ bgroup "standalone" $ do
                  (name, f) <- impls
                  let f' (xs, ys) = f xs ys
                  pure . bench name $ nf f' (xs, ys)
              , bgroup "fused" $ do
                  (name, f) <- impls
                  let f' (xs, ys) = map succ (f (filter even xs) ys)
                  pure . bench name $ nf f' (xs, ys)
              ]

相关问题