diff --git a/web/func_equal.go b/web/func_equal.go new file mode 100644 index 0000000..3206b04 --- /dev/null +++ b/web/func_equal.go @@ -0,0 +1,38 @@ +package web + +import ( + "reflect" +) + +func isFunc(fn interface{}) bool { + return reflect.ValueOf(fn).Kind() == reflect.Func +} + +/* +This is more than a little sketchtacular. Go's rules for function pointer +equality are pretty restrictive: nil function pointers always compare equal, and +all other pointer types never do. However, this is pretty limiting: it means +that we can't let people reference the middleware they've given us since we have +no idea which function they're referring to. + +To get better data out of Go, we sketch on the representation of interfaces. We +happen to know that interfaces are pairs of pointers: one to the real data, one +to data about the type. Therefore, two interfaces, including two function +interface{}'s, point to exactly the same objects iff their interface +representations are identical. And it turns out this is sufficient for our +purposes. + +If you're curious, you can read more about the representation of functions here: +http://golang.org/s/go11func +We're in effect comparing the pointers of the indirect layer. +*/ +func funcEqual(a, b interface{}) bool { + if !isFunc(a) || !isFunc(b) { + panic("funcEqual: type error!") + } + + av := reflect.ValueOf(&a).Elem() + bv := reflect.ValueOf(&b).Elem() + + return av.InterfaceData() == bv.InterfaceData() +} diff --git a/web/func_equal_test.go b/web/func_equal_test.go new file mode 100644 index 0000000..daf8d9a --- /dev/null +++ b/web/func_equal_test.go @@ -0,0 +1,84 @@ +package web + +import ( + "testing" +) + +// To tell you the truth, I'm not actually sure how many of these cases are +// needed. Presumably someone with more patience than I could comb through +// http://golang.org/s/go11func and figure out what all the different cases I +// ought to test are, but I think this test includes all the cases I care about +// and is at least reasonably thorough. + +func a() string { + return "A" +} +func b() string { + return "B" +} +func mkFn(s string) func() string { + return func() string { + return s + } +} + +var c = mkFn("C") +var d = mkFn("D") +var e = a +var f = c +var g = mkFn("D") + +type Type string + +func (t *Type) String() string { + return string(*t) +} + +var t1 = Type("hi") +var t2 = Type("bye") +var t1f = t1.String +var t2f = t2.String + +var funcEqualTests = []struct { + a, b func() string + result bool +}{ + {a, a, true}, + {a, b, false}, + {b, b, true}, + {a, c, false}, + {c, c, true}, + {c, d, false}, + {a, e, true}, + {a, f, false}, + {c, f, true}, + {e, f, false}, + {d, g, false}, + {t1f, t1f, true}, + {t1f, t2f, false}, +} + +func TestFuncEqual(t *testing.T) { + t.Parallel() + + for _, test := range funcEqualTests { + r := funcEqual(test.a, test.b) + if r != test.result { + t.Errorf("funcEqual(%v, %v) should have been %v", + test.a, test.b, test.result) + } + } + h := mkFn("H") + i := h + j := mkFn("H") + k := a + if !funcEqual(h, i) { + t.Errorf("h and i should have been equal") + } + if funcEqual(h, j) { + t.Errorf("h and j should not have been equal") + } + if !funcEqual(a, k) { + t.Errorf("a and k should not have been equal") + } +}