From f67c7d59eeac7ddd36dc88b24c0b251c268a70d2 Mon Sep 17 00:00:00 2001 From: Carl Jackson Date: Sat, 1 Mar 2014 15:55:58 -0800 Subject: [PATCH] Parameter parsing Package param implements parameter parsing into a target struct (in much the same way as encoding/json parses JSON into a struct). It targets the common jQuery.param / Ruby on Rails style parameter serialization format. --- param/crazy_test.go | 56 +++++ param/error_helpers.go | 34 +++ param/param.go | 60 +++++ param/param_test.go | 505 +++++++++++++++++++++++++++++++++++++++++ param/parse.go | 214 +++++++++++++++++ param/pebkac_test.go | 58 +++++ param/struct.go | 117 ++++++++++ param/struct_test.go | 106 +++++++++ 8 files changed, 1150 insertions(+) create mode 100644 param/crazy_test.go create mode 100644 param/error_helpers.go create mode 100644 param/param.go create mode 100644 param/param_test.go create mode 100644 param/parse.go create mode 100644 param/pebkac_test.go create mode 100644 param/struct.go create mode 100644 param/struct_test.go diff --git a/param/crazy_test.go b/param/crazy_test.go new file mode 100644 index 0000000..46538cd --- /dev/null +++ b/param/crazy_test.go @@ -0,0 +1,56 @@ +package param + +import ( + "net/url" + "testing" +) + +type Crazy struct { + A *Crazy + B *Crazy + Value int + Slice []int + Map map[string]Crazy +} + +func TestCrazy(t *testing.T) { + t.Parallel() + + c := Crazy{} + err := Parse(url.Values{ + "A[B][B][A][Value]": {"1"}, + "B[A][A][Slice][]": {"3", "1", "4"}, + "B[Map][hello][A][Value]": {"8"}, + "A[Value]": {"2"}, + "A[Slice][]": {"9", "1", "1"}, + "Value": {"42"}, + }, &c) + if err != nil { + t.Error("Error parsing craziness: ", err) + } + + // Exhaustively checking everything here is going to be a huge pain, so + // let's just hope for the best, pretend NPEs don't exist, and hope that + // this test covers enough stuff that it's actually useful. + assertEqual(t, "c.A.B.B.A.Value", 1, c.A.B.B.A.Value) + assertEqual(t, "c.A.Value", 2, c.A.Value) + assertEqual(t, "c.Value", 42, c.Value) + assertEqual(t, `c.B.Map["hello"].A.Value`, 8, c.B.Map["hello"].A.Value) + + assertEqual(t, "c.A.B.B.B", (*Crazy)(nil), c.A.B.B.B) + assertEqual(t, "c.A.B.A", (*Crazy)(nil), c.A.B.A) + assertEqual(t, "c.A.A", (*Crazy)(nil), c.A.A) + + if c.Slice != nil || c.Map != nil { + t.Error("Map and Slice should not be set") + } + + sl := c.B.A.A.Slice + if len(sl) != 3 || sl[0] != 3 || sl[1] != 1 || sl[2] != 4 { + t.Error("Something is wrong with c.B.A.A.Slice") + } + sl = c.A.Slice + if len(sl) != 3 || sl[0] != 9 || sl[1] != 1 || sl[2] != 1 { + t.Error("Something is wrong with c.A.Slice") + } +} diff --git a/param/error_helpers.go b/param/error_helpers.go new file mode 100644 index 0000000..8033c9a --- /dev/null +++ b/param/error_helpers.go @@ -0,0 +1,34 @@ +package param + +import ( + "errors" + "fmt" + "log" +) + +// TODO: someday it might be nice to throw typed errors instead of weird strings + +// Testing log.Fatal in tests is... not a thing. Allow tests to stub it out. +var pebkacTesting bool + +const errPrefix = "param/parse: " +const yourFault = " This is a bug in your use of the param library." + +// Panic with a formatted error. The param library uses panics to quickly unwind +// the call stack and return a user error +func perr(format string, a ...interface{}) { + err := errors.New(errPrefix + fmt.Sprintf(format, a...)) + panic(err) +} + +// Problem exists between keyboard and chair. This function is used in cases of +// programmer error, i.e. an inappripriate use of the param library, to +// immediately force the program to halt with a hopefully helpful error message. +func pebkac(format string, a ...interface{}) { + err := errors.New(errPrefix + fmt.Sprintf(format, a...) + yourFault) + if pebkacTesting { + panic(err) + } else { + log.Fatal(err) + } +} diff --git a/param/param.go b/param/param.go new file mode 100644 index 0000000..0a1d8f5 --- /dev/null +++ b/param/param.go @@ -0,0 +1,60 @@ +/* +Package param deserializes parameter values into the given struct using magical +reflection ponies. Inspired by gorilla/schema, but uses Rails/jQuery style param +encoding instead of their weird dotted syntax. In particular, this package was +written with the intent of parsing the output of jQuery.param. + +This package uses struct tags to guess what names things ought to have. If a +struct value has a "param" tag defined, it will use that. If there is no "param" +tag defined, the name part of the "json" tag will be used. If that is not +defined, the name of the field itself will be used (no case transformation is +performed). + +If the name derived in this way is the string "-", param will refuse to set that +value. + +The parser is extremely strict, and will return an error if it has any +difficulty whatsoever in parsing any parameter, or if there is any kind of type +mismatch. +*/ +package param + +import ( + "net/url" + "reflect" + "strings" +) + +// Parse the given arguments into the the given pointer to a struct object. +func Parse(params url.Values, target interface{}) (err error) { + v := reflect.ValueOf(target) + + defer func() { + if r := recover(); r != nil { + var ok bool + err, ok = r.(error) + if !ok { + panic(err) + } + } + }() + + if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { + pebkac("Target of param.Parse must be a pointer to a struct. "+ + "We instead were passed a %v", v.Type()) + } + + el := v.Elem() + t := el.Type() + cache := cacheStruct(t) + + for key, values := range params { + sk, keytail := key, "" + if i := strings.IndexRune(key, '['); i != -1 { + sk, keytail = sk[:i], sk[i:] + } + parseStructField(cache, key, sk, keytail, values, el) + } + + return nil +} diff --git a/param/param_test.go b/param/param_test.go new file mode 100644 index 0000000..48f5e42 --- /dev/null +++ b/param/param_test.go @@ -0,0 +1,505 @@ +package param + +import ( + "net/url" + "reflect" + "strings" + "testing" + "time" +) + +type Everything struct { + Bool bool + Int int + Uint uint + Float float64 + Map map[string]int + Slice []int + String string + Struct Sub + Time time.Time + + PBool *bool + PInt *int + PUint *uint + PFloat *float64 + PMap *map[string]int + PSlice *[]int + PString *string + PStruct *Sub + PTime *time.Time + + PPInt **int + + ABool MyBool + AInt MyInt + AUint MyUint + AFloat MyFloat + AMap MyMap + APtr MyPtr + ASlice MySlice + AString MyString +} + +type Sub struct { + A int + B int +} + +type MyBool bool +type MyInt int +type MyUint uint +type MyFloat float64 +type MyMap map[MyString]MyInt +type MyPtr *MyInt +type MySlice []MyInt +type MyString string + +var boolAnswers = map[string]bool{ + "true": true, + "false": false, + "0": false, + "1": true, + "on": true, + "": false, +} + +var testTimeString = "1996-12-19T16:39:57-08:00" +var testTime time.Time + +func init() { + testTime, _ = time.Parse(time.RFC3339, testTimeString) +} + +func singletonErrors(t *testing.T, field, valid, invalid string) { + e := Everything{} + + err := Parse(url.Values{field: {invalid}}, &e) + if err == nil { + t.Errorf("Expected error parsing %q as %s", invalid, field) + } + + err = Parse(url.Values{field + "[]": {valid}}, &e) + if err == nil { + t.Errorf("Expected error parsing nested %s", field) + } + + err = Parse(url.Values{field + "[nested]": {valid}}, &e) + if err == nil { + t.Errorf("Expected error parsing nested %s", field) + } + + err = Parse(url.Values{field: {valid, valid}}, &e) + if err == nil { + t.Errorf("Expected error passing %s twice", field) + } +} + +func TestBool(t *testing.T) { + t.Parallel() + + for val, correct := range boolAnswers { + e := Everything{} + e.Bool = !correct + err := Parse(url.Values{"Bool": {val}}, &e) + if err != nil { + t.Error("Parse error on key: ", val) + } + assertEqual(t, "e.Bool", correct, e.Bool) + } +} + +func TestBoolTyped(t *testing.T) { + t.Parallel() + + e := Everything{} + err := Parse(url.Values{"ABool": {"true"}}, &e) + if err != nil { + t.Error("Parse error for typed bool") + } + assertEqual(t, "e.ABool", MyBool(true), e.ABool) +} + +func TestBoolErrors(t *testing.T) { + t.Parallel() + singletonErrors(t, "Bool", "true", "llama") +} + +var intAnswers = map[string]int{ + "0": 0, + "9001": 9001, + "-42": -42, +} + +func TestInt(t *testing.T) { + t.Parallel() + + for val, correct := range intAnswers { + e := Everything{} + e.Int = 1 + err := Parse(url.Values{"Int": {val}}, &e) + if err != nil { + t.Error("Parse error on key: ", val) + } + assertEqual(t, "e.Int", correct, e.Int) + } +} + +func TestIntTyped(t *testing.T) { + t.Parallel() + + e := Everything{} + err := Parse(url.Values{"AInt": {"1"}}, &e) + if err != nil { + t.Error("Parse error for typed int") + } + assertEqual(t, "e.AInt", MyInt(1), e.AInt) +} + +func TestIntErrors(t *testing.T) { + t.Parallel() + singletonErrors(t, "Int", "1", "llama") + + e := Everything{} + err := Parse(url.Values{"Int": {"4.2"}}, &e) + if err == nil { + t.Error("Expected error parsing float as int") + } +} + +var uintAnswers = map[string]uint{ + "0": 0, + "9001": 9001, +} + +func TestUint(t *testing.T) { + t.Parallel() + + for val, correct := range uintAnswers { + e := Everything{} + e.Uint = 1 + err := Parse(url.Values{"Uint": {val}}, &e) + if err != nil { + t.Error("Parse error on key: ", val) + } + assertEqual(t, "e.Uint", correct, e.Uint) + } +} + +func TestUintTyped(t *testing.T) { + t.Parallel() + + e := Everything{} + err := Parse(url.Values{"AUint": {"1"}}, &e) + if err != nil { + t.Error("Parse error for typed uint") + } + assertEqual(t, "e.AUint", MyUint(1), e.AUint) +} + +func TestUintErrors(t *testing.T) { + t.Parallel() + singletonErrors(t, "Uint", "1", "llama") + + e := Everything{} + err := Parse(url.Values{"Uint": {"4.2"}}, &e) + if err == nil { + t.Error("Expected error parsing float as uint") + } + + err = Parse(url.Values{"Uint": {"-42"}}, &e) + if err == nil { + t.Error("Expected error parsing negative number as uint") + } +} + +var floatAnswers = map[string]float64{ + "0": 0, + "9001": 9001, + "-42": -42, + "9001.0": 9001.0, + "4.2": 4.2, + "-9.000001": -9.000001, +} + +func TestFloat(t *testing.T) { + t.Parallel() + + for val, correct := range floatAnswers { + e := Everything{} + e.Float = 1 + err := Parse(url.Values{"Float": {val}}, &e) + if err != nil { + t.Error("Parse error on key: ", val) + } + assertEqual(t, "e.Float", correct, e.Float) + } +} + +func TestFloatTyped(t *testing.T) { + t.Parallel() + + e := Everything{} + err := Parse(url.Values{"AFloat": {"1.0"}}, &e) + if err != nil { + t.Error("Parse error for typed float") + } + assertEqual(t, "e.AFloat", MyFloat(1.0), e.AFloat) +} + +func TestFloatErrors(t *testing.T) { + t.Parallel() + singletonErrors(t, "Float", "1.0", "llama") +} + +func TestMap(t *testing.T) { + t.Parallel() + e := Everything{} + + err := Parse(url.Values{ + "Map[one]": {"1"}, + "Map[two]": {"2"}, + "Map[three]": {"3"}, + }, &e) + if err != nil { + t.Error("Parse error in map: ", err) + } + + for k, v := range map[string]int{"one": 1, "two": 2, "three": 3} { + if mv, ok := e.Map[k]; !ok { + t.Errorf("Key %q not in map", k) + } else { + assertEqual(t, "Map["+k+"]", v, mv) + } + } +} + +func TestMapTyped(t *testing.T) { + t.Parallel() + + e := Everything{} + err := Parse(url.Values{"AMap[one]": {"1"}}, &e) + if err != nil { + t.Error("Parse error for typed map") + } + assertEqual(t, "e.AMap[one]", MyInt(1), e.AMap[MyString("one")]) +} + +func TestMapErrors(t *testing.T) { + t.Parallel() + e := Everything{} + + err := Parse(url.Values{"Map[]": {"llama"}}, &e) + if err == nil { + t.Error("expected error parsing empty map key") + } + + err = Parse(url.Values{"Map": {"llama"}}, &e) + if err == nil { + t.Error("expected error parsing map without key") + } + + err = Parse(url.Values{"Map[": {"llama"}}, &e) + if err == nil { + t.Error("expected error parsing map with malformed key") + } +} + +func testPtr(t *testing.T, key, in string, out interface{}) { + e := Everything{} + + err := Parse(url.Values{key: {in}}, &e) + if err != nil { + t.Errorf("Parse error while parsing pointer e.%s: %v", key, err) + } + fieldKey := key + if i := strings.IndexRune(fieldKey, '['); i >= 0 { + fieldKey = fieldKey[:i] + } + v := reflect.ValueOf(e).FieldByName(fieldKey) + if v.IsNil() { + t.Errorf("Expected param to allocate pointer for e.%s", key) + } else { + assertEqual(t, "*e."+key, out, v.Elem().Interface()) + } +} + +func TestPtr(t *testing.T) { + t.Parallel() + testPtr(t, "PBool", "true", true) + testPtr(t, "PInt", "2", 2) + testPtr(t, "PUint", "2", uint(2)) + testPtr(t, "PFloat", "2.0", 2.0) + testPtr(t, "PMap[llama]", "4", map[string]int{"llama": 4}) + testPtr(t, "PSlice[]", "4", []int{4}) + testPtr(t, "PString", "llama", "llama") + testPtr(t, "PStruct[B]", "2", Sub{0, 2}) + testPtr(t, "PTime", testTimeString, testTime) + + foo := 2 + testPtr(t, "PPInt", "2", &foo) +} + +func TestPtrTyped(t *testing.T) { + t.Parallel() + + e := Everything{} + err := Parse(url.Values{"APtr": {"1"}}, &e) + if err != nil { + t.Error("Parse error for typed pointer") + } + assertEqual(t, "e.APtr", MyInt(1), *e.APtr) +} + +func TestSlice(t *testing.T) { + t.Parallel() + + e := Everything{} + err := Parse(url.Values{"Slice[]": {"3", "1", "4"}}, &e) + if err != nil { + t.Error("Parse error for slice") + } + if e.Slice == nil { + t.Fatal("Expected param to allocate a slice") + } + if len(e.Slice) != 3 { + t.Fatal("Expected a slice of length 3") + } + + assertEqual(t, "e.Slice[0]", 3, e.Slice[0]) + assertEqual(t, "e.Slice[1]", 1, e.Slice[1]) + assertEqual(t, "e.Slice[2]", 4, e.Slice[2]) +} + +func TestSliceTyped(t *testing.T) { + t.Parallel() + e := Everything{} + err := Parse(url.Values{"ASlice[]": {"3", "1", "4"}}, &e) + if err != nil { + t.Error("Parse error for typed slice") + } + if e.ASlice == nil { + t.Fatal("Expected param to allocate a slice") + } + if len(e.ASlice) != 3 { + t.Fatal("Expected a slice of length 3") + } + + assertEqual(t, "e.ASlice[0]", MyInt(3), e.ASlice[0]) + assertEqual(t, "e.ASlice[1]", MyInt(1), e.ASlice[1]) + assertEqual(t, "e.ASlice[2]", MyInt(4), e.ASlice[2]) +} + +func TestSliceErrors(t *testing.T) { + t.Parallel() + e := Everything{} + err := Parse(url.Values{"Slice": {"1"}}, &e) + if err == nil { + t.Error("expected error parsing slice without key") + } + + err = Parse(url.Values{"Slice[llama]": {"1"}}, &e) + if err == nil { + t.Error("expected error parsing slice with string key") + } + + err = Parse(url.Values{"Slice[": {"1"}}, &e) + if err == nil { + t.Error("expected error parsing malformed slice key") + } +} + +var stringAnswer = "This is the world's best string" + +func TestString(t *testing.T) { + t.Parallel() + e := Everything{} + + err := Parse(url.Values{"String": {stringAnswer}}, &e) + if err != nil { + t.Error("Parse error in string: ", err) + } + + assertEqual(t, "e.String", stringAnswer, e.String) +} + +func TestStringTyped(t *testing.T) { + t.Parallel() + + e := Everything{} + err := Parse(url.Values{"AString": {"llama"}}, &e) + if err != nil { + t.Error("Parse error for typed string") + } + assertEqual(t, "e.AString", MyString("llama"), e.AString) +} + +func TestStruct(t *testing.T) { + t.Parallel() + e := Everything{} + + err := Parse(url.Values{ + "Struct[A]": {"1"}, + }, &e) + if err != nil { + t.Error("Parse error in struct: ", err) + } + assertEqual(t, "e.Struct.A", 1, e.Struct.A) + assertEqual(t, "e.Struct.B", 0, e.Struct.B) + + err = Parse(url.Values{ + "Struct[A]": {"4"}, + "Struct[B]": {"2"}, + }, &e) + if err != nil { + t.Error("Parse error in struct: ", err) + } + assertEqual(t, "e.Struct.A", 4, e.Struct.A) + assertEqual(t, "e.Struct.B", 2, e.Struct.B) +} + +func TestStructErrors(t *testing.T) { + t.Parallel() + e := Everything{} + + err := Parse(url.Values{"Struct[]": {"llama"}}, &e) + if err == nil { + t.Error("expected error parsing empty struct key") + } + + err = Parse(url.Values{"Struct": {"llama"}}, &e) + if err == nil { + t.Error("expected error parsing struct without key") + } + + err = Parse(url.Values{"Struct[": {"llama"}}, &e) + if err == nil { + t.Error("expected error parsing malformed struct key") + } + + err = Parse(url.Values{"Struct[C]": {"llama"}}, &e) + if err == nil { + t.Error("expected error parsing unknown") + } +} + +func TestTextUnmarshaler(t *testing.T) { + t.Parallel() + e := Everything{} + + err := Parse(url.Values{"Time": {testTimeString}}, &e) + if err != nil { + t.Error("parse error for TextUnmarshaler (Time): ", err) + } + assertEqual(t, "e.Time", testTime, e.Time) +} + +func TestTextUnmarshalerError(t *testing.T) { + t.Parallel() + e := Everything{} + + err := Parse(url.Values{"Time": {"llama"}}, &e) + if err == nil { + t.Error("expected error parsing llama as time") + } +} diff --git a/param/parse.go b/param/parse.go new file mode 100644 index 0000000..e59432a --- /dev/null +++ b/param/parse.go @@ -0,0 +1,214 @@ +package param + +import ( + "encoding" + "fmt" + "reflect" + "strconv" + "strings" +) + +var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + +// Generic parse dispatcher. This function's signature is the interface of all +// parse functions. `key` is the entire key that is currently being parsed, such +// as "foo[bar][]". `keytail` is the portion of the string that the current +// parser is responsible for, for instance "[bar][]". `values` is the list of +// values assigned to this key, and `target` is where the resulting typed value +// should be Set() to. +func parse(key, keytail string, values []string, target reflect.Value) { + t := target.Type() + if reflect.PtrTo(t).Implements(textUnmarshalerType) { + parseTextUnmarshaler(key, keytail, values, target) + return + } + + switch k := target.Kind(); k { + case reflect.Bool: + parseBool(key, keytail, values, target) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + parseInt(key, keytail, values, target) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + parseUint(key, keytail, values, target) + case reflect.Float32, reflect.Float64: + parseFloat(key, keytail, values, target) + case reflect.Map: + parseMap(key, keytail, values, target) + case reflect.Ptr: + parsePtr(key, keytail, values, target) + case reflect.Slice: + parseSlice(key, keytail, values, target) + case reflect.String: + parseString(key, keytail, values, target) + case reflect.Struct: + parseStruct(key, keytail, values, target) + + default: + pebkac("unsupported object of type %v and kind %v.", + target.Type(), k) + } +} + +// We pass down both the full key ("foo[bar][]") and the part the current layer +// is responsible for making sense of ("[bar][]"). This computes the other thing +// you probably want to know, which is the path you took to get here ("foo"). +func kpath(key, keytail string) string { + l, t := len(key), len(keytail) + return key[:l-t] +} + +// Helper for validating that a value has been passed exactly once, and that the +// user is not attempting to nest on the key. +func primitive(tipe, key, keytail string, values []string) { + if keytail != "" { + perr("expected %s for key %q, got nested value", tipe, + kpath(key, keytail)) + } + if len(values) != 1 { + perr("expected %s for key %q, but key passed %v times", tipe, + kpath(key, keytail), len(values)) + } +} + +func keyed(tipe reflect.Type, key, keytail string) (string, string) { + idx := strings.IndexRune(keytail, ']') + // Keys must be at least 1 rune wide: we refuse to use the empty string + // as the key + if len(keytail) < 3 || keytail[0] != '[' || idx < 2 { + perr("expected a square bracket delimited index for %q "+ + "(of type %v)", kpath(key, keytail), tipe) + } + return keytail[1:idx], keytail[idx+1:] +} + +func parseTextUnmarshaler(key, keytail string, values []string, target reflect.Value) { + primitive("encoding.TextUnmarshaler", key, keytail, values) + + tu := target.Addr().Interface().(encoding.TextUnmarshaler) + err := tu.UnmarshalText([]byte(values[0])) + if err != nil { + perr("error while calling UnmarshalText on %v for key %q: %v", + target.Type(), kpath(key, keytail), err) + } +} + +func parseBool(key, keytail string, values []string, target reflect.Value) { + primitive("bool", key, keytail, values) + + switch values[0] { + case "true", "1", "on": + target.SetBool(true) + case "false", "0", "": + target.SetBool(false) + default: + perr("could not parse key %q as bool", kpath(key, keytail)) + } +} + +func parseInt(key, keytail string, values []string, target reflect.Value) { + primitive("int", key, keytail, values) + + t := target.Type() + i, err := strconv.ParseInt(values[0], 10, t.Bits()) + if err != nil { + perr("error parsing key %q as int: %v", kpath(key, keytail), + err) + } + target.SetInt(i) +} + +func parseUint(key, keytail string, values []string, target reflect.Value) { + primitive("uint", key, keytail, values) + + t := target.Type() + i, err := strconv.ParseUint(values[0], 10, t.Bits()) + if err != nil { + perr("error parsing key %q as uint: %v", kpath(key, keytail), + err) + } + target.SetUint(i) +} + +func parseFloat(key, keytail string, values []string, target reflect.Value) { + primitive("float", key, keytail, values) + + t := target.Type() + f, err := strconv.ParseFloat(values[0], t.Bits()) + if err != nil { + perr("error parsing key %q as float: %v", kpath(key, keytail), + err) + } + + target.SetFloat(f) +} + +func parseString(key, keytail string, values []string, target reflect.Value) { + primitive("string", key, keytail, values) + + target.SetString(values[0]) +} + +func parseSlice(key, keytail string, values []string, target reflect.Value) { + // BUG(carl): We currently do not handle slices of nested types. If + // support is needed, the implementation probably could be fleshed out. + if keytail != "[]" { + perr("unexpected array nesting for key %q: %q", + kpath(key, keytail), keytail) + } + t := target.Type() + + slice := reflect.MakeSlice(t, len(values), len(values)) + kp := kpath(key, keytail) + for i, _ := range values { + // We actually cheat a little bit and modify the key so we can + // generate better debugging messages later + key := fmt.Sprintf("%s[%d]", kp, i) + parse(key, "", values[i:i+1], slice.Index(i)) + } + target.Set(slice) +} + +func parseMap(key, keytail string, values []string, target reflect.Value) { + t := target.Type() + mapkey, maptail := keyed(t, key, keytail) + + // BUG(carl): We don't support any map keys except strings, although + // there's no reason we shouldn't be able to throw the value through our + // unparsing stack. + var mk reflect.Value + if t.Key().Kind() == reflect.String { + mk = reflect.ValueOf(mapkey).Convert(t.Key()) + } else { + pebkac("key for map %v isn't a string (it's a %v).", t, t.Key()) + } + + if target.IsNil() { + target.Set(reflect.MakeMap(t)) + } + + val := target.MapIndex(mk) + if !val.IsValid() || !val.CanSet() { + // It's a teensy bit annoying that the value returned by + // MapIndex isn't Set()table if the key exists. + val = reflect.New(t.Elem()).Elem() + } + parse(key, maptail, values, val) + target.SetMapIndex(mk, val) +} + +func parseStruct(key, keytail string, values []string, target reflect.Value) { + t := target.Type() + sk, skt := keyed(t, key, keytail) + cache := cacheStruct(t) + + parseStructField(cache, key, sk, skt, values, target) +} + +func parsePtr(key, keytail string, values []string, target reflect.Value) { + t := target.Type() + + if target.IsNil() { + target.Set(reflect.New(t.Elem())) + } + parse(key, keytail, values, target.Elem()) +} diff --git a/param/pebkac_test.go b/param/pebkac_test.go new file mode 100644 index 0000000..71d64eb --- /dev/null +++ b/param/pebkac_test.go @@ -0,0 +1,58 @@ +package param + +import ( + "net/url" + "strings" + "testing" +) + +type Bad struct { + Unknown interface{} +} + +type Bad2 struct { + Unknown *interface{} +} + +type Bad3 struct { + BadMap map[int]int +} + +// These tests are not parallel so we can frob pebkac behavior in an isolated +// way + +func assertPebkac(t *testing.T, err error) { + if err == nil { + t.Error("Expected PEBKAC error message") + } else if !strings.HasSuffix(err.Error(), yourFault) { + t.Errorf("Expected PEBKAC error, but got: %v", err) + } +} + +func TestBadInputs(t *testing.T) { + pebkacTesting = true + + err := Parse(url.Values{"Unknown": {"4"}}, Bad{}) + assertPebkac(t, err) + + b := &Bad{} + err = Parse(url.Values{"Unknown": {"4"}}, &b) + assertPebkac(t, err) + + pebkacTesting = false +} + +func TestBadTypes(t *testing.T) { + pebkacTesting = true + + err := Parse(url.Values{"Unknown": {"4"}}, &Bad{}) + assertPebkac(t, err) + + err = Parse(url.Values{"Unknown": {"4"}}, &Bad2{}) + assertPebkac(t, err) + + err = Parse(url.Values{"BadMap[llama]": {"4"}}, &Bad3{}) + assertPebkac(t, err) + + pebkacTesting = false +} diff --git a/param/struct.go b/param/struct.go new file mode 100644 index 0000000..314d5b0 --- /dev/null +++ b/param/struct.go @@ -0,0 +1,117 @@ +package param + +import ( + "reflect" + "strings" + "sync" +) + +// We decode a lot of structs (since it's the top-level thing this library +// decodes) and it takes a fair bit of work to reflect upon the struct to figure +// out what we want to do. Instead of doing this on every invocation, we cache +// metadata about each struct the first time we see it. The upshot is that we +// save some work every time. The downside is we are forced to briefly acquire +// a lock to access the cache in a thread-safe way. If this ever becomes a +// bottleneck, both the lock and the cache can be sharded or something. +type structCache map[string]cacheLine +type cacheLine struct { + offset int + parse func(string, string, []string, reflect.Value) +} + +var cacheLock sync.RWMutex +var cache = make(map[reflect.Type]structCache) + +func cacheStruct(t reflect.Type) structCache { + cacheLock.RLock() + sc, ok := cache[t] + cacheLock.RUnlock() + + if ok { + return sc + } + + // It's okay if two people build struct caches simultaneously + sc = make(structCache) + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + // Only unexported fields have a PkgPath; we want to only cache + // exported fields. + if sf.PkgPath != "" { + continue + } + name := extractName(sf) + if name != "-" { + sc[name] = cacheLine{i, extractHandler(t, sf)} + } + } + + cacheLock.Lock() + cache[t] = sc + cacheLock.Unlock() + + return sc +} + +// Extract the name of the given struct field, looking at struct tags as +// appropriate. +func extractName(sf reflect.StructField) string { + name := sf.Tag.Get("param") + if name == "" { + name = sf.Tag.Get("json") + idx := strings.IndexRune(name, ',') + if idx >= 0 { + name = name[:idx] + } + } + if name == "" { + name = sf.Name + } + + return name +} + +func extractHandler(s reflect.Type, sf reflect.StructField) func(string, string, []string, reflect.Value) { + if reflect.PtrTo(sf.Type).Implements(textUnmarshalerType) { + return parseTextUnmarshaler + } + + switch sf.Type.Kind() { + case reflect.Bool: + return parseBool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return parseInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return parseUint + case reflect.Float32, reflect.Float64: + return parseFloat + case reflect.Map: + return parseMap + case reflect.Ptr: + return parsePtr + case reflect.Slice: + return parseSlice + case reflect.String: + return parseString + case reflect.Struct: + return parseStruct + + default: + pebkac("struct %v has illegal field %q (type %v, kind %v).", + s, sf.Name, sf.Type, sf.Type.Kind()) + return nil + } +} + +// We have to parse two types of structs: ones at the top level, whose keys +// don't have square brackets around them, and nested structs, which do. +func parseStructField(cache structCache, key, sk, keytail string, values []string, target reflect.Value) { + l, ok := cache[sk] + if !ok { + perr("unknown key %q for struct at key %q", sk, + kpath(key, keytail)) + } + f := target.Field(l.offset) + + l.parse(key, keytail, values, f) +} diff --git a/param/struct_test.go b/param/struct_test.go new file mode 100644 index 0000000..ecba3e2 --- /dev/null +++ b/param/struct_test.go @@ -0,0 +1,106 @@ +package param + +import ( + "reflect" + "testing" +) + +type Fruity struct { + A bool + B int `json:"banana"` + C uint `param:"cherry"` + D float64 `json:"durian" param:"dragonfruit"` + E int `json:"elderberry" param:"-"` + F map[string]int `json:"-" param:"fig"` + G *int `json:"grape,omitempty"` + H []int `param:"honeydew" json:"huckleberry"` + I string `foobar:"iyokan"` + J Cheesy `param:"jackfruit" cheese:"jarlsberg"` +} + +type Cheesy struct { + A int `param:"affinois"` + B int `param:"brie"` + C int `param:"camembert"` + D int `param:"delice d'argental"` +} + +type Private struct { + Public, private int +} + +var fruityType = reflect.TypeOf(Fruity{}) +var cheesyType = reflect.TypeOf(Cheesy{}) +var privateType = reflect.TypeOf(Private{}) + +var fruityNames = []string{ + "A", "banana", "cherry", "dragonfruit", "-", "fig", "grape", "honeydew", + "I", "jackfruit", +} + +var fruityCache = map[string]cacheLine{ + "A": {0, parseBool}, + "banana": {1, parseInt}, + "cherry": {2, parseUint}, + "dragonfruit": {3, parseFloat}, + "fig": {5, parseMap}, + "grape": {6, parsePtr}, + "honeydew": {7, parseSlice}, + "I": {8, parseString}, + "jackfruit": {9, parseStruct}, +} + +func assertEqual(t *testing.T, what string, e, a interface{}) { + if !reflect.DeepEqual(e, a) { + t.Errorf("Expected %s to be %v, was actually %v", what, e, a) + } +} + +func TestNames(t *testing.T) { + t.Parallel() + + for i, val := range fruityNames { + name := extractName(fruityType.Field(i)) + assertEqual(t, "tag", val, name) + } +} + +func TestCacheStruct(t *testing.T) { + t.Parallel() + + sc := cacheStruct(fruityType) + + if len(sc) != len(fruityCache) { + t.Errorf("Cache has %d keys, but expected %d", len(sc), + len(fruityCache)) + } + + for k, v := range fruityCache { + sck, ok := sc[k] + if !ok { + t.Errorf("Could not find key %q in cache", k) + continue + } + if sck.offset != v.offset { + t.Errorf("Cache for %q: expected offset %d but got %d", + k, sck.offset, v.offset) + } + // We want to compare function pointer equality, and this + // appears to be the only way + a := reflect.ValueOf(sck.parse) + b := reflect.ValueOf(v.parse) + if a.Pointer() != b.Pointer() { + t.Errorf("Parse mismatch for %q: %v, expected %v", k, a, + b) + } + } +} + +func TestPrivate(t *testing.T) { + t.Parallel() + + sc := cacheStruct(privateType) + if len(sc) != 1 { + t.Error("Expected Private{} to have one cachable field") + } +}