Tuesday, September 26, 2017

Type-driven strictness

I was recently trying to optimize Dhall's performance because the interpreter was performing poorly on some simple examples.

For example, consider this tiny expression that prints 3000 exclamation marks:

Natural/fold +3000 Text (λ(x : Text)  x ++ "!") ""

The above Dhall expression takes over 14 seconds to evaluate to normal form, which is not acceptable:

$ bench 'dhall <<< './exclaim'
benchmarking dhall <<< ./exclaim
time                 14.42 s    (14.23 s .. 14.57 s)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 14.62 s    (14.54 s .. 14.66 s)
std dev              70.22 ms   (0.0 s .. 77.84 ms)
variance introduced by outliers: 19% (moderately inflated)

Strict

The performance suffers because Dhall is lazy in the accumulator, meaning that the accumulator builds up a gigantic expression like this:

(λ(x : Text)  x ++ "!")
( (λ(x : Text)  x ++ "!")
  ( (λ(x : Text)  x ++ "!")
    ( (λ(x : Text)  x ++ "!")
      ( (λ(x : Text)  x ++ "!")
        ...
        {Repeat this nesting 3000 times}
        ...
        ""
      )
    )
  )
)

... and then attempts to normalize the entire expression, which takes a long time and wastes a lot of memory.

The reason for this lazy behavior is the following code for evaluating Natural/fold in the Dhall interpreter:

normalize (App (App (App (App NaturalFold (NaturalLit n0)) _) succ') zero)
    = normalize (go n0)
  where
    go !0 = zero
    go !n = App succ' (go (n - 1))

You can read that as saying:

  • Given an expression of the form Natural/Fold n succ zero
  • Wrap the value zero in n calls of the function succ
    • i.e. succ (succ (succ (... {n times} ... (succ zero) ...)))
  • Then normalize that

A smarter approach would be to keep the accumulator strict, which means that we evaluate as we go instead of deferring all evaluation to the end. For example, the accumulator starts off as just the empty string:

""

... then after one iteration of the loop we get the following accumulator:

(λ(x : Text)  x ++ "!") ""

... and if we evaluate that accumulator immediately we get:

"!"

Then the next iteration of the loop produces the following accumulator:

(λ(x : Text)  x ++ "!") "!"

... which we can again immediately evaluate to get:

"!!"

This is significantly more efficient than leaving the expression unevaluated.

We can easily implement such a strict loop by making the following change to the interpreter:

normalize (App (App (App (App NaturalFold (NaturalLit n0)) _) succ') zero)
    = go n0
  where
    go !0 = normalize zero
    go !n = normalize (App succ' (go (n - 1)))

The difference here is that we still build up a chain of n calls to succ but now we normalize our expression in between each call to the succ function instead of waiting until the end to normalize.

Once we do this runtime improves dramatically, going down from 15 seconds to 90 milliseconds:

$ bench 'dhall <<< './example'
benchmarking dhall <<< ./example
time                 88.92 ms   (87.14 ms .. 90.74 ms)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 86.06 ms   (84.98 ms .. 87.15 ms)
std dev              1.734 ms   (1.086 ms .. 2.504 ms)

... or in other words about 30 microseconds per element. We could still do more to optimize this but at least we're now in the right ballpark for an interpreter. For reference, Python is 4x faster on my machine for the following equivalent program:

print ("!" * 3000)
$ bench 'python exclaim.py'
benchmarking python exclaim.py
time                 24.55 ms   (24.09 ms .. 25.02 ms)
                     0.998 R²   (0.996 R² .. 1.000 R²)
mean                 24.53 ms   (24.16 ms .. 24.88 ms)
std dev              798.4 μs   (559.8 μs .. 1.087 ms)

However, these results don't necessarily imply that a strict accumulator is always better.

Lazy

Sometimes laziness is more efficient, though. Consider this program:

List/build
Integer
(   λ(list : Type)
   λ(cons : Integer  list  list)
   Natural/fold +6000 list (cons 1)
)

The above example uses Natural/fold to build a list of 6000 1s.

In this case the accumulator of the fold is a list that grows by one element after each step of the fold. We don't want to normalize the list on each iteration because that would lead to quadratic time complexity. Instead we prefer to defer normalization to the end of the loop so that we get linear time complexity.

We can measure the difference pretty easily. A strict loop takes over 6 seconds to complete:

bench 'dhall <<< ./ones'
benchmarking dhall <<< ./ones
time                 6.625 s    (6.175 s .. 7.067 s)
                     0.999 R²   (0.998 R² .. 1.000 R²)
mean                 6.656 s    (6.551 s .. 6.719 s)
std dev              95.98 ms   (0.0 s .. 108.3 ms)
variance introduced by outliers: 19% (moderately inflated)

... whereas a lazy loop completes in about 180 milliseconds:

$ bench 'dhall <<< ./g'
benchmarking dhall <<< ./g
time                 182.5 ms   (175.1 ms .. 191.3 ms)
                     0.998 R²   (0.995 R² .. 1.000 R²)
mean                 177.5 ms   (172.1 ms .. 180.8 ms)
std dev              5.589 ms   (2.674 ms .. 8.273 ms)
variance introduced by outliers: 12% (moderately inflated)

Moreover, the difference in performance will only worsen with larger list sizes due to the difference in time complexity.

Also, in case you were wondering, Python is about 7x faster:

print ([1] * 6000)
$ bench 'python ones.py'
benchmarking python ones.py
time                 25.36 ms   (24.75 ms .. 25.92 ms)
                     0.998 R²   (0.996 R² .. 0.999 R²)
mean                 25.64 ms   (25.16 ms .. 26.03 ms)
std dev              917.8 μs   (685.7 μs .. 1.348 ms)

Why not both?

This poses a conundrum because we'd like to efficiently support both of these use cases. How can we know when to be lazy or strict?

We can use Dhall's type system to guide whether or not we keep the accumulator strict. We already have access to the type of the accumulator for our loop, so we can define a function that tells us if our accumulator type is compact or not:

compact :: Expr s a -> Bool
compact Bool             = True
compact Natural          = True
compact Integer          = True
compact Double           = True
compact Text             = True
compact (App List _)     = False
compact (App Optional t) = compact t
compact (Record kvs)     = all compact kvs
compact (Union kvs)      = all compact kvs
compact _                = False

You can read this function as saying:

  • primitive types are compact
  • lists are not compact
  • optional types are compact if the corresponding non-optional type is compact
  • a record type is compact if all the field types are compact
  • a union type is compact if all the alternative types are compact

Now, all we need to do is modify our fold logic to use this compact function to decide whether or not we use a strict or lazy loop:

normalize (App (App (App (App NaturalFold (NaturalLit n0)) t) succ') zero) =
    if compact (normalize t) then strict else lazy
  where
    strict =            strictLoop n0
    lazy   = normalize (  lazyLoop n0)

    strictLoop !0 = normalize zero
    strictLoop !n = normalize (App succ' (strictLoop (n - 1)))

    lazyLoop !0 = zero
    lazyLoop !n = App succ' (lazyLoop (n - 1))

Now we get the best of both worlds and our interpreter gives excellent performance in both of the above examples.

Fizzbuzz

Here's a bonus example for people who got this far!

The original Dhall expression that motivated this post was an attempt to implement FizzBuzz in Dhall. The program I ended up writing was:

    let pred =
            λ(n : Natural)
               let start = { next = +0, prev = +0 }
            
            in  let step =
                        λ(state : { next : Natural, prev : Natural })
                       { next = state.next + +1, prev = state.next }
            
            in  let result =
                      Natural/fold
                      n
                      { next : Natural, prev : Natural }
                      step
                      start
            
            in  result.prev

in  let not = λ(b : Bool)  b == False

in  let succ =
            λ ( state
              : { buzz : Natural, fizz : Natural, index : Natural, text : Text }
              )
               let fizzy = Natural/isZero state.fizz
            
            in  let buzzy = Natural/isZero state.buzz
            
            in  let line =
                            if fizzy && buzzy then "FizzBuzz"
                      
                      else  if fizzy && not buzzy then "Fizz"
                      
                      else  if not fizzy && buzzy then "Buzz"
                      
                      else  Integer/show (Natural/toInteger state.index)
            
            in  { buzz  = pred (if buzzy then +5 else state.buzz)
                , fizz  = pred (if fizzy then +3 else state.fizz)
                , index = state.index + +1
                , text  = state.text ++ line ++ "\n"
                }

in  let zero = { buzz = +5, fizz = +3, index = +0, text = "" }

in  let fizzBuzz =
            λ(n : Natural)
               let result =
                      Natural/fold
                      n
                      { buzz  : Natural
                      , fizz  : Natural
                      , index : Natural
                      , text  : Text
                      }
                      succ
                      zero
            
            in  result.text

in  fizzBuzz

However, this program runs incredibly slowly, taking over 7 seconds just to compute 20 elements:

bench 'dhall <<< "./fizzbuzz +20"'
benchmarking dhall <<< "./fizzbuzz +20"
time                 7.450 s    (7.194 s .. 7.962 s)
                     0.999 R²   (NaN R² .. 1.000 R²)
mean                 7.643 s    (7.512 s .. 7.739 s)
std dev              145.0 ms   (0.0 s .. 165.6 ms)
variance introduced by outliers: 19% (moderately inflated)

However, if you use a strict fold then the program takes half a second to go through 10,000 elements:

$ bench 'dhall <<< "./fizzbuzz +10000"'
benchmarking dhall <<< "./fizzbuzz +10000"
time                 591.5 ms   (567.3 ms .. NaN s)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 583.4 ms   (574.0 ms .. 588.8 ms)
std dev              8.418 ms   (0.0 s .. 9.301 ms)
variance introduced by outliers: 19% (moderately inflated)

Conclusion

Many people associate dynamic languages with interpreters, but Dhall is an example of a statically typed interpreter. Dhall's evaluator is not sophisticated at all but can still take advantage of static type information to achieve comparable performance with Python (which is a significantly more mature interpreter). This makes me wonder if the next generation of interpreters will be statically typed in order to enable better optimizations.

No comments:

Post a Comment