diff --git a/decoder.go b/decoder.go index 2df2810..bdba43b 100644 --- a/decoder.go +++ b/decoder.go @@ -2,10 +2,13 @@ package maxminddb import ( "encoding/binary" + "fmt" "math" "math/big" "reflect" "sync" + + "github.com/maxmind/mmdbwriter/mmdbtype" ) type decoder struct { @@ -40,6 +43,8 @@ const ( maximumDataStructureDepth = 512 ) +var mmdbtypeDataType = reflect.TypeOf([]mmdbtype.DataType{}).Elem() + func (d *decoder) decode(offset uint, result reflect.Value, depth int) (uint, error) { if depth > maximumDataStructureDepth { return 0, newInvalidDatabaseError("exceeded maximum data structure depth; database is likely corrupt") @@ -168,7 +173,11 @@ func (d *decoder) unmarshalBool(size, offset uint, result reflect.Value) (uint, result.SetBool(value) return newOffset, nil case reflect.Interface: - if result.NumMethod() == 0 { + switch { + case result.Type() == mmdbtypeDataType: + result.Set(reflect.ValueOf(mmdbtype.Bool(value))) + return newOffset, nil + case result.NumMethod() == 0: result.Set(reflect.ValueOf(value)) return newOffset, nil } @@ -216,7 +225,11 @@ func (d *decoder) unmarshalBytes(size, offset uint, result reflect.Value) (uint, return newOffset, nil } case reflect.Interface: - if result.NumMethod() == 0 { + switch { + case result.Type() == mmdbtypeDataType: + result.Set(reflect.ValueOf(mmdbtype.Bytes(value))) + return newOffset, nil + case result.NumMethod() == 0: result.Set(reflect.ValueOf(value)) return newOffset, nil } @@ -235,7 +248,11 @@ func (d *decoder) unmarshalFloat32(size, offset uint, result reflect.Value) (uin result.SetFloat(float64(value)) return newOffset, nil case reflect.Interface: - if result.NumMethod() == 0 { + switch { + case result.Type() == mmdbtypeDataType: + result.Set(reflect.ValueOf(mmdbtype.Float32(value))) + return newOffset, nil + case result.NumMethod() == 0: result.Set(reflect.ValueOf(value)) return newOffset, nil } @@ -257,7 +274,11 @@ func (d *decoder) unmarshalFloat64(size, offset uint, result reflect.Value) (uin result.SetFloat(value) return newOffset, nil case reflect.Interface: - if result.NumMethod() == 0 { + switch { + case result.Type() == mmdbtypeDataType: + result.Set(reflect.ValueOf(mmdbtype.Float64(value))) + return newOffset, nil + case result.NumMethod() == 0: result.Set(reflect.ValueOf(value)) return newOffset, nil } @@ -285,7 +306,11 @@ func (d *decoder) unmarshalInt32(size, offset uint, result reflect.Value) (uint, return newOffset, nil } case reflect.Interface: - if result.NumMethod() == 0 { + switch { + case result.Type() == mmdbtypeDataType: + result.Set(reflect.ValueOf(mmdbtype.Int32(value))) + return newOffset, nil + case result.NumMethod() == 0: result.Set(reflect.ValueOf(value)) return newOffset, nil } @@ -308,13 +333,20 @@ func (d *decoder) unmarshalMap( case reflect.Map: return d.decodeMap(size, offset, result, depth) case reflect.Interface: - if result.NumMethod() == 0 { - rv := reflect.ValueOf(make(map[string]interface{}, size)) - newOffset, err := d.decodeMap(size, offset, rv, depth) - result.Set(rv) - return newOffset, err + var v interface{} + switch { + case result.Type() == mmdbtypeDataType: + v = make(mmdbtype.Map, size) + case result.NumMethod() == 0: + v = make(map[string]interface{}, size) + default: + return 0, newUnmarshalTypeError("map", result.Type()) } - return 0, newUnmarshalTypeError("map", result.Type()) + + rv := reflect.ValueOf(v) + newOffset, err := d.decodeMap(size, offset, rv, depth) + result.Set(rv) + return newOffset, err } } @@ -337,13 +369,20 @@ func (d *decoder) unmarshalSlice( case reflect.Slice: return d.decodeSlice(size, offset, result, depth) case reflect.Interface: - if result.NumMethod() == 0 { + var rv reflect.Value + switch { + case result.Type() == mmdbtypeDataType: + a := mmdbtype.Slice{} + rv = reflect.ValueOf(&a).Elem() + case result.NumMethod() == 0: a := []interface{}{} - rv := reflect.ValueOf(&a).Elem() - newOffset, err := d.decodeSlice(size, offset, rv, depth) - result.Set(rv) - return newOffset, err + rv = reflect.ValueOf(&a).Elem() + default: + return 0, newUnmarshalTypeError("map", result.Type()) } + newOffset, err := d.decodeSlice(size, offset, rv, depth) + result.Set(rv) + return newOffset, err } return 0, newUnmarshalTypeError("array", result.Type()) } @@ -356,7 +395,11 @@ func (d *decoder) unmarshalString(size, offset uint, result reflect.Value) (uint result.SetString(value) return newOffset, nil case reflect.Interface: - if result.NumMethod() == 0 { + switch { + case result.Type() == mmdbtypeDataType: + result.Set(reflect.ValueOf(mmdbtype.String(value))) + return newOffset, nil + case result.NumMethod() == 0: result.Set(reflect.ValueOf(value)) return newOffset, nil } @@ -384,7 +427,20 @@ func (d *decoder) unmarshalUint(size, offset uint, result reflect.Value, uintTyp return newOffset, nil } case reflect.Interface: - if result.NumMethod() == 0 { + switch { + case result.Type() == mmdbtypeDataType: + switch uintType { + case 16: + result.Set(reflect.ValueOf(mmdbtype.Uint16(value))) + case 32: + result.Set(reflect.ValueOf(mmdbtype.Uint32(value))) + case 64: + result.Set(reflect.ValueOf(mmdbtype.Uint64(value))) + default: + return 0, fmt.Errorf("unknown uint type: %d", uintType) + } + return newOffset, nil + case result.NumMethod() == 0: result.Set(reflect.ValueOf(value)) return newOffset, nil } @@ -407,7 +463,12 @@ func (d *decoder) unmarshalUint128(size, offset uint, result reflect.Value) (uin return newOffset, nil } case reflect.Interface: - if result.NumMethod() == 0 { + switch { + case result.Type() == mmdbtypeDataType: + v := mmdbtype.Uint128(*value) + result.Set(reflect.ValueOf(&v)) + return newOffset, nil + case result.NumMethod() == 0: result.Set(reflect.ValueOf(value)) return newOffset, nil } diff --git a/go.mod b/go.mod index 4918a8c..ec84fe6 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/oschwald/maxminddb-golang go 1.9 require ( + github.com/maxmind/mmdbwriter v0.0.0-20200813154840-0d62f9dbc5b4 github.com/stretchr/testify v1.6.1 golang.org/x/sys v0.0.0-20191224085550-c709ea063b76 ) diff --git a/go.sum b/go.sum index 45bf757..6b546e3 100644 --- a/go.sum +++ b/go.sum @@ -1,22 +1,19 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/maxmind/mmdbwriter v0.0.0-20200813154840-0d62f9dbc5b4 h1:vXDTX6JzwqQnTSrsICy37iz3/Wq+ShfJzBfupqdS0R4= +github.com/maxmind/mmdbwriter v0.0.0-20200813154840-0d62f9dbc5b4/go.mod h1:CC0q++Jsqiets4nVeAx6KaQ5J5fkmeYsECtHF4v0nFo= +github.com/oschwald/maxminddb-golang v1.7.0/go.mod h1:RXZtst0N6+FY/3qCNmZMBApR19cdQj43/NM9VkrNAis= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.0 h1:DMOzIV76tmoDNE9pX6RSN0aDtCYeCg5VueieJaAo1uw= -github.com/stretchr/testify v1.5.0/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/sys v0.0.0-20191224085550-c709ea063b76 h1:Dho5nD6R3PcW2SH1or8vS0dszDaXRxIw55lBX7XiE5g= golang.org/x/sys v0.0.0-20191224085550-c709ea063b76/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/reader_test.go b/reader_test.go index 863462d..762cd7b 100644 --- a/reader_test.go +++ b/reader_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/maxmind/mmdbwriter/mmdbtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -296,6 +297,48 @@ func TestDecoder(t *testing.T) { assert.NoError(t, reader.Close()) } +func TestDecoderToDataType(t *testing.T) { + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + require.NoError(t, err) + + var result mmdbtype.DataType + require.NoError(t, reader.Lookup(net.ParseIP("::1.1.1.0"), &result)) + + bigInt := big.Int{} + bigInt.SetString("1329227995784915872903807060280344576", 10) + mmdbBigInt := mmdbtype.Uint128(bigInt) + assert.Equal( + t, + mmdbtype.Map{ + "array": mmdbtype.Slice{ + mmdbtype.Uint32(1), + mmdbtype.Uint32(2), + mmdbtype.Uint32(3), + }, + "boolean": mmdbtype.Bool(true), + "bytes": mmdbtype.Bytes{0x0, 0x0, 0x0, 0x2a}, + "double": mmdbtype.Float64(42.123456), + "float": mmdbtype.Float32(1.1), + "int32": mmdbtype.Int32(-268435456), + "map": mmdbtype.Map{"mapX": mmdbtype.Map{ + "arrayX": mmdbtype.Slice{ + mmdbtype.Uint32(7), + mmdbtype.Uint32(8), + mmdbtype.Uint32(9), + }, + "utf8_stringX": mmdbtype.String("hello"), + }}, + "uint128": &mmdbBigInt, + "uint16": mmdbtype.Uint16(0x64), + "uint32": mmdbtype.Uint32(0x10000000), + "uint64": mmdbtype.Uint64(0x1000000000000000), + "utf8_string": mmdbtype.String("unicode! ☯ - ♫"), + }, + + result, + ) +} + type TestInterface interface { method() bool } diff --git a/test-data b/test-data index c46c33c..c4c2805 160000 --- a/test-data +++ b/test-data @@ -1 +1 @@ -Subproject commit c46c33c3c598c648013e2aa7458f8492f4ecfcce +Subproject commit c4c280500277981b944154acab929c667beff67d diff --git a/traverse.go b/traverse.go index 21f9c67..03a831a 100644 --- a/traverse.go +++ b/traverse.go @@ -22,6 +22,8 @@ type Networks struct { var ( allIPv4 = &net.IPNet{IP: make(net.IP, 4), Mask: net.CIDRMask(0, 32)} allIPv6 = &net.IPNet{IP: make(net.IP, 16), Mask: net.CIDRMask(0, 128)} + + ipv4Subtree = &net.IPNet{IP: make(net.IP, 16), Mask: net.CIDRMask(96, 128)} ) // Networks returns an iterator that can be used to traverse all networks in @@ -81,6 +83,20 @@ func (n *Networks) Next() bool { n.nodes = n.nodes[:len(n.nodes)-1] for node.pointer != n.reader.Metadata.NodeCount { + // XXX - this is just a proof-of concept hack. This should probably + // be made an option so that we don't break people's code + // + // The intent is to not traverse IPv4 aliases without hardcoding + // the networks that the writer currently aliases. + // + // Also, if we do this, we should adjust the IPNets for the IPv4 + // subtree so that they are less surprising (e.g., make them proper + // IPv4 network) + if n.reader.ipv4Start != 0 && node.pointer == n.reader.ipv4Start && + !ipv4Subtree.Contains(node.ip) { + break + } + if node.pointer > n.reader.Metadata.NodeCount { n.lastNode = node return true @@ -113,10 +129,12 @@ func (n *Networks) Next() bool { // Network returns the current network or an error if there is a problem // decoding the data for the network. It takes a pointer to a result value to -// decode the network's data into. +// decode the network's data into. If result is nil, decoding will be skipped. func (n *Networks) Network(result interface{}) (*net.IPNet, error) { - if err := n.reader.retrieveData(n.lastNode.pointer, result); err != nil { - return nil, err + if result != nil { + if err := n.reader.retrieveData(n.lastNode.pointer, result); err != nil { + return nil, err + } } return &net.IPNet{ @@ -125,6 +143,12 @@ func (n *Networks) Network(result interface{}) (*net.IPNet, error) { }, nil } +// Offset returns the offset of the record for the current network or an error +// if there is a problem. +func (n *Networks) Offset() (uintptr, error) { + return n.reader.resolveDataPointer(n.lastNode.pointer) +} + // Err returns an error, if any, that was encountered during iteration. func (n *Networks) Err() error { return n.err