The Haskell Prelude provides multiple ways to fold lists into a single value. For example, you can count the number of elements in a list:
import Data.List (genericLength)
genericLength :: (Num i) => [a] -> i
... or you can add them up:
import Prelude hiding (sum)
-- I'm deviating from the Prelude's sum, which leaks space
sum :: (Num a) => [a] -> a
sum = foldl' (+) 0
Individually, these two folds run in constant memory when given a lazy list as an argument, never bringing more than one element into memory at a time:
>>> genericLength [1..100000000]
100000000
>>> sum' [1..100000000]
5000000050000000
However, we get an immediate space leak if we try to combine these two folds to compute an average:
>>> let average xs = sum xs / genericLength xs
>>> average [1..100000000]
<Huge space leak>
The original isolated folds streamed in constant memory because Haskell is lazy and does not compute each element of the list until the fold actually requests the element. After the fold traverses each element the garbage collector detects the element will no longer be used and collects it immediately, preventing any build-up of elements.
However, when we combine these two folds naively like we did with
average then our program leaks space while we compute
sum and before we get a chance to compute
genericLength. As
sum traverses the list, the garbage collector cannot collect any of the elements because we have to hold on to the entire list for the subsequent
genericLength fold.
Unfortunately, the
conventional solution to this is not pretty:
mean :: [Double] -> Double
mean = go 0 0
where
go s l [] = s / fromIntegral l
go s l (x:xs) = s `seq` l `seq`
go (s+x) (l+1) xs
Here we've sliced open the guts of each fold and combined their individual step functions into a new step function so we can pass over the list just once. We also had to pay a lot of attention to detail regarding strictness. This is what newcomers to Haskell complain about when they say you need to be an expert at Haskell to produce highly efficient code.
The Fold type
Let's fix this by reformulating our original folds to preserve more information so that we can transparently combine multiple folds into a single pass over the list:
{-# LANGUAGE ExistentialQuantification #-}
import Data.List (foldl')
import Data.Monoid
data Fold a b = forall w. (Monoid w) => Fold
{ tally :: a -> w
, summarize :: w -> b
}
fold :: Fold a b -> [a] -> b
fold (Fold t c) xs =
c (foldl' mappend mempty (map t xs))
Here I've taken a fold and split it into two parts:
- tally: The step function that we use to accumulate each element of the list
- summarize: The final function we call at the end of the fold to convert our accumulator into the desired result
The
w type variable represents the internal accumulator that our
Fold will use as it traverses the list. The
Fold can use any accumulator of its choice as long as the accumulator is a
Monoid of some sort. We specify that in the types by existentially quantifying the accumulator using the
ExistentialQuantification extension.
The end user also doesn't care what the internal accumulator is either, because the user only interacts with
Folds using the
fold function. The type system enforces that
fold (or any other function) cannot use any specific details about a
Fold's accumulator other than the fact that the accumulator is a
Monoid.
We'll test out this type by rewriting out our original folds using the new
Fold type:
genericLength :: (Num i) => Fold a i
genericLength =
Fold (\_ -> Sum 1) (fromIntegral . getSum)
sum :: (Num a) => Fold a a
sum = Fold Sum getSum
Notice how the
Monoid we choose implicitly encodes how to accumulate the result.
genericLength counts the number of elements simply by mapping them all to
Sum 1, and then the
Monoid instance for
Sum just adds up all these ones to get the list length.
sum is even simpler: just wrap each element in
Sum and the
Monoid instance for
Sum adds up every element of the list. When we're done, we unwrap the final result using
getSum.
We can now apply these folds to any list using the
fold function, which handles all the details of accumulating each element of the list and summarizing the result:
>>> fold genericLength [(1::Int)..100000000]
100000000
>>> fold sum [(1::Int)..100000000]
5000000050000000
So far, so good, but how do we combine them into an
average?
Combining Folds
Fold has the nice property that it is an
Applicative, given by the following definition:
import Control.Applicative
import Data.Strict.Tuple
instance (Monoid a, Monoid b) => Monoid (Pair a b) where
mempty = (mempty :!: mempty)
mappend (aL :!: aR) (bL :!: bR) =
(mappend aL bL :!: mappend aR bR)
instance Functor (Fold a) where
fmap f (Fold t k) = Fold t (f . k)
instance Applicative (Fold a) where
pure a = Fold (\_ -> ()) (\_ -> a)
(Fold tL cL) <*> (Fold tR cR) =
let t x = (tL x :!: tR x)
c (wL :!: wR) = (cL wL) (cR wR)
in Fold t c
Note that this uses strict
Pairs from
Data.Strict.Tuple to ensure that the combined
Fold still automatically runs in constant space. You only need to remember that
(x :!: y) is the strict analog of
(x, y).
With this
Applicative instance in hand, we can very easily combine our
sum and
genericLength folds into an
average fold:
average :: (Fractional a) => Fold a a
average = (/) <$> sum <*> genericLength
This combines the two folds transparently into a single fold that traverses the list just once in constant memory, computing the average of all elements within the list:
>>> fold average [1..1000000]
500000.5
Now we're programming at a high-altitude instead of hand-writing our own accumulators and left folds and praying to the strictness gods.
What if we wanted to compute the standard deviation of a list? All we need is one extra primitive fold that computes the sum of squares:
sumSq :: (Num a) => Fold a a
sumSq = Fold (\x -> Sum (x ^ 2)) getSum
Now we can write a derived fold using
Applicative style:
std :: (Floating a) => Fold a a
std = (\ss s len -> sqrt (ss / len - (s / len)^2))
<$> sumSq
<*> sum
<*> genericLength
... which still traverses the list just once:
fold std [1..10000000]
2886751.345954732
In fact, this is the exact same principle that the
BIRCH data clustering algorithm uses for clustering features. You keep a tally of the length, sum, and sum of squares, and you can compute most useful statistics in O(1) time from those three tallies.
Similarly, what if we wanted to compute both the
sum and
product of a list in a single pass?
product :: (Num a) => Fold a a
product = Fold Product getProduct
Once again, we can just use
Applicative style:
>>> fold ((,) <$> sum <*> product) [1..100]
(5050,9332621544394415268169923885626670049071596826438162146859
2963895217599993229915608941463976156518286253697920827223758251
185210916864000000000000000000000000)
Conclusion
Contrary to conventional wisdom, you can program in Haskell at a high level without leaking space. Haskell gives you the tools to abstract away efficient idioms behind a convenient and composable interface, so use them!
Appendix
I've included the full code so that people can play with this themselves:
{-# LANGUAGE ExistentialQuantification #-}
import Control.Applicative
import Data.List (foldl')
import Data.Monoid
import Data.Strict.Tuple
import Prelude hiding (sum, length)
data Fold a b = forall w. (Monoid w) => Fold
{ tally :: a -> w
, compute :: w -> b
}
fold :: Fold a b -> [a] -> b
fold (Fold t c) xs =
c (foldl' mappend mempty (map t xs))
instance (Monoid a, Monoid b) => Monoid (Pair a b) where
mempty = (mempty :!: mempty)
mappend (aL :!: aR) (bL :!: bR) =
(mappend aL bL :!: mappend aR bR)
instance Functor (Fold a) where
fmap f (Fold t k) = Fold t (f . k)
instance Applicative (Fold a) where
pure a = Fold (\_ -> ()) (\_ -> a)
(Fold tL cL) <*> (Fold tR cR) =
let t x = (tL x :!: tR x)
c (wL :!: wR) = (cL wL) (cR wR)
in Fold t c
genericLength :: (Num b) => Fold a b
genericLength =
Fold (\_ -> Sum (1::Int)) (fromIntegral . getSum)
sum :: (Num a) => Fold a a
sum = Fold Sum getSum
sumSq :: (Num a) => Fold a a
sumSq = Fold (\x -> Sum (x ^ 2)) getSum
average :: (Fractional a) => Fold a a
average = (\s c -> s / c) <$> sum <*> genericLength
product :: (Num a) => Fold a a
product = Fold Product getProduct
std :: (Floating a) => Fold a a
std = (\ss s len -> sqrt (ss / len - (s / len)^2))
<$> sumSq
<*> sum
<*> genericLength