Haskell memoization
Today we’ll go through the process of writing a higher order function designed to memoize it’s argument.
While that probably sounds a bit too abstract, it will become clearer later. Hopefully :D
Our motivating example will be writing a function that calculates the n-th Fibonacci number.
Firstly, let’s write a fast enough implementation that we’ll use to test correctness
of our other implementations. This one will have a linear time complexity and will
be faster than our final memoized version.
Now, you might ask, why bother with doing all this if we’re going to introduce a FASTER
version right from the start?
There are several benefits. One of them is that our memoized version will match the
recursive definition almost completely meaning that our code won’t really have to adapt
to accommodate optimizations. Another benefit is that our memo
combinator will apply
to a whole class of functions.
So, here’s the fast implementation.
fib = (fibs !!)
where fibs = 1 : 1 : zipWith (+) fibs (tail fibs)
I won’t go into detail how this works. If you’re familiar with Haskell and it’s lazy semantics you can probably figure it out. In fact, I’m pretty sure most Haskell programmers have even seen this specific implementation.
Now we can start building up. Let’s write the naive version.
nfib 0 = 1
nfib 1 = 1
nfib n = nfib (n - 1) + nfib (n - 2)
This is a pretty great definition. We get to use pattern matching so our code resembles the
mathematical definition. Unfortunately, every call to nfib
spawns two more, so the
time complexity for this version is exponential. No bueno.
Let’s imagine we already have our magical memo
combinator available. How would we use it?
The first instinct might be something like this.
mfib = memo nfib
While that certainly does something, it doesn’t do what we want it to do. Every call to
mfib
with a number previously used will be fast, but how many calls to mfib
will we actually
do? The problem is that our nfib
function isn’t aware that there exists a memoized version
of it so even if we cache it’s return value, it doesn’t access that cache. It just recursively
calls itself.
After thinking a bit about this, we might come up with this version, which will be correct.
mfib = memo fib'
where fib' 0 = 1
fib' 1 = 1
fib' n = mfib (n - 1) + mfib (n - 2)
Here our memoized version actually wraps around a normal, recursive definition. The only
difference is that the recursive definition doesn’t make calls to fib'
but mfib
.
The cached version.
A thing to note here is that this means we won’t be able to memoize functions that we can’t
redefine. If there’s some function f
that we know recurses, we can’t make it faster.
That’s too bad, but it would require some serious hackery to fix that.
So our objective will be to define the function memo
. To see how we might do that, we’ll
look at yet another implementation of the fib
function. This one will be memoized.
lfib = (map fib' [0..] !!)
where fib' 0 = 1
fib' 1 = 1
fib' n = lfib (n - 1) + lfib (n - 2)
Now this is pretty good. It resembles our memo
version and it doesn’t perform half bad.
For me, it can do lfib 50000
in GHCi in about 15 seconds. Not to mention we could extract
our memo
function directly from here. It would look something like this.
memo f = (map f [0..] !!)
Pretty sweet. Using Haskell’s niftyness we managed to capture a pretty general concept
in a higher order function. Unfortunately, there’s always a ‘but’.
There are two things we can all will do better. Firstly, the type of our memo function is
(Num a1, Enum a1) => (a1 -> a) -> Int -> a
. For all intents and purposes, we can look at
it as (Int -> a) -> Int -> a
. This means it’s generic in the functions return value,
which is nice, but it can only memoize functions that operate on Ints.
Secondly, this memo
implementation is still too slow. It requires a traversal of the
cache list every time we want to get a value out because lists have linear time indexing.
The complexity of our lfib
function is quadratic or so.
Still, this will serve as a motivation for our actual, real implementation. Notice how it works. It maps our function over all possible inputs and stores the results. However, because of non-strict evaluation it doesn’t really do anything until it absolutely needs to. This is what allows us to use an infinite structure to cache our results. We’ll never evaluate the 1000th element of the list if we only calculate the first 999.
The same idea can be applied to any data structure in Haskell. We can define them recursively and Haskell will just say “Yeah, sure. Whatevs’”, treating our definition like instructions on how to generate any specific part of that structure.
To avoid the linear access time in our structure, we’ll use a binary tree instead of a list. Trees have nice, logarithmic time, access properties that we want.
The general idea is to somehow establish a mapping from the functions argument type, to the
set of nodes of our tree. Each unique argument should have a unique node assigned to keep
the result value.
The way I chose to do this, and I’m sure there are other, better, options, is to introduce
a new typeclass BitsBijective
that represents types that support serialization into and
from a list of boolean values.
Here’s my tree definition, along with the typeclass definition.
data Memo a b = Fork (Memo a b) b (Memo a b)
deriving Show
type Bit = Bool
newtype Bits = Bits [Bit]
instance Show Bits where
show (Bits b) = "[" ++ map (\x -> if x then '1' else '0') b ++ "]"
{-
Laws:
toBits . fromBits = fromBits . toBits = id
∀n
zeros := replicate n False
∀x fromBits x = fromBits (x ++ zeros)
-}
class BitsBijective a where
toBits :: a -> Bits
fromBits :: Bits -> a
Some constraints that should be satisfied by all instances are written down in the comment.
We can make Ints an instance of our class relatively easily.
instance BitsBijective Int where
toBits n = Bits $ map (testBit n) [0..finiteBitSize n - 1]
fromBits (Bits b) = foldl setBit 0 $ map fst $ filter snd $ zip [0..] b
And using two helper functions we can make any pair of BitsBijective
values also
BitsBijective
.
interleave :: Bits -> Bits -> Bits
interleave (Bits l) (Bits r) = Bits $ interleave' l r
where interleave' [] [] = []
interleave' [] ys = False : interleave' ys []
interleave' (x : xs) ys = x : interleave' ys xs
uninterleave :: Bits -> (Bits, Bits)
uninterleave (Bits b) = (Bits l, Bits r)
where (l, r) = uninterleave' b
uninterleave' [] = ([], [])
uninterleave' (x : xs) = let (rs, ls) = uninterleave' xs in (x : ls, rs)
instance (BitsBijective a, BitsBijective b) => BitsBijective (a, b) where
toBits (toBits -> l, toBits -> r) = interleave l r
fromBits (uninterleave -> (l, r)) = (fromBits l, fromBits r)
So, what do we get from this? Well, we’ll require that the functions we memoize have
arguments of a BitsBijective
type. While this is kind of restrictive, it’s already
better than our first attempt since Ints are BitsBijective
AND we can write instances
for other types. But why impose that restriction?
It’s got everything to do with the
way we’ll store the data in our tree. We’ll turn the argument into ones and zeros,
and then, starting at the root of the tree, we’ll take a left turn for each zero,
and take a right turn for each one. This will lead us to the node that will have
the result of our function stored.
One caveat to watch out for is one of the laws. We ignore the trailing zeros. This
enables us to serialize pairs of values that don’t necessarily have the same number
of bits, and then still be able to reconstruct both of them. Here’s the final memo
implementation. Notice that as we accumulate the path to each node when constructing the tree,
we have to reverse it before getting the actual value because we keep PREpending
each turn we take. That’s not ideal, but it’s good enough.
memo :: BitsBijective a => (a -> b) -> a -> b
memo f = readMemo
where memo' l = Fork (memo' $ False : l) (f $ fromBits $ Bits $ reverse l) (memo' $ True : l)
tree = memo' []
readMemo x = followPath b tree
where (Bits (ker -> b)) = toBits x
followPath [] (Fork _ y _) = y
followPath (False : xs) (Fork y _ _) = followPath xs y
followPath (True : xs) (Fork _ _ y) = followPath xs y
ker xs | null ones = []
| otherwise = take (1 + fst (last ones)) xs
where ones = filter snd $ zip [0..] xs
And that’s it! Now our original mfib
definition works, and it works great!
For me it can calculate the 50000th number in just under 2 seconds. And that’s
in GHCi.
The final complexity should be somewhere around n * m
where n is the index and
m is the number of bits in our argument. We can also test correctness using our
original implementation. The first 1000 numbers match, which is proof enough for me.
Here’s a link to the complete version.
Well, I hope this wasn’t too dry of a read. I’ve certainly had fun writing it and about it. If you’re got comments, I’d be glad to hear them.