2009年1月30日星期五

Tips for Memoization

Introduction
We’ve been talking about functional programming quite a bit already. One of the things used frequently in functional programming is recursion, instead of imperative loop constructs. Both have their advantages, but often recursive techniques can cause significant degradations in performance. The prototypical sample is the computation of the Fibonacci sequence (a typical interview question, too). In mathematical terms, Fibonacci is expressed as:
fib : N –> N fib 1 = 1 fib 2 = 2 fib n = fib (n – 1) + fib (n – 2), n > 2
Translating this directly into functional style of code yields the following (C#):
Func fib = null; fib = n => n <= 2 ? 1 : fib(n - 1) + fib(n - 2); The reason we need to spread this across two lines is interesting by itself. If we don’t do this, the following error appears: memoize.cs(17,48): error CS0165: Use of unassigned local variable 'fib' referring to the highlighted position in the code: Func fib = n => n <= 2 ? 1 : fib(n - 1) + fib(n - 2); The reason this error pops up is because we’re defining a function in terms of itself, something that’s invalid in a variable declaration/assignment in C#, just like the following is invalid: int a = a + b; F# addresses this through the use of the rec keyword, but that’s a separate discussion. But what are we doing really when declaring the following? Func fib = null; fib = n => n <= 2 ? 1 : fib(n - 1) + fib(n - 2); Here’s the answer:


Notice the <>c__DisplayClass1, a closure. When assigning the lambda on the second line to fib, we’re capturing the fib variable itself as it appears in the lambda’s body. In more detail, this happens:


On lines IL_0007 to IL_0009 we store null as the value for fib, immediately replacing it on lines IL_000e to IL_001b with a new function retrieved from
b__0. This is where the code becomes self-referential:



As you can see on lines IL_0005 and IL_0013 we’re loading the variable we got assigned to by Main (but this code by itself doesn’t know that) in order to call it 4 lines further on, through the delegate. The rest of this code is a trivial translation of the ternary operator. Why is the interesting at all? It turns out this will be fairly important further on in this article as we’ll want to tweak this function.

What’s memoization?
Looking at our Fibonacci sequence sample again, try to imagine the call tree that results from a call like fib(10). Or let’s simplify it, consider Fib(5). Here’s the call tree:
Fib(5) Fib(4) Fib(3) Fib(2) Fib(1) Fib(2) Fib(3) Fib(2) Fib(1)
We’re calculating things over and over again. So how can we solve this? First of all, by embracing an imperative style, at the cost of the more declarative natural mapping of the original recursive definition:
uint Fib(uint n) { if (n <= 2) return 1; else { uint prev = 1; uint curr = 1; for (uint i = 0; i < t =" curr;" prev =" t;">

Measuring success
Before we claim things like “10 times better”, we should establish a baseline for comparison and create a mechanism to measure our success. As usual, we’ll rely on the System.Diagnostics.Stopwatch class to do so:
static void Test(Func fib) { Stopwatch sw = new Stopwatch(); sw.Start();
var res = from x in Range(1, 40, i => i + 1) select new { N = x, Value = fib(x) }; foreach (var i in res) Console.WriteLine("fib(" + i.N + ") = " + i.Value);
sw.Stop(); Console.WriteLine(sw.Elapsed); }
In here I’m using a generalization of Enumerable.Range I find useful (although here there’s no real need to range of uint for the input, our function could well be Func):
static IEnumerable Range(T from, T to, Func inc) where T : IComparable { for (T t = from; t.CompareTo(to) <= 0; t = inc(t)) { yield return t; } } Actually you’d call Range “For” instead and it becomes very apparent what it’s all about, isn’t it? Let’s take a look how our current implementation does: Func fib = null; fib = n => n <= 2 ? 1 : fib(n - 1) + fib(n - 2); Test(fib); Yes, it’s fine to say it: in one word terrible…


Injecting the memoizer
As mentioned before, our strategy to tackle this inefficiency will be to trade instructions for memory, essentially keeping a cache of calculated values in some kind of cache. The built-in collection type that’s ideal for this purpose is obviously the generic Dictionary in System.Collections.Generic. But how do we get it in our function definition seamlessly? In other words, given any function of arity 1 (meaning it takes in one argument, we’ll look at extending that further on), how can we sprinkle a little bit of memoization on top without changing the outer contract of the function? Here’s the code that allows us to preserve the signature but slice the memoizer in between the original function and the memoized one:

static Func Memoize(this Func f)
{
Dictionary cache = new Dictionary();
return t => {
if (cache.ContainsKey(t))
return cache[t];
else
return (cache[t] = f(t));
};
}

Actually this code can be optimized a little further using the TryGetValue method on the Dictionary class, and if you have more taste than me the else-block statement can be writter in a nicer way (if I was in a real evil mood, I’d have put it in a ternary operator conditional). I’ll leave such rewrites to the reader as an additional opportunity to express personal style :-).

Notice how the signature of the returned function is the same as the original on: that’s what makes our implementation seamless and transparent. I’m writing this as an extension method on Func, but there’s no need to do it that way. What’s more important though is how it works internally. Again you can see closures at work, because what we’ve really created here is something that looks like this:

class Memoizer
{
private Dictionary _cache = new Dictionary();
private Func _f;

internal Memoizer(Func f)
{
_f = f;
}

internal R Invoke(T t)
{
if (cache.ContainsKey(t))
return cache[t];
else
return (cache[t] = _f(t));
}
}

You can look at it as lifting an existing function into the memoizer (one per function as we need a unique cache on a function-per-function basis). Obviously you’ll need similar implementations for other function arities (including the zero-argument function, typically used for delayed computation scenarios). Here another issue pops up: the lack of Tuple types (with proper implementations for Equals and GetHashCode) that would be useful in such a case to express the dictionary’s key type. Even more, the debate on how much generic overloads to provide (Action, Func, Tuple, etc) enters the picture again. Unfortunately the type system isn’t rich enough to have a “Tuple”. At runtime there are ways to get around this, but then we enter dynamic meta-grounds again, so let’s not deviate from our path this time and keep that discussion for another time.

Putting it to the test
Back to our original code:

Func fib = null;
fib = n => n <= 2 ? 1 : fib(n - 1) + fib(n - 2); Test(fib); Easy as it seems you might think the following will do the trick: Func fib = null;
fib = n => n <= 2 ? 1 : fib(n - 1) + fib(n - 2); Test(fib.Memoize()); but unfortunately you won’t see any noticeable effect by doing so. Why? Take a closer look at what’s happening. The code above is equivalent to: Func fib = null;
fib = n => n <= 2 ? 1 : fib(n - 1) + fib(n - 2); Memoizer m = new Memoizer(fib);
Func fibm = new Func(m.Invoke);
Test(fibm);
Now we’re calling through fibm, which results in invoking the Invoke method on the (simplified) memoizer’s closure class. But look what we’re passing in to the constructor: the original fib instance, which really is a public field on another closure as explained in the introduction paragraph. So ultimately we’re just memoizing the “outer” fib function, not the “inner” recursive calls. How can we make the thing work as we expect it to? Remember from the introduction paragraph why we needed the following trick?

Func fib = null;
fib = n => n <= 2 ? 1 : fib(n - 1) + fib(n - 2); The generated code stores fib in a closure class field <>c__DisplayClass1::fib. In fact, there’s no such thing as a local variable fib in the IL-code; instead all occurrences of fib have been replaced by ldfld and stfld operations on the closure class’s field. But what’s more is that the closure class’s

b__0 method uses the same field for the recursive calls to fib (see the last figure in the introduction paragraph). That’s precisely what we need to know in order to make the memoizer work: if we assign the result of fib.Memoize() to fib again, we’re replacing the field value that’s also used in the recursive calls:

Func fib = null;
fib = n => n <= 2 ? 1 : fib(n - 1) + fib(n - 2); fib = fib.Memoize(); Test(fib); As a little quiz question: why can’t we write the following instead? Func fib = null;
fib = (n => n <= 2 ? 1 : fib(n - 1) + fib(n - 2)).Memoize(); Test(fib); And here’s the result:


Much better, actually much more than “10 times better”.

Additional quiz question

Would you be able to do all of this with expression trees, i.e. with the following hypothetical piece of code:

LambdaExpressiion> fibEx = null;
fibEx = n => n <= 2 ? 1 : fibEx(n - 1) + fibEx(n - 2); // what now? Test(fib); Why (not)?

没有评论:

发表评论