Skip to content

Commit 52f6238

Browse files
authored
Merge pull request #74 from oschwald/greg/custom-deserializer
Support custom deserializer
2 parents a1069d8 + 1f1e288 commit 52f6238

6 files changed

+294
-8
lines changed

decoder.go

+135
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,26 @@ func (d *decoder) decode(offset uint, result reflect.Value, depth int) (uint, er
5656
return d.decodeFromType(typeNum, size, newOffset, result, depth+1)
5757
}
5858

59+
func (d *decoder) decodeToDeserializer(offset uint, dser deserializer, depth int) (uint, error) {
60+
if depth > maximumDataStructureDepth {
61+
return 0, newInvalidDatabaseError("exceeded maximum data structure depth; database is likely corrupt")
62+
}
63+
typeNum, size, newOffset, err := d.decodeCtrlData(offset)
64+
if err != nil {
65+
return 0, err
66+
}
67+
68+
skip, err := dser.ShouldSkip(uintptr(offset))
69+
if err != nil {
70+
return 0, err
71+
}
72+
if skip {
73+
return d.nextValueOffset(offset, 1)
74+
}
75+
76+
return d.decodeFromTypeToDeserializer(typeNum, size, newOffset, dser, depth+1)
77+
}
78+
5979
func (d *decoder) decodeCtrlData(offset uint) (dataType, uint, uint, error) {
6080
newOffset := offset + 1
6181
if offset >= uint(len(d.buffer)) {
@@ -157,6 +177,68 @@ func (d *decoder) decodeFromType(
157177
}
158178
}
159179

180+
func (d *decoder) decodeFromTypeToDeserializer(
181+
dtype dataType,
182+
size uint,
183+
offset uint,
184+
dser deserializer,
185+
depth int,
186+
) (uint, error) {
187+
// For these types, size has a special meaning
188+
switch dtype {
189+
case _Bool:
190+
v, offset := d.decodeBool(size, offset)
191+
return offset, dser.Bool(v)
192+
case _Map:
193+
return d.decodeMapToDeserializer(size, offset, dser, depth)
194+
case _Pointer:
195+
pointer, newOffset, err := d.decodePointer(size, offset)
196+
if err != nil {
197+
return 0, err
198+
}
199+
_, err = d.decodeToDeserializer(pointer, dser, depth)
200+
return newOffset, err
201+
case _Slice:
202+
return d.decodeSliceToDeserializer(size, offset, dser, depth)
203+
}
204+
205+
// For the remaining types, size is the byte size
206+
if offset+size > uint(len(d.buffer)) {
207+
return 0, newOffsetError()
208+
}
209+
switch dtype {
210+
case _Bytes:
211+
v, offset := d.decodeBytes(size, offset)
212+
return offset, dser.Bytes(v)
213+
case _Float32:
214+
v, offset := d.decodeFloat32(size, offset)
215+
return offset, dser.Float32(v)
216+
case _Float64:
217+
v, offset := d.decodeFloat64(size, offset)
218+
return offset, dser.Float64(v)
219+
case _Int32:
220+
v, offset := d.decodeInt(size, offset)
221+
return offset, dser.Int32(int32(v))
222+
case _String:
223+
v, offset := d.decodeString(size, offset)
224+
return offset, dser.String(v)
225+
case _Uint16:
226+
v, offset := d.decodeUint(size, offset)
227+
return offset, dser.Uint16(uint16(v))
228+
case _Uint32:
229+
v, offset := d.decodeUint(size, offset)
230+
return offset, dser.Uint32(uint32(v))
231+
case _Uint64:
232+
v, offset := d.decodeUint(size, offset)
233+
return offset, dser.Uint64(v)
234+
case _Uint128:
235+
v, offset := d.decodeUint128(size, offset)
236+
return offset, dser.Uint128(v)
237+
default:
238+
return 0, newInvalidDatabaseError("unknown type: %d", dtype)
239+
}
240+
}
241+
160242
func (d *decoder) unmarshalBool(size, offset uint, result reflect.Value) (uint, error) {
161243
if size > 1 {
162244
return 0, newInvalidDatabaseError("the MaxMind DB file's data section contains bad data (bool size of %v)", size)
@@ -199,6 +281,7 @@ func (d *decoder) indirect(result reflect.Value) reflect.Value {
199281
if result.IsNil() {
200282
result.Set(reflect.New(result.Type().Elem()))
201283
}
284+
202285
result = result.Elem()
203286
}
204287
return result
@@ -486,6 +569,35 @@ func (d *decoder) decodeMap(
486569
return offset, nil
487570
}
488571

572+
func (d *decoder) decodeMapToDeserializer(
573+
size uint,
574+
offset uint,
575+
dser deserializer,
576+
depth int,
577+
) (uint, error) {
578+
err := dser.StartMap(size)
579+
if err != nil {
580+
return 0, err
581+
}
582+
for i := uint(0); i < size; i++ {
583+
// TODO - implement key/value skipping?
584+
offset, err = d.decodeToDeserializer(offset, dser, depth)
585+
if err != nil {
586+
return 0, err
587+
}
588+
589+
offset, err = d.decodeToDeserializer(offset, dser, depth)
590+
if err != nil {
591+
return 0, err
592+
}
593+
}
594+
err = dser.End()
595+
if err != nil {
596+
return 0, err
597+
}
598+
return offset, nil
599+
}
600+
489601
func (d *decoder) decodePointer(
490602
size uint,
491603
offset uint,
@@ -538,6 +650,29 @@ func (d *decoder) decodeSlice(
538650
return offset, nil
539651
}
540652

653+
func (d *decoder) decodeSliceToDeserializer(
654+
size uint,
655+
offset uint,
656+
dser deserializer,
657+
depth int,
658+
) (uint, error) {
659+
err := dser.StartSlice(size)
660+
if err != nil {
661+
return 0, err
662+
}
663+
for i := uint(0); i < size; i++ {
664+
offset, err = d.decodeToDeserializer(offset, dser, depth)
665+
if err != nil {
666+
return 0, err
667+
}
668+
}
669+
err = dser.End()
670+
if err != nil {
671+
return 0, err
672+
}
673+
return offset, nil
674+
}
675+
541676
func (d *decoder) decodeString(size, offset uint) (string, uint) {
542677
newOffset := offset + size
543678
return string(d.buffer[offset:newOffset]), newOffset

deserializer.go

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package maxminddb
2+
3+
import "math/big"
4+
5+
// deserializer is an interface for a type that deserializes an MaxMind DB
6+
// data record to some other type. This exists as an alternative to the
7+
// standard reflection API.
8+
//
9+
// This is fundamentally different than the Unmarshaler interface that
10+
// several packages provide. A Deserializer will generally create the
11+
// final struct or value rather than unmarshaling to itself.
12+
//
13+
// This interface and the associated unmarshaling code is EXPERIMENTAL!
14+
// It is not currently covered by any Semantic Versioning guarantees.
15+
// Use at your own risk.
16+
type deserializer interface {
17+
ShouldSkip(offset uintptr) (bool, error)
18+
StartSlice(size uint) error
19+
StartMap(size uint) error
20+
End() error
21+
String(string) error
22+
Float64(float64) error
23+
Bytes([]byte) error
24+
Uint16(uint16) error
25+
Uint32(uint32) error
26+
Int32(int32) error
27+
Uint64(uint64) error
28+
Uint128(*big.Int) error
29+
Bool(bool) error
30+
Float32(float32) error
31+
}

deserializer.go deserializer_test.go

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package maxminddb
2+
3+
import (
4+
"math/big"
5+
"net"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestDecodingToDeserializer(t *testing.T) {
12+
reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb"))
13+
require.NoError(t, err, "unexpected error while opening database: %v", err)
14+
15+
dser := testDeserializer{}
16+
err = reader.Lookup(net.ParseIP("::1.1.1.0"), &dser)
17+
require.NoError(t, err, "unexpected error while doing lookup: %v", err)
18+
19+
checkDecodingToInterface(t, dser.rv)
20+
}
21+
22+
type stackValue struct {
23+
value interface{}
24+
curNum int
25+
}
26+
27+
type testDeserializer struct {
28+
stack []*stackValue
29+
rv interface{}
30+
key *string
31+
}
32+
33+
func (d *testDeserializer) ShouldSkip(offset uintptr) (bool, error) {
34+
return false, nil
35+
}
36+
37+
func (d *testDeserializer) StartSlice(size uint) error {
38+
return d.add(make([]interface{}, size))
39+
}
40+
41+
func (d *testDeserializer) StartMap(size uint) error {
42+
return d.add(map[string]interface{}{})
43+
}
44+
45+
func (d *testDeserializer) End() error {
46+
d.stack = d.stack[:len(d.stack)-1]
47+
return nil
48+
}
49+
50+
func (d *testDeserializer) String(v string) error {
51+
return d.add(v)
52+
}
53+
54+
func (d *testDeserializer) Float64(v float64) error {
55+
return d.add(v)
56+
}
57+
58+
func (d *testDeserializer) Bytes(v []byte) error {
59+
return d.add(v)
60+
}
61+
62+
func (d *testDeserializer) Uint16(v uint16) error {
63+
return d.add(uint64(v))
64+
}
65+
66+
func (d *testDeserializer) Uint32(v uint32) error {
67+
return d.add(uint64(v))
68+
}
69+
70+
func (d *testDeserializer) Int32(v int32) error {
71+
return d.add(int(v))
72+
}
73+
74+
func (d *testDeserializer) Uint64(v uint64) error {
75+
return d.add(v)
76+
}
77+
78+
func (d *testDeserializer) Uint128(v *big.Int) error {
79+
return d.add(v)
80+
}
81+
82+
func (d *testDeserializer) Bool(v bool) error {
83+
return d.add(v)
84+
}
85+
86+
func (d *testDeserializer) Float32(v float32) error {
87+
return d.add(v)
88+
}
89+
90+
func (d *testDeserializer) add(v interface{}) error {
91+
if len(d.stack) == 0 {
92+
d.rv = v
93+
} else {
94+
top := d.stack[len(d.stack)-1]
95+
switch parent := top.value.(type) {
96+
case map[string]interface{}:
97+
if d.key == nil {
98+
key := v.(string)
99+
d.key = &key
100+
} else {
101+
parent[*d.key] = v
102+
d.key = nil
103+
}
104+
105+
case []interface{}:
106+
parent[top.curNum] = v
107+
top.curNum++
108+
default:
109+
}
110+
}
111+
112+
switch v := v.(type) {
113+
case map[string]interface{}, []interface{}:
114+
d.stack = append(d.stack, &stackValue{value: v})
115+
default:
116+
}
117+
118+
return nil
119+
}

go.sum

-8
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,11 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
44
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
55
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
66
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
7-
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
8-
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
9-
github.com/stretchr/testify v1.5.0 h1:DMOzIV76tmoDNE9pX6RSN0aDtCYeCg5VueieJaAo1uw=
10-
github.com/stretchr/testify v1.5.0/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
11-
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
12-
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
137
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
148
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
159
golang.org/x/sys v0.0.0-20191224085550-c709ea063b76 h1:Dho5nD6R3PcW2SH1or8vS0dszDaXRxIw55lBX7XiE5g=
1610
golang.org/x/sys v0.0.0-20191224085550-c709ea063b76/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
1711
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
1812
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
19-
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
20-
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
2113
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
2214
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

reader.go

+5
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ func (r *Reader) decode(offset uintptr, result interface{}) error {
227227
return errors.New("result param must be a pointer")
228228
}
229229

230+
if dser, ok := result.(deserializer); ok {
231+
_, err := r.decoder.decodeToDeserializer(uint(offset), dser, 0)
232+
return err
233+
}
234+
230235
_, err := r.decoder.decode(uint(offset), rv, 0)
231236
return err
232237
}

reader_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ func TestDecodingToInterface(t *testing.T) {
203203
err = reader.Lookup(net.ParseIP("::1.1.1.0"), &recordInterface)
204204
require.NoError(t, err, "unexpected error while doing lookup: %v", err)
205205

206+
checkDecodingToInterface(t, recordInterface)
207+
}
208+
209+
func checkDecodingToInterface(t *testing.T, recordInterface interface{}) {
206210
record := recordInterface.(map[string]interface{})
207211
assert.Equal(t, []interface{}{uint64(1), uint64(2), uint64(3)}, record["array"])
208212
assert.Equal(t, true, record["boolean"])

0 commit comments

Comments
 (0)