-
-
Save DanielHeath/7744196 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package memoize | |
import ( | |
"fmt" | |
"reflect" | |
) | |
// fptr is a pointer to a function variable which will receive a | |
// memoized wrapper around function impl. Impl must have 1 or more | |
// arguments, all of which must be usable as map keys; and it must | |
// have 1 or more return values. | |
func Memoize(fptr, impl interface{}) { | |
implType := reflect.TypeOf(impl) | |
implValue := reflect.ValueOf(impl) | |
if implType.Kind() != reflect.Func { | |
panic(fmt.Sprintf("Not a function: %v", impl)) | |
} | |
if implType.NumIn() == 0 { | |
panic(fmt.Sprintf("%v takes no inputs", impl)) | |
} | |
if implType.NumOut() == 0 { | |
panic(fmt.Sprintf("%v gives no outputs", impl)) | |
} | |
if !reflect.PtrTo(implType).AssignableTo(reflect.TypeOf(fptr)) { | |
panic(fmt.Sprintf("Can't assign %v to %v", impl, fptr)) | |
} | |
var resultTypes []reflect.Type | |
for on := 0; on < implType.NumOut(); on++ { | |
out := implType.Out(on) | |
resultTypes = append(resultTypes, out) | |
} | |
mapTypes := make([]reflect.Type, implType.NumIn()) | |
mapType := reflect.TypeOf([]reflect.Value{}) | |
mapTypes[len(mapTypes)-1] = mapType | |
for in := implType.NumIn() - 1; in >= 0; in-- { | |
inType := implType.In(in) | |
mapType = reflect.MapOf(inType, mapType) | |
mapTypes[in] = mapType | |
} | |
m := reflect.MakeMap(mapTypes[0]) | |
mem := func(args []reflect.Value) []reflect.Value { | |
thisMap := m | |
for an := 0; an < len(args)-1; an++ { | |
v := thisMap.MapIndex(args[an]) | |
if !v.IsValid() { | |
v = reflect.MakeMap(mapTypes[an+1]) | |
thisMap.SetMapIndex(args[an], v) | |
} | |
thisMap = v | |
} | |
an := len(args) - 1 | |
v := thisMap.MapIndex(args[an]) | |
var vs []reflect.Value | |
if v.IsValid() { | |
for i := 0; i < v.Len(); i++ { | |
// v.Index() gives us a Value for | |
// Value for int. We need a Value for | |
// int. | |
valval := v.Index(i) | |
val := deVal(valval).(reflect.Value) | |
vs = append(vs, val) | |
} | |
} else { | |
vs = implValue.Call(args) | |
thisMap.SetMapIndex(args[an], reflect.ValueOf(vs)) | |
} | |
return vs | |
} | |
typedMem := reflect.MakeFunc(implType, mem) | |
reflect.ValueOf(fptr).Elem().Set(typedMem) | |
} | |
func deVal(val reflect.Value) interface{} { | |
var result interface{} | |
inner := func(v interface{}) { | |
result = v | |
} | |
reflect.ValueOf(inner).Call([]reflect.Value{val}) | |
return result | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package memoize | |
import ( | |
"testing" | |
) | |
var intData = []struct { | |
in int | |
calls int | |
}{ | |
{0, 1}, | |
{1, 2}, | |
{0, 2}, | |
{2, 3}, | |
{2, 3}, | |
} | |
func TestInt1(t *testing.T) { | |
numCalls := 0 | |
var f1 func(int) int | |
f2 := func(i int) int { | |
numCalls++ | |
return i | |
} | |
Memoize(&f1, f2) | |
for _, d := range intData { | |
out := f1(d.in) | |
if out != d.in { | |
t.Errorf("Got %d, want %d", out, d.in) | |
} | |
if numCalls != d.calls { | |
t.Errorf("Num calls = %d, want %d", numCalls, d.calls) | |
} | |
} | |
} | |
func TestReassign(t *testing.T) { | |
numCalls := 0 | |
f := func(i int) int { | |
numCalls++ | |
return i | |
} | |
Memoize(&f, f) | |
for _, d := range intData { | |
out := f(d.in) | |
if out != d.in { | |
t.Errorf("Got %d, want %d", out, d.in) | |
} | |
if numCalls != d.calls { | |
t.Errorf("Num calls = %d, want %d", numCalls, d.calls) | |
} | |
} | |
} | |
var int2Data = []struct { | |
in1, in2 int | |
calls int | |
}{ | |
{0, 0, 1}, | |
{0, 1, 2}, | |
{1, 0, 3}, | |
{0, 0, 3}, | |
{0, 1, 3}, | |
{1, 0, 3}, | |
} | |
func TestInt2(t *testing.T) { | |
numCalls := 0 | |
f := func(a, b int) int { | |
numCalls++ | |
return a + b | |
} | |
Memoize(&f, f) | |
for _, d := range int2Data { | |
out := f(d.in1, d.in2) | |
if out != d.in1+d.in2 { | |
t.Errorf("Got %d, want %d + %d", out, d.in1, d.in2) | |
} | |
if numCalls != d.calls { | |
t.Errorf("Num calls = %d, want %d", numCalls, d.calls) | |
} | |
} | |
} | |
var int22Data = []struct { | |
in1, in2 int | |
calls int | |
}{ | |
{0, 0, 1}, | |
{1, 0, 2}, | |
{0, 1, 3}, | |
{0, 1, 3}, | |
{1, 0, 3}, | |
{0, 0, 3}, | |
} | |
func TestInt22(t *testing.T) { | |
numCalls := 0 | |
f := func(a, b int) (int, int) { | |
numCalls++ | |
return b, a | |
} | |
Memoize(&f, f) | |
for _, d := range int22Data { | |
out1, out2 := f(d.in1, d.in2) | |
if out1 != d.in2 || out2 != d.in1 { | |
t.Errorf("Got (%d, %d) from (%d, %d)", out1, out2, | |
d.in1, d.in2) | |
} | |
if numCalls != d.calls { | |
t.Errorf("Num calls = %d, want %d", numCalls, d.calls) | |
} | |
} | |
} | |
var mixedData = []struct { | |
in1 int | |
in2 string | |
calls int | |
}{ | |
{0, "zero", 1}, | |
{1, "zero", 2}, | |
{0, "one", 3}, | |
{1, "one", 4}, | |
{0, "zero", 4}, | |
{1, "zero", 4}, | |
{0, "one", 4}, | |
{1, "one", 4}, | |
} | |
func TestMixed(t *testing.T) { | |
numCalls := 0 | |
f := func(a int, b string) (string, int) { | |
numCalls++ | |
return b, a | |
} | |
Memoize(&f, f) | |
for _, d := range mixedData { | |
out1, out2 := f(d.in1, d.in2) | |
if out1 != d.in2 || out2 != d.in1 { | |
t.Errorf("Got (%s, %d) from (%d, %s)", out1, out2, | |
d.in1, d.in2) | |
} | |
if numCalls != d.calls { | |
t.Errorf("Num calls = %d, want %d", numCalls, d.calls) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment